2022-09-18 20:56:14 +08:00
|
|
|
|
# !/usr/bin/env python3
|
2021-11-13 15:16:11 +08:00
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
#
|
|
|
|
|
# Copyright 2016-2099 Ailemon.net
|
|
|
|
|
#
|
|
|
|
|
# This file is part of ASRT Speech Recognition Tool.
|
|
|
|
|
#
|
|
|
|
|
# ASRT is free software: you can redistribute it and/or modify
|
|
|
|
|
# it under the terms of the GNU General Public License as published by
|
|
|
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
|
|
|
# (at your option) any later version.
|
|
|
|
|
# ASRT is distributed in the hope that it will be useful,
|
|
|
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
|
|
# GNU General Public License for more details.
|
|
|
|
|
#
|
|
|
|
|
# You should have received a copy of the GNU General Public License
|
|
|
|
|
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
@author: nl8590687
|
|
|
|
|
声学模型基础功能模板定义
|
|
|
|
|
"""
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
import random
|
2021-11-24 15:11:08 +08:00
|
|
|
|
import numpy as np
|
2021-11-13 15:16:11 +08:00
|
|
|
|
|
|
|
|
|
from utils.ops import get_edit_distance, read_wav_data
|
2021-11-24 15:11:08 +08:00
|
|
|
|
from utils.config import load_config_file, DEFAULT_CONFIG_FILENAME, load_pinyin_dict
|
2022-04-18 14:48:44 +08:00
|
|
|
|
from utils.thread import threadsafe_generator
|
2021-11-13 15:16:11 +08:00
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
|
2021-11-13 15:16:11 +08:00
|
|
|
|
class ModelSpeech:
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
语音模型类
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
speech_model: 声学模型类型 (BaseModel类) 实例对象
|
|
|
|
|
speech_features: 声学特征类型(SpeechFeatureMeta类)实例对象
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
|
|
|
|
|
2021-11-13 15:16:11 +08:00
|
|
|
|
def __init__(self, speech_model, speech_features, max_label_length=64):
|
|
|
|
|
self.data_loader = None
|
|
|
|
|
self.speech_model = speech_model
|
|
|
|
|
self.trained_model, self.base_model = speech_model.get_model()
|
|
|
|
|
self.speech_features = speech_features
|
|
|
|
|
self.max_label_length = max_label_length
|
|
|
|
|
|
2022-04-18 14:48:44 +08:00
|
|
|
|
@threadsafe_generator
|
2021-11-13 15:16:11 +08:00
|
|
|
|
def _data_generator(self, batch_size, data_loader):
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
数据生成器函数,用于Keras的generator_fit训练
|
|
|
|
|
batch_size: 一次产生的数据量
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2023-02-04 18:34:38 +08:00
|
|
|
|
labels = np.zeros((batch_size, 1), dtype=np.float64)
|
2021-11-13 15:16:11 +08:00
|
|
|
|
data_count = data_loader.get_data_count()
|
|
|
|
|
index = 0
|
2021-11-24 15:11:08 +08:00
|
|
|
|
|
2021-11-13 15:16:11 +08:00
|
|
|
|
while True:
|
2023-02-04 18:34:38 +08:00
|
|
|
|
X = np.zeros((batch_size,) + self.speech_model.input_shape, dtype=np.float64)
|
2021-11-13 15:16:11 +08:00
|
|
|
|
y = np.zeros((batch_size, self.max_label_length), dtype=np.int16)
|
|
|
|
|
input_length = []
|
|
|
|
|
label_length = []
|
2021-11-24 15:11:08 +08:00
|
|
|
|
|
2021-11-13 15:16:11 +08:00
|
|
|
|
for i in range(batch_size):
|
2021-11-24 15:11:08 +08:00
|
|
|
|
wavdata, sample_rate, data_labels = data_loader.get_data(index)
|
|
|
|
|
data_input = self.speech_features.run(wavdata, sample_rate)
|
2022-09-18 20:56:14 +08:00
|
|
|
|
data_input = data_input.reshape(data_input.shape[0], data_input.shape[1], 1)
|
2021-11-13 15:16:11 +08:00
|
|
|
|
# 必须加上模pool_size得到的值,否则会出现inf问题,然后提示No valid path found.
|
|
|
|
|
# 但是直接加又可能会出现sequence_length <= xxx 的问题,因此不能让其超过时间序列长度的最大值,比如200
|
|
|
|
|
pool_size = self.speech_model.input_shape[0] // self.speech_model.output_shape[0]
|
2022-09-18 20:56:14 +08:00
|
|
|
|
inlen = min(data_input.shape[0] // pool_size + data_input.shape[0] % pool_size,
|
|
|
|
|
self.speech_model.output_shape[0])
|
2021-11-13 15:16:11 +08:00
|
|
|
|
input_length.append(inlen)
|
2021-11-24 15:11:08 +08:00
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
X[i, 0:len(data_input)] = data_input
|
|
|
|
|
y[i, 0:len(data_labels)] = data_labels
|
2021-11-13 15:16:11 +08:00
|
|
|
|
label_length.append([len(data_labels)])
|
2021-11-24 15:11:08 +08:00
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
index = (index + 1) % data_count
|
2022-03-13 11:04:00 +08:00
|
|
|
|
|
2021-11-13 15:16:11 +08:00
|
|
|
|
label_length = np.matrix(label_length)
|
|
|
|
|
input_length = np.array([input_length]).T
|
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
yield [X, y, input_length, label_length], labels
|
2021-11-13 15:16:11 +08:00
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
def train_model(self, optimizer, data_loader, epochs=1, save_step=1, batch_size=16, last_epoch=0, call_back=None):
|
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
训练模型
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
optimizer:tensorflow.keras.optimizers 优化器实例对象
|
|
|
|
|
data_loader:数据加载器类型 (SpeechData) 实例对象
|
|
|
|
|
epochs: 迭代轮数
|
|
|
|
|
save_step: 每多少epoch保存一次模型
|
|
|
|
|
batch_size: mini batch大小
|
|
|
|
|
last_epoch: 上一次epoch的编号,可用于断点处继续训练时,epoch编号不冲突
|
|
|
|
|
call_back: keras call back函数
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
|
|
|
|
save_filename = os.path.join('save_models', self.speech_model.get_model_name(),
|
|
|
|
|
self.speech_model.get_model_name())
|
2021-11-13 15:16:11 +08:00
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
self.trained_model.compile(loss=self.speech_model.get_loss_function(), optimizer=optimizer)
|
2021-11-13 15:16:11 +08:00
|
|
|
|
print('[ASRT] Compiles Model Successfully.')
|
|
|
|
|
|
|
|
|
|
yielddatas = self._data_generator(batch_size, data_loader)
|
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
data_count = data_loader.get_data_count() # 获取数据的数量
|
2021-11-13 15:16:11 +08:00
|
|
|
|
# 计算每一个epoch迭代的次数
|
|
|
|
|
num_iterate = data_count // batch_size
|
|
|
|
|
iter_start = last_epoch
|
|
|
|
|
iter_end = last_epoch + epochs
|
2022-09-18 20:56:14 +08:00
|
|
|
|
for epoch in range(iter_start, iter_end): # 迭代轮数
|
2021-11-13 15:16:11 +08:00
|
|
|
|
try:
|
|
|
|
|
epoch += 1
|
|
|
|
|
print('[ASRT Training] train epoch %d/%d .' % (epoch, iter_end))
|
2021-11-17 19:38:14 +08:00
|
|
|
|
data_loader.shuffle()
|
2022-09-18 20:56:14 +08:00
|
|
|
|
self.trained_model.fit_generator(yielddatas, num_iterate, callbacks=call_back)
|
2021-11-13 15:16:11 +08:00
|
|
|
|
except StopIteration:
|
|
|
|
|
print('[error] generator error. please check data format.')
|
|
|
|
|
break
|
2021-11-24 15:11:08 +08:00
|
|
|
|
|
2021-11-13 15:16:11 +08:00
|
|
|
|
if epoch % save_step == 0:
|
2022-09-18 20:56:14 +08:00
|
|
|
|
if not os.path.exists('save_models'): # 判断保存模型的目录是否存在
|
|
|
|
|
os.makedirs('save_models') # 如果不存在,就新建一个,避免之后保存模型的时候炸掉
|
|
|
|
|
if not os.path.exists(os.path.join('save_models', self.speech_model.get_model_name())): # 判断保存模型的目录是否存在
|
|
|
|
|
os.makedirs(
|
|
|
|
|
os.path.join('save_models', self.speech_model.get_model_name())) # 如果不存在,就新建一个,避免之后保存模型的时候炸掉
|
2021-11-13 15:16:11 +08:00
|
|
|
|
|
|
|
|
|
self.save_model(save_filename + '_epoch' + str(epoch))
|
|
|
|
|
|
|
|
|
|
print('[ASRT Info] Model training complete. ')
|
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
def load_model(self, filename):
|
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
加载模型参数
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
self.speech_model.load_weights(filename)
|
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
def save_model(self, filename):
|
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
保存模型参数
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
self.speech_model.save_weights(filename)
|
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
def evaluate_model(self, data_loader, data_count=-1, out_report=False, show_ratio=True, show_per_step=100):
|
|
|
|
|
"""
|
2021-11-24 15:11:08 +08:00
|
|
|
|
评估检验模型的识别效果
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
data_nums = data_loader.get_data_count()
|
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
if data_count <= 0 or data_count > data_nums: # 当data_count为小于等于0或者大于测试数据量的值时,则使用全部数据来测试
|
2021-11-13 15:16:11 +08:00
|
|
|
|
data_count = data_nums
|
|
|
|
|
|
|
|
|
|
try:
|
2022-09-18 20:56:14 +08:00
|
|
|
|
ran_num = random.randint(0, data_nums - 1) # 获取一个随机数
|
2021-11-13 15:16:11 +08:00
|
|
|
|
words_num = 0
|
|
|
|
|
word_error_num = 0
|
2021-11-24 15:11:08 +08:00
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
nowtime = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
|
2021-11-24 15:11:08 +08:00
|
|
|
|
if out_report:
|
2022-09-18 20:56:14 +08:00
|
|
|
|
txt_obj = open('Test_Report_' + data_loader.dataset_type + '_' + nowtime + '.txt', 'w',
|
|
|
|
|
encoding='UTF-8') # 打开文件并读入
|
|
|
|
|
txt_obj.truncate((data_count + 1) * 300) # 预先分配一定数量的磁盘空间,避免后期在硬盘中文件存储位置频繁移动,以防写入速度越来越慢
|
|
|
|
|
txt_obj.seek(0) # 从文件首开始
|
2021-11-24 15:11:08 +08:00
|
|
|
|
|
2021-11-13 15:16:11 +08:00
|
|
|
|
txt = ''
|
|
|
|
|
i = 0
|
|
|
|
|
while i < data_count:
|
|
|
|
|
wavdata, fs, data_labels = data_loader.get_data((ran_num + i) % data_nums) # 从随机数开始连续向后取一定数量数据
|
|
|
|
|
data_input = self.speech_features.run(wavdata, fs)
|
2022-09-18 20:56:14 +08:00
|
|
|
|
data_input = data_input.reshape(data_input.shape[0], data_input.shape[1], 1)
|
2021-11-13 15:16:11 +08:00
|
|
|
|
# 数据格式出错处理 开始
|
|
|
|
|
# 当输入的wav文件长度过长时自动跳过该文件,转而使用下一个wav文件来运行
|
2021-12-03 14:51:44 +08:00
|
|
|
|
if data_input.shape[0] > self.speech_model.input_shape[0]:
|
2022-09-18 20:56:14 +08:00
|
|
|
|
print('*[Error]', 'wave data lenghth of num', (ran_num + i) % data_nums, 'is too long.',
|
|
|
|
|
'this data\'s length is', data_input.shape[0],
|
|
|
|
|
'expect <=', self.speech_model.input_shape[0],
|
|
|
|
|
'\n A Exception raise when test Speech Model.')
|
2021-11-13 15:16:11 +08:00
|
|
|
|
i += 1
|
|
|
|
|
continue
|
|
|
|
|
# 数据格式出错处理 结束
|
|
|
|
|
|
|
|
|
|
pre = self.predict(data_input)
|
2021-11-24 15:11:08 +08:00
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
words_n = data_labels.shape[0] # 获取每个句子的字数
|
|
|
|
|
words_num += words_n # 把句子的总字数加上
|
|
|
|
|
edit_distance = get_edit_distance(data_labels, pre) # 获取编辑距离
|
|
|
|
|
if edit_distance <= words_n: # 当编辑距离小于等于句子字数时
|
|
|
|
|
word_error_num += edit_distance # 使用编辑距离作为错误字数
|
|
|
|
|
else: # 否则肯定是增加了一堆乱七八糟的奇奇怪怪的字
|
|
|
|
|
word_error_num += words_n # 就直接加句子本来的总字数就好了
|
2021-11-13 15:16:11 +08:00
|
|
|
|
|
2021-11-24 15:11:08 +08:00
|
|
|
|
if i % show_per_step == 0 and show_ratio:
|
2022-09-18 20:56:14 +08:00
|
|
|
|
print('[ASRT Info] Testing: ', i, '/', data_count)
|
2021-11-24 15:11:08 +08:00
|
|
|
|
|
2021-11-13 15:16:11 +08:00
|
|
|
|
txt = ''
|
2021-11-24 15:11:08 +08:00
|
|
|
|
if out_report:
|
2021-11-13 15:16:11 +08:00
|
|
|
|
txt += str(i) + '\n'
|
|
|
|
|
txt += 'True:\t' + str(data_labels) + '\n'
|
|
|
|
|
txt += 'Pred:\t' + str(pre) + '\n'
|
|
|
|
|
txt += '\n'
|
|
|
|
|
txt_obj.write(txt)
|
2021-11-24 15:11:08 +08:00
|
|
|
|
|
2021-11-13 15:16:11 +08:00
|
|
|
|
i += 1
|
2021-11-24 15:11:08 +08:00
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
# print('*[测试结果] 语音识别 ' + str_dataset + ' 集语音单字错误率:', word_error_num / words_num * 100, '%')
|
|
|
|
|
print('*[ASRT Test Result] Speech Recognition ' + data_loader.dataset_type + ' set word error ratio: ',
|
|
|
|
|
word_error_num / words_num * 100, '%')
|
2021-11-24 15:11:08 +08:00
|
|
|
|
if out_report:
|
2022-09-18 20:56:14 +08:00
|
|
|
|
txt = '*[ASRT Test Result] Speech Recognition ' + data_loader.dataset_type + ' set word error ratio: ' + str(
|
|
|
|
|
word_error_num / words_num * 100) + ' %'
|
2021-11-13 15:16:11 +08:00
|
|
|
|
txt_obj.write(txt)
|
2022-09-18 20:56:14 +08:00
|
|
|
|
txt_obj.truncate() # 去除文件末尾剩余未使用的空白存储字节
|
2021-11-13 15:16:11 +08:00
|
|
|
|
txt_obj.close()
|
2021-11-24 15:11:08 +08:00
|
|
|
|
|
2021-11-13 15:16:11 +08:00
|
|
|
|
except StopIteration:
|
|
|
|
|
print('[ASRT Error] Model testing raise a error. Please check data format.')
|
|
|
|
|
|
|
|
|
|
def predict(self, data_input):
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
预测结果
|
|
|
|
|
|
|
|
|
|
返回语音识别后的forward结果
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
return self.speech_model.forward(data_input)
|
|
|
|
|
|
|
|
|
|
def recognize_speech(self, wavsignal, fs):
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
最终做语音识别用的函数,识别一个wav序列的语音
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
# 获取输入特征
|
|
|
|
|
data_input = self.speech_features.run(wavsignal, fs)
|
2023-02-04 18:34:38 +08:00
|
|
|
|
data_input = np.array(data_input, dtype=np.float64)
|
2022-09-18 20:56:14 +08:00
|
|
|
|
# print(data_input,data_input.shape)
|
|
|
|
|
data_input = data_input.reshape(data_input.shape[0], data_input.shape[1], 1)
|
2021-11-13 15:16:11 +08:00
|
|
|
|
r1 = self.predict(data_input)
|
|
|
|
|
# 获取拼音列表
|
2021-11-24 15:11:08 +08:00
|
|
|
|
list_symbol_dic, _ = load_pinyin_dict(load_config_file(DEFAULT_CONFIG_FILENAME)['dict_filename'])
|
2021-11-13 15:16:11 +08:00
|
|
|
|
|
2022-09-18 20:56:14 +08:00
|
|
|
|
r_str = []
|
2021-11-13 15:16:11 +08:00
|
|
|
|
for i in r1:
|
|
|
|
|
r_str.append(list_symbol_dic[i])
|
|
|
|
|
|
|
|
|
|
return r_str
|
|
|
|
|
|
|
|
|
|
def recognize_speech_from_file(self, filename):
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
最终做语音识别用的函数,识别指定文件名的语音
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
|
|
|
|
wavsignal, sample_rate, _, _ = read_wav_data(filename)
|
2021-11-24 15:11:08 +08:00
|
|
|
|
r = self.recognize_speech(wavsignal, sample_rate)
|
2021-11-13 15:16:11 +08:00
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def model(self):
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
返回tf.keras model
|
2022-09-18 20:56:14 +08:00
|
|
|
|
"""
|
2021-11-13 15:16:11 +08:00
|
|
|
|
return self.trained_model
|