Paste: co1
Author: | 302 |
Mode: | factor |
Date: | Tue, 11 Jan 2022 03:54:28 |
Plain Text |
# -*- coding:utf-8 -*-
###############################################################################
# Copyright (c) 2022 Horizon Robotics, All rights reserved.
# __ __ _ ___ __ __ _
# / // /__ ____(_)__ ___ ___ / _ \___ / / ___ / /_(_)______
# / _ / _ \/ __/ /_ // _ \/ _ \ / , _/ _ \/ _ \/ _ \/ __/ / __(_-<
# /_//_/\___/_/ /_//__/\___/_//_/ /_/|_|\___/_.__/\___/\__/_/\__/___/
# -----
# Filename : phoneme_confusion.py
# Author : junzhe.jiang
# Date : Tuesday, 2022-01-11 11:15
# Modified By: junzhe.jiang
# Modified At: Tuesday, 2022-01-11 11:39
# -----
# Describe : Calculate confusion matrix.
# 音素混淆矩阵统计代码
# 需要offline-test的config文件和需要统计结果的路径(在输入时需要按位置匹配)
# 根据 -o 输出选项 输出若干文件至 -p 路径
# -----
# History:
###############################################################################
import os
import sys
import argparse
import textwrap
import numpy as np
import confusion_matrix_module
from sklearn.metrics import confusion_matrix
class RawFormatter(argparse.HelpFormatter):
def _fill_text(self, text, width, indent):
return "\n".join(
[
textwrap.fill(line, width)
for line in textwrap.indent(textwrap.dedent(text), indent).splitlines()
]
)
def parse_args():
parser = argparse.ArgumentParser(
description=f"""
Calculate Confusion Matrix.\n
Usage:
python3 phoneme_confusion.py \\
-c CONFIG1 [CONFIG2 ...] \\
-r RES_DIR1 [RES_DIR2 ...] \\
-p OUTPUT_PATH
[...] and `-o` `-d` are optinal.
Notice that CONFIG-n must match RES_DIR-n.
""",
formatter_class=RawFormatter,
)
parser.add_argument(
"-c",
"--config",
help="task offline-test config yaml file path.",
nargs="+",
type=str,
required=True,
)
parser.add_argument(
"-r",
"--result",
help="The result data directory which is the parent directory for the result multi-ID folder.",
nargs="+",
type=str,
required=True,
)
parser.add_argument(
"-p",
"--path_output",
help="Path for output txt file and png file",
default=os.getcwd(),
type=str,
)
parser.add_argument(
"-o",
"--output_options",
help="Output options. 5-digit 0-1 code. 1 for output and 0 for not. First digit for output replacement confusion statistics. Second for png file. Third for change error statistics. Fourth for delete error statistics. Fifth for add error statistics.",
default="11000",
type=str,
)
parser.add_argument(
"-d",
"--decoder_extra",
help="You may need extra decoder root path for calculate based on other's result data. If some needed files couldn't be found, the program will use extra decoder.",
nargs="+",
default=[],
type=str,
)
args = parser.parse_args()
return args
def main_code(
data_path_all,
post_path_all,
recall_path_all,
graph_path_list_path_all,
phoneme_path_all,
output_path,
):
"""main_code
The main code of count phoneme confusion. Run this code will output to several files.
Parameters
----------
data_path_all : str
Contains all the data paths of results.
post_path_all : str
Contains all the post paths of offline process data.
recall_path_all : str
Contains all statistic/recall/.json files of offline process data to get the label.
graph_path_list_path_all : str
Contains all the graph path list file.
phoneme_path_all : str
Contains all the files which stores all the phonemes.
output_path : str
Path where outputs several results.
Returns
-------
None
"""
total_y_pred = []
total_y_true = []
total_add_list = []
total_del_list = []
count = 0
for data_path, post_path, recall_path, graph_path_list_path, phoneme_path in zip(
data_path_all,
post_path_all,
recall_path_all,
graph_path_list_path_all,
phoneme_path_all,
):
# 输出现在执行的文件目录
count += 1
print(f"*** Processing CONFIG {count}: {data_path} ***")
# 获取当前路径下所有存放数据的文件夹
folder_list = []
file_list = os.listdir(data_path)
# print(file_list)
for file_name in file_list:
if "." in file_name:
pass
else: # 判断是否是存放数据的文件夹
directory = os.listdir(os.path.join(data_path, file_name))
if "offline_process_data" in directory:
folder_list.append(file_name)
# print(folder_list)
(
phoneme2id,
id2phoneme,
total_phoneme_num,
plot_classes,
) = confusion_matrix_module.get_dict_between_phoneme_and_id(phoneme_path)
# print(total_phoneme_num)
word2phoneme = confusion_matrix_module.get_word2phoneme_dict(
graph_path_list_path
)
for folder in folder_list:
# 进入工作路径
work_dir = os.path.join(data_path, folder)
# 获取post_proc路径下所有文件
file_list = os.listdir(os.path.join(work_dir, post_path))
file_list.sort()
# 获取result文件
recall_folder = os.listdir(os.path.join(work_dir, recall_path))
for i in recall_folder:
if ".json" in i:
recall_file_path = i
break
# 打开statistics的label
f_label = open(
os.path.join(work_dir, recall_path, recall_file_path),
"r",
encoding="utf-8",
)
# 对每个文件操作
for label_line in f_label.readlines():
file_name = confusion_matrix_module.find_label_matched_post_proc_file(
label_line, file_list
)
if not file_name:
continue
with open(
os.path.join(work_dir, post_path, file_name), "r", encoding="utf-8"
) as f_predict:
# 将post_proc数据转换为概率矩阵
raw_list = np.array(
[
[float(item) for item in line.strip().split()]
for line in f_predict
]
)
# 找到最大值对应的phoneme
id_list = np.argmax(raw_list, axis=1)
phoneme_list = [phoneme2id[i] for i in id_list]
optimized_phoneme_list = (
confusion_matrix_module.get_optimized_phoneme_list(phoneme_list)
)
true_label_phoneme_list = (
confusion_matrix_module.get_true_label_phoneme_list(
label_line, word2phoneme
)
)
# 仅对label和pred都存在的数据进行判断
if true_label_phoneme_list and optimized_phoneme_list:
# 计算最小编辑距离
dp = confusion_matrix_module.min_ed_route(
true_label_phoneme_list, optimized_phoneme_list
)
res = confusion_matrix_module.find_route(
dp,
len(true_label_phoneme_list) - 1,
len(optimized_phoneme_list) - 1,
)
route_list = []
for route in res:
# 过滤掉所有的声母-韵母之前的替换错误,并将每种可能的替换路径结果(权重相同)加入到混淆矩阵中
if confusion_matrix_module.result_filter(route):
route_list.append(route)
y_pred = []
y_true = []
for route in route_list:
for i in route:
if i[0] == "REPLACE":
y_pred.append(id2phoneme[i[2]])
y_true.append(id2phoneme[i[4]])
if i[0] == "ADD":
total_add_list.append(i[2])
if i[0] == "DELETE":
total_del_list.append(i[2])
total_y_pred.extend(y_pred)
total_y_true.extend(y_true)
f_label.close()
# 生成音素混淆矩阵
matrix_result = confusion_matrix(
y_true=total_y_true,
y_pred=total_y_pred,
labels=[i for i in range(total_phoneme_num)],
)
# 只统计替换数大于error_threshold的错误
matrix_result[matrix_result <= error_threshold] = 0
# 保存音素混淆矩阵
# np.save('/mnt/mnt-data-3/junzhe.jiang/map_reduce_visual_audio_tool/mrtasks/mmcmd_predictor/scripts/tmp_save_arr3', matrix_result)
# 输出音素混淆统计文件
if output_options[0] == "1":
confusion_matrix_module.output2file(
matrix_result, phoneme2id, output_path, output_with_phoneme=True
)
# 输出音素混淆矩阵图像
if output_options[1] == "1":
confusion_matrix_module.plot_confusion_matrix(
matrix_result, plot_classes, output_path
)
# 输出替换错误、添加错误、删除错误文件
if output_options[2] == "1":
confusion_matrix_module.count_change_error(
total_phoneme_num, matrix_result, phoneme2id, output_path
)
if output_options[3] == "1":
confusion_matrix_module.count_delete_error(total_add_list, output_path)
if output_options[4] == "1":
confusion_matrix_module.count_add_error(total_del_list, output_path)
return
def use_saved_file(npy_file_path, phoneme_path):
"""use_saved_file
Use saved confusion matrix .npy file. The premise is to save the phoneme confusion matrix into the .npy file first.
Parameters
----------
npy_file_path : str
The saved confusion matrix .npy file path.
phoneme_path : str
Contains file which stores all the phonemes.
Returns
-------
None
"""
(
phoneme2id,
_,
_,
plot_classes,
) = confusion_matrix_module.get_dict_between_phoneme_and_id(phoneme_path)
loaded_matrix = np.load(npy_file_path)
confusion_matrix_module.output2file(
loaded_matrix, phoneme2id, output_with_phoneme=True
)
confusion_matrix_module.plot_confusion_matrix(loaded_matrix, classes=plot_classes)
return
if __name__ == "__main__":
# default config
error_threshold = 0 # 只统计替换错误数大于error_threshold的情况
output_to_file = True # 选择输出到文件或是输出到stdout
# get argument
args = parse_args()
config_all = args.config
result_path_all = args.result
output_options = args.output_options
decimal_output_options = int("0b" + output_options, 2)
output_path = args.path_output
decoder_root_all = args.decoder_extra
# 检查data_path是否存在
for i in result_path_all:
if not os.path.isdir(i):
raise Exception(f"data_path {i} not exists")
# check output_options
if decimal_output_options > 31 or decimal_output_options < 0:
raise Exception(
f"output_options {args.output_options} illegal. legal example: 11000, 10000, 00101..."
)
# 获取decoder路径
decoder_root = os.path.join(
os.path.abspath(
os.path.dirname(os.path.abspath(sys.argv[0])) + os.path.sep + ".."
),
"offline-test/decoder/",
)
# 将当前使用者的decoder插入第一位
decoder_root_all.insert(0, decoder_root)
# 获取计算音素混淆矩阵所需的各项参数
(
data_path_all,
post_path_all,
recall_path_all,
graph_path_list_path_all,
phoneme_path_all,
) = confusion_matrix_module.resolve_yaml_config(
config_all, result_path_all, decoder_root_all
)
# use_saved_file('/mnt/mnt-data-3/junzhe.jiang/map_reduce_visual_audio_tool/mrtasks/mmcmd_predictor/scripts/tmp_save_arr3.npy', "/mnt/mnt-data-3/junzhe.jiang/map_reduce_visual_audio_tool/mrtasks/mmcmd_predictor/offline-test/decoder/resources/mono.list")
main_code(
data_path_all,
post_path_all,
recall_path_all,
graph_path_list_path_all,
phoneme_path_all,
output_path,
)
New Annotation