From f95df001d743162822d750071f2fc8b472657d96 Mon Sep 17 00:00:00 2001 From: nl <3210346136@qq.com> Date: Sat, 13 Nov 2021 15:16:11 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0speech=20model?= =?UTF-8?q?=E6=A8=A1=E6=9D=BF=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- speech_model.py | 248 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 248 insertions(+) create mode 100644 speech_model.py diff --git a/speech_model.py b/speech_model.py new file mode 100644 index 0000000..abe82d2 --- /dev/null +++ b/speech_model.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2016-2099 Ailemon.net +# +# This file is part of ASRT Speech Recognition Tool. +# +# ASRT is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# ASRT is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with ASRT. If not, see . +# ============================================================================ + +""" +@author: nl8590687 +声学模型基础功能模板定义 +""" +import os +import time +import numpy as np +import random + +from utils.ops import get_edit_distance, read_wav_data +from utils.config import load_config_file, default_config_filename, load_pinyin_dict + +class ModelSpeech: + ''' + 语音模型类 + + 参数: + speech_model: 声学模型类型 (BaseModel类) 实例对象 + speech_features: 声学特征类型(SpeechFeatureMeta类)实例对象 + ''' + def __init__(self, speech_model, speech_features, max_label_length=64): + self.data_loader = None + self.speech_model = speech_model + self.trained_model, self.base_model = speech_model.get_model() + self.speech_features = speech_features + self.max_label_length = max_label_length + + def _data_generator(self, batch_size, data_loader): + ''' + 数据生成器函数,用于Keras的generator_fit训练 + batch_size: 一次产生的数据量 + ''' + labels = np.zeros((batch_size,1), dtype = np.float) + data_count = data_loader.get_data_count() + index = 0 + + while True: + X = np.zeros((batch_size, )+self.speech_model.input_shape, dtype = np.float) + y = np.zeros((batch_size, self.max_label_length), dtype=np.int16) + input_length = [] + label_length = [] + + for i in range(batch_size): + wavdata, fs, data_labels = data_loader.get_data(index) + data_input = self.speech_features.run(wavdata, fs) + data_input = data_input.reshape(data_input.shape[0],data_input.shape[1],1) + # 必须加上模pool_size得到的值,否则会出现inf问题,然后提示No valid path found. + # 但是直接加又可能会出现sequence_length <= xxx 的问题,因此不能让其超过时间序列长度的最大值,比如200 + pool_size = self.speech_model.input_shape[0] // self.speech_model.output_shape[0] + inlen = min(data_input.shape[0] // pool_size + data_input.shape[0] % pool_size, self.speech_model.output_shape[0]) + input_length.append(inlen) + + X[i,0:len(data_input)] = data_input + y[i,0:len(data_labels)] = data_labels + label_length.append([len(data_labels)]) + + label_length = np.matrix(label_length) + input_length = np.array([input_length]).T + + index = (index+1) % data_count + yield [X, y, input_length, label_length ], labels + + def train_model(self, optimizer, data_loader, epochs = 1, save_step = 1, batch_size = 16, last_epoch = 0, call_back=None): + ''' + 训练模型 + + 参数: + optimizer:tensorflow.keras.optimizers 优化器实例对象 + data_loader:数据加载器类型 (SpeechData) 实例对象 + epochs: 迭代轮数 + save_step: 每多少epoch保存一次模型 + batch_size: mini batch大小 + last_epoch: 上一次epoch的编号,可用于断点处继续训练时,epoch编号不冲突 + call_back: keras call back函数 + ''' + save_filename = os.path.join('save_models', self.speech_model.get_model_name(), self.speech_model.get_model_name()) + + self.trained_model.compile(loss=self.speech_model.get_loss_function(), optimizer = optimizer) + print('[ASRT] Compiles Model Successfully.') + + yielddatas = self._data_generator(batch_size, data_loader) + + data_count = data_loader.get_data_count() # 获取数据的数量 + # 计算每一个epoch迭代的次数 + num_iterate = data_count // batch_size + iter_start = last_epoch + iter_end = last_epoch + epochs + for epoch in range(iter_start, iter_end): # 迭代轮数 + try: + epoch += 1 + print('[ASRT Training] train epoch %d/%d .' % (epoch, iter_end)) + self.trained_model.fit_generator(yielddatas, num_iterate, callbacks = call_back) + except StopIteration: + print('[error] generator error. please check data format.') + break + + if epoch % save_step == 0: + if(not os.path.exists('save_models')): # 判断保存模型的目录是否存在 + os.makedirs('save_models') # 如果不存在,就新建一个,避免之后保存模型的时候炸掉 + if(not os.path.exists(os.path.join('save_models',self.speech_model.get_model_name()))): # 判断保存模型的目录是否存在 + os.makedirs(os.path.join('save_models',self.speech_model.get_model_name())) # 如果不存在,就新建一个,避免之后保存模型的时候炸掉 + + self.save_model(save_filename + '_epoch' + str(epoch)) + + print('[ASRT Info] Model training complete. ') + + def load_model(self,filename): + ''' + 加载模型参数 + ''' + self.speech_model.load_weights(filename) + + def save_model(self,filename): + ''' + 保存模型参数 + ''' + self.speech_model.save_weights(filename) + + def evaluate_model(self, data_loader, data_count = -1, out_report = False, show_ratio = True, show_per_step = 100): + ''' + 评估检验模型的识别效果 + ''' + data_nums = data_loader.get_data_count() + + if(data_count <= 0 or data_count > data_nums): # 当data_count为小于等于0或者大于测试数据量的值时,则使用全部数据来测试 + data_count = data_nums + + try: + ran_num = random.randint(0,data_nums - 1) # 获取一个随机数 + + words_num = 0 + word_error_num = 0 + + nowtime = time.strftime('%Y%m%d_%H%M%S',time.localtime(time.time())) + if(out_report == True): + txt_obj = open('Test_Report_' + data_loader.dataset_type + '_' + nowtime + '.txt', 'w', encoding='UTF-8') # 打开文件并读入 + txt_obj.truncate((data_count + 1) * 300) # 预先分配一定数量的磁盘空间,避免后期在硬盘中文件存储位置频繁移动,以防写入速度越来越慢 + txt_obj.seek(0) # 从文件首开始 + + txt = '' + i = 0 + while i < data_count: + wavdata, fs, data_labels = data_loader.get_data((ran_num + i) % data_nums) # 从随机数开始连续向后取一定数量数据 + data_input = self.speech_features.run(wavdata, fs) + data_input = data_input.reshape(data_input.shape[0],data_input.shape[1],1) + # 数据格式出错处理 开始 + # 当输入的wav文件长度过长时自动跳过该文件,转而使用下一个wav文件来运行 + while(data_input.shape[0] > self.speech_model.input_shape[0]): + print('*[Error]','wave data lenghth of num',(ran_num + i) % data_nums, 'is too long.','\n A Exception raise when test Speech Model.') + i += 1 + continue + # 数据格式出错处理 结束 + + pre = self.predict(data_input) + + words_n = data_labels.shape[0] # 获取每个句子的字数 + words_num += words_n # 把句子的总字数加上 + edit_distance = get_edit_distance(data_labels, pre) # 获取编辑距离 + if(edit_distance <= words_n): # 当编辑距离小于等于句子字数时 + word_error_num += edit_distance # 使用编辑距离作为错误字数 + else: # 否则肯定是增加了一堆乱七八糟的奇奇怪怪的字 + word_error_num += words_n # 就直接加句子本来的总字数就好了 + + if(i % show_per_step == 0 and show_ratio == True): + print('[ASRT Info] Testing: ',i,'/',data_count) + + txt = '' + if(out_report == True): + txt += str(i) + '\n' + txt += 'True:\t' + str(data_labels) + '\n' + txt += 'Pred:\t' + str(pre) + '\n' + txt += '\n' + txt_obj.write(txt) + + i += 1 + + #print('*[测试结果] 语音识别 ' + str_dataset + ' 集语音单字错误率:', word_error_num / words_num * 100, '%') + print('*[ASRT Test Result] Speech Recognition ' + data_loader.dataset_type + ' set word error ratio: ', word_error_num / words_num * 100, '%') + if(out_report == True): + txt = '*[ASRT Test Result] Speech Recognition ' + data_loader.dataset_type + ' set word error ratio: ' + str(word_error_num / words_num * 100) + ' %' + txt_obj.write(txt) + txt_obj.truncate() # 去除文件末尾剩余未使用的空白存储字节 + txt_obj.close() + + except StopIteration: + print('[ASRT Error] Model testing raise a error. Please check data format.') + + def predict(self, data_input): + ''' + 预测结果 + + 返回语音识别后的forward结果 + ''' + return self.speech_model.forward(data_input) + + def recognize_speech(self, wavsignal, fs): + ''' + 最终做语音识别用的函数,识别一个wav序列的语音 + ''' + # 获取输入特征 + data_input = self.speech_features.run(wavsignal, fs) + data_input = np.array(data_input, dtype = np.float) + #print(data_input,data_input.shape) + data_input = data_input.reshape(data_input.shape[0],data_input.shape[1],1) + r1 = self.predict(data_input) + # 获取拼音列表 + list_symbol_dic, _ = load_pinyin_dict(load_config_file(default_config_filename)['dict_filename']) + + r_str=[] + for i in r1: + r_str.append(list_symbol_dic[i]) + + return r_str + + def recognize_speech_from_file(self, filename): + ''' + 最终做语音识别用的函数,识别指定文件名的语音 + ''' + wavsignal,fs, _, _ = read_wav_data(filename) + r = self.recognize_speech(wavsignal, fs) + return r + + @property + def model(self): + ''' + 返回tf.keras model + ''' + return self.trained_model