Paste: co1

Author: 302
Mode: factor
Date: Tue, 11 Jan 2022 03:54:28
Plain Text |
# -*- coding:utf-8 -*-
###############################################################################
#   Copyright (c) 2022 Horizon Robotics, All rights reserved.
#      __ __         _                 ___       __        __  _
#     / // /__  ____(_)__ ___  ___    / _ \___  / /  ___  / /_(_)______
#    / _  / _ \/ __/ /_ // _ \/ _ \  / , _/ _ \/ _ \/ _ \/ __/ / __(_-<
#   /_//_/\___/_/ /_//__/\___/_//_/ /_/|_|\___/_.__/\___/\__/_/\__/___/
#   -----
#   Filename   : phoneme_confusion.py
#   Author     : junzhe.jiang
#   Date       : Tuesday, 2022-01-11 11:15
#   Modified By: junzhe.jiang
#   Modified At: Tuesday, 2022-01-11 11:39
#   -----
#   Describe   : Calculate confusion matrix.
#                音素混淆矩阵统计代码
#                需要offline-test的config文件和需要统计结果的路径(在输入时需要按位置匹配)
#                根据 -o 输出选项 输出若干文件至 -p 路径
#   -----
#   History:
###############################################################################


import os
import sys
import argparse
import textwrap
import numpy as np
import confusion_matrix_module
from sklearn.metrics import confusion_matrix


class RawFormatter(argparse.HelpFormatter):
    def _fill_text(self, text, width, indent):
        return "\n".join(
            [
                textwrap.fill(line, width)
                for line in textwrap.indent(textwrap.dedent(text), indent).splitlines()
            ]
        )


def parse_args():
    parser = argparse.ArgumentParser(
        description=f"""
                    Calculate Confusion Matrix.\n

                    Usage:
                        python3 phoneme_confusion.py \\
                        -c CONFIG1 [CONFIG2 ...] \\
                        -r RES_DIR1 [RES_DIR2 ...] \\
                        -p OUTPUT_PATH

                    [...] and `-o` `-d` are optinal.
                    Notice that CONFIG-n must match RES_DIR-n.
                    """,
        formatter_class=RawFormatter,
    )
    parser.add_argument(
        "-c",
        "--config",
        help="task offline-test config yaml file path.",
        nargs="+",
        type=str,
        required=True,
    )
    parser.add_argument(
        "-r",
        "--result",
        help="The result data directory which is the parent directory for the result multi-ID folder.",
        nargs="+",
        type=str,
        required=True,
    )
    parser.add_argument(
        "-p",
        "--path_output",
        help="Path for output txt file and png file",
        default=os.getcwd(),
        type=str,
    )
    parser.add_argument(
        "-o",
        "--output_options",
        help="Output options. 5-digit 0-1 code. 1 for output and 0 for not. First digit for output replacement confusion statistics. Second for png file. Third for change error statistics. Fourth for delete error statistics. Fifth for add error statistics.",
        default="11000",
        type=str,
    )
    parser.add_argument(
        "-d",
        "--decoder_extra",
        help="You may need extra decoder root path for calculate based on other's result data. If some needed files couldn't be found, the program will use extra decoder.",
        nargs="+",
        default=[],
        type=str,
    )
    args = parser.parse_args()
    return args


def main_code(
    data_path_all,
    post_path_all,
    recall_path_all,
    graph_path_list_path_all,
    phoneme_path_all,
    output_path,
):
    """main_code

    The main code of count phoneme confusion. Run this code will output to several files.

    Parameters
    ----------
    data_path_all : str
        Contains all the data paths of results.

    post_path_all : str
        Contains all the post paths of offline process data.

    recall_path_all : str
        Contains all statistic/recall/.json files of offline process data to get the label.

    graph_path_list_path_all : str
        Contains all the graph path list file.

    phoneme_path_all : str
        Contains all the files which stores all the phonemes.

    output_path : str
        Path where outputs several results.

    Returns
    -------
    None
    """

    total_y_pred = []
    total_y_true = []
    total_add_list = []
    total_del_list = []
    count = 0

    for data_path, post_path, recall_path, graph_path_list_path, phoneme_path in zip(
        data_path_all,
        post_path_all,
        recall_path_all,
        graph_path_list_path_all,
        phoneme_path_all,
    ):
        # 输出现在执行的文件目录
        count += 1
        print(f"*** Processing CONFIG {count}: {data_path} ***")

        # 获取当前路径下所有存放数据的文件夹
        folder_list = []
        file_list = os.listdir(data_path)
        # print(file_list)
        for file_name in file_list:
            if "." in file_name:
                pass
            else:  # 判断是否是存放数据的文件夹
                directory = os.listdir(os.path.join(data_path, file_name))
                if "offline_process_data" in directory:
                    folder_list.append(file_name)
        # print(folder_list)

        (
            phoneme2id,
            id2phoneme,
            total_phoneme_num,
            plot_classes,
        ) = confusion_matrix_module.get_dict_between_phoneme_and_id(phoneme_path)
        # print(total_phoneme_num)

        word2phoneme = confusion_matrix_module.get_word2phoneme_dict(
            graph_path_list_path
        )

        for folder in folder_list:
            # 进入工作路径
            work_dir = os.path.join(data_path, folder)

            # 获取post_proc路径下所有文件
            file_list = os.listdir(os.path.join(work_dir, post_path))
            file_list.sort()

            # 获取result文件
            recall_folder = os.listdir(os.path.join(work_dir, recall_path))
            for i in recall_folder:
                if ".json" in i:
                    recall_file_path = i
                    break

            # 打开statistics的label
            f_label = open(
                os.path.join(work_dir, recall_path, recall_file_path),
                "r",
                encoding="utf-8",
            )

            # 对每个文件操作
            for label_line in f_label.readlines():
                file_name = confusion_matrix_module.find_label_matched_post_proc_file(
                    label_line, file_list
                )
                if not file_name:
                    continue
                with open(
                    os.path.join(work_dir, post_path, file_name), "r", encoding="utf-8"
                ) as f_predict:
                    # 将post_proc数据转换为概率矩阵
                    raw_list = np.array(
                        [
                            [float(item) for item in line.strip().split()]
                            for line in f_predict
                        ]
                    )
                    # 找到最大值对应的phoneme
                    id_list = np.argmax(raw_list, axis=1)
                    phoneme_list = [phoneme2id[i] for i in id_list]

                optimized_phoneme_list = (
                    confusion_matrix_module.get_optimized_phoneme_list(phoneme_list)
                )

                true_label_phoneme_list = (
                    confusion_matrix_module.get_true_label_phoneme_list(
                        label_line, word2phoneme
                    )
                )

                # 仅对label和pred都存在的数据进行判断
                if true_label_phoneme_list and optimized_phoneme_list:
                    # 计算最小编辑距离
                    dp = confusion_matrix_module.min_ed_route(
                        true_label_phoneme_list, optimized_phoneme_list
                    )
                    res = confusion_matrix_module.find_route(
                        dp,
                        len(true_label_phoneme_list) - 1,
                        len(optimized_phoneme_list) - 1,
                    )
                    route_list = []
                    for route in res:
                        # 过滤掉所有的声母-韵母之前的替换错误,并将每种可能的替换路径结果(权重相同)加入到混淆矩阵中
                        if confusion_matrix_module.result_filter(route):
                            route_list.append(route)

                    y_pred = []
                    y_true = []
                    for route in route_list:
                        for i in route:
                            if i[0] == "REPLACE":
                                y_pred.append(id2phoneme[i[2]])
                                y_true.append(id2phoneme[i[4]])
                            if i[0] == "ADD":
                                total_add_list.append(i[2])
                            if i[0] == "DELETE":
                                total_del_list.append(i[2])
                    total_y_pred.extend(y_pred)
                    total_y_true.extend(y_true)

            f_label.close()

    # 生成音素混淆矩阵
    matrix_result = confusion_matrix(
        y_true=total_y_true,
        y_pred=total_y_pred,
        labels=[i for i in range(total_phoneme_num)],
    )

    # 只统计替换数大于error_threshold的错误
    matrix_result[matrix_result <= error_threshold] = 0

    # 保存音素混淆矩阵
    # np.save('/mnt/mnt-data-3/junzhe.jiang/map_reduce_visual_audio_tool/mrtasks/mmcmd_predictor/scripts/tmp_save_arr3', matrix_result)

    # 输出音素混淆统计文件
    if output_options[0] == "1":
        confusion_matrix_module.output2file(
            matrix_result, phoneme2id, output_path, output_with_phoneme=True
        )

    # 输出音素混淆矩阵图像
    if output_options[1] == "1":
        confusion_matrix_module.plot_confusion_matrix(
            matrix_result, plot_classes, output_path
        )

    # 输出替换错误、添加错误、删除错误文件
    if output_options[2] == "1":
        confusion_matrix_module.count_change_error(
            total_phoneme_num, matrix_result, phoneme2id, output_path
        )
    if output_options[3] == "1":
        confusion_matrix_module.count_delete_error(total_add_list, output_path)
    if output_options[4] == "1":
        confusion_matrix_module.count_add_error(total_del_list, output_path)

    return


def use_saved_file(npy_file_path, phoneme_path):
    """use_saved_file

    Use saved confusion matrix .npy file. The premise is to save the phoneme confusion matrix into the .npy file first.

    Parameters
    ----------
    npy_file_path : str
        The saved confusion matrix .npy file path.

    phoneme_path : str
        Contains file which stores all the phonemes.

    Returns
    -------
    None
    """
    (
        phoneme2id,
        _,
        _,
        plot_classes,
    ) = confusion_matrix_module.get_dict_between_phoneme_and_id(phoneme_path)
    loaded_matrix = np.load(npy_file_path)
    confusion_matrix_module.output2file(
        loaded_matrix, phoneme2id, output_with_phoneme=True
    )
    confusion_matrix_module.plot_confusion_matrix(loaded_matrix, classes=plot_classes)
    return


if __name__ == "__main__":
    # default config
    error_threshold = 0  # 只统计替换错误数大于error_threshold的情况
    output_to_file = True  # 选择输出到文件或是输出到stdout

    # get argument
    args = parse_args()
    config_all = args.config
    result_path_all = args.result
    output_options = args.output_options
    decimal_output_options = int("0b" + output_options, 2)
    output_path = args.path_output
    decoder_root_all = args.decoder_extra

    # 检查data_path是否存在
    for i in result_path_all:
        if not os.path.isdir(i):
            raise Exception(f"data_path {i} not exists")

    # check output_options
    if decimal_output_options > 31 or decimal_output_options < 0:
        raise Exception(
            f"output_options {args.output_options} illegal. legal example: 11000, 10000, 00101..."
        )

    # 获取decoder路径
    decoder_root = os.path.join(
        os.path.abspath(
            os.path.dirname(os.path.abspath(sys.argv[0])) + os.path.sep + ".."
        ),
        "offline-test/decoder/",
    )

    # 将当前使用者的decoder插入第一位
    decoder_root_all.insert(0, decoder_root)

    # 获取计算音素混淆矩阵所需的各项参数
    (
        data_path_all,
        post_path_all,
        recall_path_all,
        graph_path_list_path_all,
        phoneme_path_all,
    ) = confusion_matrix_module.resolve_yaml_config(
        config_all, result_path_all, decoder_root_all
    )

    # use_saved_file('/mnt/mnt-data-3/junzhe.jiang/map_reduce_visual_audio_tool/mrtasks/mmcmd_predictor/scripts/tmp_save_arr3.npy', "/mnt/mnt-data-3/junzhe.jiang/map_reduce_visual_audio_tool/mrtasks/mmcmd_predictor/offline-test/decoder/resources/mono.list")
    main_code(
        data_path_all,
        post_path_all,
        recall_path_all,
        graph_path_list_path_all,
        phoneme_path_all,
        output_path,
    )

New Annotation

Summary:
Author:
Mode:
Body: