# -*- coding:utf-8 -*-
import os
import sys
import argparse
import textwrap
import numpy as np
import confusion_matrix_module
from sklearn.metrics import confusion_matrix


class RawFormatter(argparse.HelpFormatter):
    def _fill_text(self, text, width, indent):
        return "\n".join(
            [
                textwrap.fill(line, width)
                for line in textwrap.indent(textwrap.dedent(text), indent).splitlines()
            ]
        )


def parse_args():
    parser = argparse.ArgumentParser(
        description=f"""
        Calculate Confusion Matrix.\n
        Usage:
            python3 phoneme_confusion.py \\
                -c CONFIG1 [CONFIG2 ...] \\
                -r RES_DIR1 [RES_DIR2 ...] \\
                -p OUTPUT_PATH [...]
        and `-o` `-d` are optinal. Notice that CONFIG-n must match RES_DIR-n. """, formatter_class=RawFormatter, ) parser.add_argument( "-c", "--config", help="task offline-test config yaml file path.", nargs="+", type=str, required=True, ) parser.add_argument( "-r", "--result", help="The result data directory which is the parent directory for the result multi-ID folder.", nargs="+", type=str, required=True, ) parser.add_argument( "-p", "--path_output", help="Path for output txt file and png file", default=os.getcwd(), type=str, ) parser.add_argument( "-o", "--output_options", help="Output options. 5-digit 0-1 code. 1 for output and 0 for not. First digit for output replacement confusion statistics. Second for png file. Third for change error statistics. Fourth for delete error statistics. Fifth for add error statistics.", default="11000", type=str, ) parser.add_argument( "-d", "--decoder_extra", help="You may need extra decoder root path for calculate based on other's result data. If some needed files couldn't be found, the program will use extra decoder.", nargs="+", default=[], type=str, ) args = parser.parse_args() return args def main_code( data_path_all, post_path_all, recall_path_all, graph_path_list_path_all, phoneme_path_all, output_path, ): """main_code The main code of count phoneme confusion. Run this code will output to several files. Parameters ---------- data_path_all : str Contains all the data paths of results. post_path_all : str Contains all the post paths of offline process data. recall_path_all : str Contains all statistic/recall/.json files of offline process data to get the label. graph_path_list_path_all : str Contains all the graph path list file. phoneme_path_all : str Contains all the files which stores all the phonemes. output_path : str Path where outputs several results. Returns ------- None """ total_y_pred = [] total_y_true = [] total_add_list = [] total_del_list = [] count = 0 for data_path, post_path, recall_path, graph_path_list_path, phoneme_path in zip( data_path_all, post_path_all, recall_path_all, graph_path_list_path_all, phoneme_path_all, ): # 输出现在执行的文件目录 count += 1 print(f"*** Processing CONFIG {count}: {data_path} ***") # 获取当前路径下所有存放数据的文件夹 folder_list = [] file_list = os.listdir(data_path) # print(file_list) for file_name in file_list: if "." in file_name: pass else: # 判断是否是存放数据的文件夹 directory = os.listdir(os.path.join(data_path, file_name)) if "offline_process_data" in directory: folder_list.append(file_name) # print(folder_list) ( phoneme2id, id2phoneme, total_phoneme_num, plot_classes, ) = confusion_matrix_module.get_dict_between_phoneme_and_id(phoneme_path) # print(total_phoneme_num) word2phoneme = confusion_matrix_module.get_word2phoneme_dict( graph_path_list_path ) for folder in folder_list: # 进入工作路径 work_dir = os.path.join(data_path, folder) # 获取post_proc路径下所有文件 file_list = os.listdir(os.path.join(work_dir, post_path)) file_list.sort() # 获取result文件 recall_folder = os.listdir(os.path.join(work_dir, recall_path)) for i in recall_folder: if ".json" in i: recall_file_path = i break # 打开statistics的label f_label = open( os.path.join(work_dir, recall_path, recall_file_path), "r", encoding="utf-8", ) # 对每个文件操作 for label_line in f_label.readlines(): file_name = confusion_matrix_module.find_label_matched_post_proc_file( label_line, file_list ) if not file_name: continue with open( os.path.join(work_dir, post_path, file_name), "r", encoding="utf-8" ) as f_predict: # 将post_proc数据转换为概率矩阵 raw_list = np.array( [ [float(item) for item in line.strip().split()] for line in f_predict ] ) # 找到最大值对应的phoneme id_list = np.argmax(raw_list, axis=1) phoneme_list = [phoneme2id[i] for i in id_list] optimized_phoneme_list = ( confusion_matrix_module.get_optimized_phoneme_list(phoneme_list) ) true_label_phoneme_list = ( confusion_matrix_module.get_true_label_phoneme_list( label_line, word2phoneme ) ) # 仅对label和pred都存在的数据进行判断 if true_label_phoneme_list and optimized_phoneme_list: # 计算最小编辑距离 dp = confusion_matrix_module.min_ed_route( true_label_phoneme_list, optimized_phoneme_list ) res = confusion_matrix_module.find_route( dp, len(true_label_phoneme_list) - 1, len(optimized_phoneme_list) - 1, ) route_list = [] for route in res: # 过滤掉所有的声母-韵母之前的替换错误,并将每种可能的替换路径结果(权重相同)加入到混淆矩阵中 if confusion_matrix_module.result_filter(route): route_list.append(route) y_pred = [] y_true = [] for route in route_list: for i in route: if i[0] == "REPLACE": y_pred.append(id2phoneme[i[2]]) y_true.append(id2phoneme[i[4]]) if i[0] == "ADD": total_add_list.append(i[2]) if i[0] == "DELETE": total_del_list.append(i[2]) total_y_pred.extend(y_pred) total_y_true.extend(y_true) f_label.close() # 生成音素混淆矩阵 matrix_result = confusion_matrix( y_true=total_y_true, y_pred=total_y_pred, labels=[i for i in range(total_phoneme_num)], ) # 只统计替换数大于error_threshold的错误 matrix_result[matrix_result <= error_threshold] = 0 # 保存音素混淆矩阵 # np.save('/mnt/mnt-data-3/junzhe.jiang/map_reduce_visual_audio_tool/mrtasks/mmcmd_predictor/scripts/tmp_save_arr3', matrix_result) # 输出音素混淆统计文件 if output_options[0] == "1": confusion_matrix_module.output2file( matrix_result, phoneme2id, output_path, output_with_phoneme=True ) # 输出音素混淆矩阵图像 if output_options[1] == "1": confusion_matrix_module.plot_confusion_matrix( matrix_result, plot_classes, output_path ) # 输出替换错误、添加错误、删除错误文件 if output_options[2] == "1": confusion_matrix_module.count_change_error( total_phoneme_num, matrix_result, phoneme2id, output_path ) if output_options[3] == "1": confusion_matrix_module.count_delete_error(total_add_list, output_path) if output_options[4] == "1": confusion_matrix_module.count_add_error(total_del_list, output_path) return def use_saved_file(npy_file_path, phoneme_path): """use_saved_file Use saved confusion matrix .npy file. The premise is to save the phoneme confusion matrix into the .npy file first. Parameters ---------- npy_file_path : str The saved confusion matrix .npy file path. phoneme_path : str Contains file which stores all the phonemes. Returns ------- None """ ( phoneme2id, _, _, plot_classes, ) = confusion_matrix_module.get_dict_between_phoneme_and_id(phoneme_path) loaded_matrix = np.load(npy_file_path) confusion_matrix_module.output2file( loaded_matrix, phoneme2id, output_with_phoneme=True ) confusion_matrix_module.plot_confusion_matrix(loaded_matrix, classes=plot_classes) return if __name__ == "__main__": # default config error_threshold = 0 # 只统计替换错误数大于error_threshold的情况 output_to_file = True # 选择输出到文件或是输出到stdout # get argument args = parse_args() config_all = args.config result_path_all = args.result output_options = args.output_options decimal_output_options = int("0b" + output_options, 2) output_path = args.path_output decoder_root_all = args.decoder_extra # 检查data_path是否存在 for i in result_path_all: if not os.path.isdir(i): raise Exception(f"data_path {i} not exists") # check output_options if decimal_output_options > 31 or decimal_output_options < 0: raise Exception( f"output_options {args.output_options} illegal. legal example: 11000, 10000, 00101..." ) # 获取decoder路径 decoder_root = os.path.join( os.path.abspath( os.path.dirname(os.path.abspath(sys.argv[0])) + os.path.sep + ".." ), "offline-test/decoder/", ) # 将当前使用者的decoder插入第一位 decoder_root_all.insert(0, decoder_root) # 获取计算音素混淆矩阵所需的各项参数 ( data_path_all, post_path_all, recall_path_all, graph_path_list_path_all, phoneme_path_all, ) = confusion_matrix_module.resolve_yaml_config( config_all, result_path_all, decoder_root_all ) # use_saved_file('/mnt/mnt-data-3/junzhe.jiang/map_reduce_visual_audio_tool/mrtasks/mmcmd_predictor/scripts/tmp_save_arr3.npy', "/mnt/mnt-data-3/junzhe.jiang/map_reduce_visual_audio_tool/mrtasks/mmcmd_predictor/offline-test/decoder/resources/mono.list") main_code( data_path_all, post_path_all, recall_path_all, graph_path_list_path_all, phoneme_path_all, output_path, )