add test model function
This commit is contained in:
parent
e4104f091a
commit
aa6a8053f9
|
@ -8,10 +8,12 @@ import os
|
|||
|
||||
from general_function.file_wav import *
|
||||
from general_function.file_dict import *
|
||||
from general_function.gen_func import *
|
||||
|
||||
# LSTM_CNN
|
||||
import keras as kr
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from keras.models import Sequential, Model
|
||||
from keras.layers import Dense, Dropout, Input # , Flatten,LSTM,Convolution1D,MaxPooling1D,Merge
|
||||
|
@ -174,6 +176,7 @@ class ModelSpeech(): # 语音模型类
|
|||
break
|
||||
|
||||
self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step * save_step))
|
||||
ms.TestModel(self.datapath, str_dataset='dev', data_count = 16)
|
||||
|
||||
|
||||
def LoadModel(self, filename = 'model_speech/speech_model_e_0_step_1.model'):
|
||||
|
@ -203,16 +206,22 @@ class ModelSpeech(): # 语音模型类
|
|||
data_count = num_data
|
||||
|
||||
try:
|
||||
gen = data.data_genetator(data_count)
|
||||
#for i in range(1):
|
||||
# [X, y, input_length, label_length ], labels = gen
|
||||
#r = self._model.test_on_batch([X, y, input_length, label_length ], labels)
|
||||
r = self._model.evaluate_generator(generator = gen, steps = 1, max_queue_size = data_count, workers = 1, use_multiprocessing = False)
|
||||
print(r)
|
||||
ran_num = random.randint(0,num_data - 1) # 获取一个随机数
|
||||
|
||||
words_num = 0
|
||||
word_error_num = 0
|
||||
for i in range(data_count):
|
||||
data_input, data_labels = data.GetData((ran_num + i) % num_data) # 从随机数开始连续向后取一定数量数据
|
||||
pre = self.Predict(data_input, data_input.shape[0] // 4)
|
||||
|
||||
words_num += max(data_labels.shape[0], pre.shape[0])
|
||||
word_error_num += GetEditDistance(data_labels, pre)
|
||||
|
||||
print('*[测试结果] 语音识别语音单字错误率:', word_error_num / words_num * 100, '%')
|
||||
except StopIteration:
|
||||
print('[Error] Model Test Error. please check data format.')
|
||||
|
||||
def Predict(self, batch_size, data_input, in_len):
|
||||
def Predict(self, data_input, input_len):
|
||||
'''
|
||||
预测结果
|
||||
返回语音识别后的拼音符号列表
|
||||
|
@ -220,8 +229,8 @@ class ModelSpeech(): # 语音模型类
|
|||
batch_size = 1
|
||||
|
||||
in_len = np.zeros((batch_size),dtype = np.int32)
|
||||
print(in_len.shape)
|
||||
in_len[0] = in_len[0] - 2
|
||||
#print(in_len.shape)
|
||||
in_len[0] = input_len - 2
|
||||
|
||||
|
||||
x_in = np.zeros((batch_size, 1600, 200), dtype=np.float)
|
||||
|
@ -230,10 +239,10 @@ class ModelSpeech(): # 语音模型类
|
|||
x_in[i,0:len(data_input)] = data_input
|
||||
|
||||
base_pred = self.base_model.predict(x = x_in)
|
||||
print('base_pred:\n', base_pred)
|
||||
#print('base_pred:\n', base_pred)
|
||||
|
||||
y_p = base_pred
|
||||
print('base_pred0:\n',base_pred[0][0].shape)
|
||||
#print('base_pred0:\n',base_pred[0][0].shape)
|
||||
|
||||
#for j in range(200):
|
||||
# mean = np.sum(y_p[0][j]) / y_p[0][j].shape[0]
|
||||
|
@ -247,48 +256,27 @@ class ModelSpeech(): # 语音模型类
|
|||
|
||||
base_pred =base_pred[:, 2:, :]
|
||||
r = K.ctc_decode(base_pred, in_len, greedy = True, beam_width=100, top_paths=1)
|
||||
print('r', r)
|
||||
#print('r', r)
|
||||
#r = K.cast(r[0][0], dtype='float32')
|
||||
#print('r1', r)
|
||||
#print('解码完成')
|
||||
|
||||
r1 = K.get_value(r[0][0])
|
||||
print('r1', r1)
|
||||
#print('r1', r1)
|
||||
|
||||
print('r0', r[1])
|
||||
#print('r0', r[1])
|
||||
r2 = K.get_value(r[1])
|
||||
print(r2)
|
||||
print('解码完成')
|
||||
#print('r2', r2)
|
||||
#print('解码完成')
|
||||
list_symbol_dic = GetSymbolList(self.datapath) # 获取拼音列表
|
||||
|
||||
r1=r1[0]
|
||||
|
||||
r_str=[]
|
||||
for i in r1:
|
||||
r_str.append(list_symbol_dic[i])
|
||||
|
||||
#print(r_str)
|
||||
|
||||
return r_str
|
||||
return r1
|
||||
pass
|
||||
|
||||
def show_edit_distance(self, num):
|
||||
num_left = num
|
||||
mean_norm_ed = 0.0
|
||||
mean_ed = 0.0
|
||||
while num_left > 0:
|
||||
word_batch = next(self.text_img_gen)[0]
|
||||
num_proc = min(word_batch['the_input'].shape[0], num_left)
|
||||
decoded_res = decode_batch(self.test_func, word_batch['the_input'][0:num_proc])
|
||||
for j in range(num_proc):
|
||||
edit_dist = editdistance.eval(decoded_res[j], word_batch['source_str'][j])
|
||||
mean_ed += float(edit_dist)
|
||||
mean_norm_ed += float(edit_dist) / len(word_batch['source_str'][j])
|
||||
num_left -= num_proc
|
||||
mean_norm_ed = mean_norm_ed / num
|
||||
mean_ed = mean_ed / num
|
||||
print('\nOut of %d samples: Mean edit distance: %.3f Mean normalized edit distance: %0.3f'
|
||||
% (num, mean_ed, mean_norm_ed))
|
||||
|
||||
|
||||
|
||||
def RecognizeSpeech(self, wavsignal, fs):
|
||||
'''
|
||||
|
@ -305,9 +293,15 @@ class ModelSpeech(): # 语音模型类
|
|||
data_input = np.array(data_input, dtype = np.float)
|
||||
|
||||
|
||||
r = self.Predict(1, data_input, input_length)
|
||||
r1 = self.Predict(data_input, input_length)
|
||||
|
||||
return r
|
||||
r_str=[]
|
||||
for i in r1:
|
||||
r_str.append(list_symbol_dic[i])
|
||||
|
||||
#print(r_str)
|
||||
|
||||
return r_str
|
||||
pass
|
||||
|
||||
def RecognizeSpeech_FromFile(self, filename):
|
||||
|
@ -350,8 +344,8 @@ if(__name__=='__main__'):
|
|||
|
||||
ms = ModelSpeech(datapath)
|
||||
|
||||
#ms.LoadModel(modelpath + 'speech_model_e_0_step_1.model')
|
||||
ms.TrainModel(datapath, epoch = 2, batch_size = 8, save_step = 1)
|
||||
#ms.TestModel(datapath, str_dataset='dev', data_count = 32)
|
||||
ms.LoadModel(modelpath + 'm1\\speech_model_e_1_step_100.model')
|
||||
#ms.TrainModel(datapath, epoch = 2, batch_size = 8, save_step = 1)
|
||||
ms.TestModel(datapath, str_dataset='dev', data_count = 8)
|
||||
#r = ms.RecognizeSpeech_FromFile('E:\\语音数据集\\wav\\test\\D4\\D4_750.wav')
|
||||
#print('*[提示] 语音识别结果:\n',r)
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
'''
|
||||
一些通用函数
|
||||
'''
|
||||
|
||||
import difflib
|
||||
|
||||
def GetEditDistance(str1, str2):
|
||||
leven_cost = 0
|
||||
s = difflib.SequenceMatcher(None, str1, str2)
|
||||
for tag, i1, i2, j1, j2 in s.get_opcodes():
|
||||
#print('{:7} a[{}: {}] --> b[{}: {}] {} --> {}'.format(tag, i1, i2, j1, j2, str1[i1: i2], str2[j1: j2]))
|
||||
|
||||
if tag == 'replace':
|
||||
leven_cost += max(i2-i1, j2-j1)
|
||||
elif tag == 'insert':
|
||||
leven_cost += (j2-j1)
|
||||
elif tag == 'delete':
|
||||
leven_cost += (i2-i1)
|
||||
return leven_cost
|
Loading…
Reference in New Issue