From aa6a8053f944f871f36ba62a78419147d7ec828d Mon Sep 17 00:00:00 2001 From: nl8590687 <3210346136@qq.com> Date: Wed, 11 Apr 2018 18:05:57 +0800 Subject: [PATCH] add test model function --- SpeechModel.py | 84 +++++++++++++++++------------------- general_function/gen_func.py | 22 ++++++++++ 2 files changed, 61 insertions(+), 45 deletions(-) create mode 100644 general_function/gen_func.py diff --git a/SpeechModel.py b/SpeechModel.py index 992f3bf..a8649cd 100644 --- a/SpeechModel.py +++ b/SpeechModel.py @@ -8,10 +8,12 @@ import os from general_function.file_wav import * from general_function.file_dict import * +from general_function.gen_func import * # LSTM_CNN import keras as kr import numpy as np +import random from keras.models import Sequential, Model from keras.layers import Dense, Dropout, Input # , Flatten,LSTM,Convolution1D,MaxPooling1D,Merge @@ -174,6 +176,7 @@ class ModelSpeech(): # 语音模型类 break self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step * save_step)) + ms.TestModel(self.datapath, str_dataset='dev', data_count = 16) def LoadModel(self, filename = 'model_speech/speech_model_e_0_step_1.model'): @@ -203,16 +206,22 @@ class ModelSpeech(): # 语音模型类 data_count = num_data try: - gen = data.data_genetator(data_count) - #for i in range(1): - # [X, y, input_length, label_length ], labels = gen - #r = self._model.test_on_batch([X, y, input_length, label_length ], labels) - r = self._model.evaluate_generator(generator = gen, steps = 1, max_queue_size = data_count, workers = 1, use_multiprocessing = False) - print(r) + ran_num = random.randint(0,num_data - 1) # 获取一个随机数 + + words_num = 0 + word_error_num = 0 + for i in range(data_count): + data_input, data_labels = data.GetData((ran_num + i) % num_data) # 从随机数开始连续向后取一定数量数据 + pre = self.Predict(data_input, data_input.shape[0] // 4) + + words_num += max(data_labels.shape[0], pre.shape[0]) + word_error_num += GetEditDistance(data_labels, pre) + + print('*[测试结果] 语音识别语音单字错误率:', word_error_num / words_num * 100, '%') except StopIteration: print('[Error] Model Test Error. please check data format.') - def Predict(self, batch_size, data_input, in_len): + def Predict(self, data_input, input_len): ''' 预测结果 返回语音识别后的拼音符号列表 @@ -220,8 +229,8 @@ class ModelSpeech(): # 语音模型类 batch_size = 1 in_len = np.zeros((batch_size),dtype = np.int32) - print(in_len.shape) - in_len[0] = in_len[0] - 2 + #print(in_len.shape) + in_len[0] = input_len - 2 x_in = np.zeros((batch_size, 1600, 200), dtype=np.float) @@ -230,10 +239,10 @@ class ModelSpeech(): # 语音模型类 x_in[i,0:len(data_input)] = data_input base_pred = self.base_model.predict(x = x_in) - print('base_pred:\n', base_pred) + #print('base_pred:\n', base_pred) y_p = base_pred - print('base_pred0:\n',base_pred[0][0].shape) + #print('base_pred0:\n',base_pred[0][0].shape) #for j in range(200): # mean = np.sum(y_p[0][j]) / y_p[0][j].shape[0] @@ -247,48 +256,27 @@ class ModelSpeech(): # 语音模型类 base_pred =base_pred[:, 2:, :] r = K.ctc_decode(base_pred, in_len, greedy = True, beam_width=100, top_paths=1) - print('r', r) + #print('r', r) #r = K.cast(r[0][0], dtype='float32') #print('r1', r) #print('解码完成') r1 = K.get_value(r[0][0]) - print('r1', r1) + #print('r1', r1) - print('r0', r[1]) + #print('r0', r[1]) r2 = K.get_value(r[1]) - print(r2) - print('解码完成') + #print('r2', r2) + #print('解码完成') list_symbol_dic = GetSymbolList(self.datapath) # 获取拼音列表 r1=r1[0] - r_str=[] - for i in r1: - r_str.append(list_symbol_dic[i]) - - #print(r_str) - - return r_str + return r1 pass - def show_edit_distance(self, num): - num_left = num - mean_norm_ed = 0.0 - mean_ed = 0.0 - while num_left > 0: - word_batch = next(self.text_img_gen)[0] - num_proc = min(word_batch['the_input'].shape[0], num_left) - decoded_res = decode_batch(self.test_func, word_batch['the_input'][0:num_proc]) - for j in range(num_proc): - edit_dist = editdistance.eval(decoded_res[j], word_batch['source_str'][j]) - mean_ed += float(edit_dist) - mean_norm_ed += float(edit_dist) / len(word_batch['source_str'][j]) - num_left -= num_proc - mean_norm_ed = mean_norm_ed / num - mean_ed = mean_ed / num - print('\nOut of %d samples: Mean edit distance: %.3f Mean normalized edit distance: %0.3f' - % (num, mean_ed, mean_norm_ed)) + + def RecognizeSpeech(self, wavsignal, fs): ''' @@ -305,9 +293,15 @@ class ModelSpeech(): # 语音模型类 data_input = np.array(data_input, dtype = np.float) - r = self.Predict(1, data_input, input_length) + r1 = self.Predict(data_input, input_length) - return r + r_str=[] + for i in r1: + r_str.append(list_symbol_dic[i]) + + #print(r_str) + + return r_str pass def RecognizeSpeech_FromFile(self, filename): @@ -350,8 +344,8 @@ if(__name__=='__main__'): ms = ModelSpeech(datapath) - #ms.LoadModel(modelpath + 'speech_model_e_0_step_1.model') - ms.TrainModel(datapath, epoch = 2, batch_size = 8, save_step = 1) - #ms.TestModel(datapath, str_dataset='dev', data_count = 32) + ms.LoadModel(modelpath + 'm1\\speech_model_e_1_step_100.model') + #ms.TrainModel(datapath, epoch = 2, batch_size = 8, save_step = 1) + ms.TestModel(datapath, str_dataset='dev', data_count = 8) #r = ms.RecognizeSpeech_FromFile('E:\\语音数据集\\wav\\test\\D4\\D4_750.wav') #print('*[提示] 语音识别结果:\n',r) diff --git a/general_function/gen_func.py b/general_function/gen_func.py new file mode 100644 index 0000000..ffd8e13 --- /dev/null +++ b/general_function/gen_func.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +''' +一些通用函数 +''' + +import difflib + +def GetEditDistance(str1, str2): + leven_cost = 0 + s = difflib.SequenceMatcher(None, str1, str2) + for tag, i1, i2, j1, j2 in s.get_opcodes(): + #print('{:7} a[{}: {}] --> b[{}: {}] {} --> {}'.format(tag, i1, i2, j1, j2, str1[i1: i2], str2[j1: j2])) + + if tag == 'replace': + leven_cost += max(i2-i1, j2-j1) + elif tag == 'insert': + leven_cost += (j2-j1) + elif tag == 'delete': + leven_cost += (i2-i1) + return leven_cost \ No newline at end of file