From cb69b5f7982bd19dd6f02aa5cfab5b163131d1aa Mon Sep 17 00:00:00 2001 From: nl8590687 <3210346136@qq.com> Date: Tue, 10 Apr 2018 14:41:59 +0800 Subject: [PATCH] a little change --- SpeechModel2.py | 6 +- SpeechModel5.py | 140 ++++++++++++++++++++++++++++++++++++-------- SpeechModel5_old.py | 2 +- readdata.py | 2 +- 4 files changed, 120 insertions(+), 30 deletions(-) diff --git a/SpeechModel2.py b/SpeechModel2.py index 20f0c2e..cf70f5b 100644 --- a/SpeechModel2.py +++ b/SpeechModel2.py @@ -34,7 +34,7 @@ class ModelSpeech(): # 语音模型类 self.label_max_string_length = 64 self.AUDIO_LENGTH = 1600 self.AUDIO_FEATURE_LENGTH = 200 - self._model = self.CreateModel() + self._model, self.base_model = self.CreateModel() @@ -105,7 +105,7 @@ class ModelSpeech(): # 语音模型类 test_func = K.function([input_data], [y_pred]) print('[*提示] 创建模型成功,模型编译成功') - return model + return model, model_data def ctc_lambda_func(self, args): y_pred, labels, input_length, label_length = args @@ -150,12 +150,14 @@ class ModelSpeech(): # 语音模型类 加载模型参数 ''' self._model.load_weights(filename) + self.base_model.load_weights(filename + '.base') def SaveModel(self,filename='model_speech/speech_model2',comment=''): ''' 保存模型参数 ''' self._model.save_weights(filename+comment+'.model') + self.base_model.save_weights(filename + comment + '.model.base') def TestModel(self, datapath, str_dataset='dev'): ''' diff --git a/SpeechModel5.py b/SpeechModel5.py index 1c960ba..d1df752 100644 --- a/SpeechModel5.py +++ b/SpeechModel5.py @@ -134,6 +134,8 @@ class ModelSpeech(): # 语音模型类 # captures output of softmax so we can decode the output during visualization self.test_func = K.function([input_data], [y_pred]) + #top_k_decoded, _ = K.ctc_decode(y_pred, input_length, greedy = True, beam_width=100, top_paths=1) + #self.decoder = K.function([input_data, input_length], [top_k_decoded[0]]) print('[*提示] 创建模型成功,模型编译成功') return model, model_data @@ -256,15 +258,101 @@ class ModelSpeech(): # 语音模型类 最终做语音识别用的函数,识别一个wav序列的语音 不过这里现在还有bug ''' - #data = self.data data = DataSpeech('E:\\语音数据集') data.LoadDataList('dev') # 获取输入特征 #data_input = data.GetMfccFeature(wavsignal, fs) data_input = data.GetFrequencyFeature(wavsignal, fs) + input_length = len(data_input) + input_length = input_length // 4 - arr_zero = np.zeros((1, 200), dtype=np.int16) #一个全是0的行向量 + data_input = np.array(data_input, dtype = np.float) + in_len = np.zeros((1),dtype = np.int32) + print(in_len.shape) + in_len[0] = input_length -2 + + + batch_size = 1 + x_in = np.zeros((batch_size, 1600, 200), dtype=np.float) + + for i in range(batch_size): + x_in[i,0:len(data_input)] = data_input + + + + base_pred = self.base_model.predict(x = x_in) + print('base_pred:\n', base_pred) + + + y_p = base_pred + print('base_pred0:\n',base_pred[0][0].shape) + + for j in range(200): + mean = np.sum(y_p[0][j]) / y_p[0][j].shape[0] + print('max y_p:',np.max(y_p[0][j]),'min y_p:',np.min(y_p[0][j]),'mean y_p:',mean,'mid y_p:',y_p[0][j][100]) + print('argmin:',np.argmin(y_p[0][j]),'argmax:',np.argmax(y_p[0][j])) + count=0 + for i in range(y_p[0][j].shape[0]): + if(y_p[0][j][i] < mean): + count += 1 + print('count:',count) + #for j in range(0,200): + # mean = sum(y_p[0][0][j]) / len(y_p[0][0][j]) + # print('max y_p:',max(y_p[0][0][j]),'min y_p:',min(y_p[0][0][j]),'mean y_p:',mean,'mid y_p:',y_p[0][0][j][100]) + # print('argmin:',np.argmin(y_p[0][0][j]),'argmax:',np.argmax(y_p[0][0][j])) + # count=0 + # for i in y_p[0][0][j]: + # if(i < mean): + # count += 1 + # print('count:',count) + #decoded_sequences = self.decoder([base_pred, in_len]) + + #print('decoded_sequences:\n', decoded_sequences) + #input_length = tf.squeeze(input_length) + + #decode_pred = self.model_decode(x=[x_in, in_len]) + #print(decode_pred) + base_pred =base_pred[:, 2:, :] + r = K.ctc_decode(base_pred, in_len, greedy = True, beam_width=100, top_paths=1) + print('r', r) + #r = K.cast(r[0][0], dtype='float32') + #print('r1', r) + #print('解码完成') + + r1 = K.get_value(r[0][0]) + print('r1', r1) + + print('r0', r[1]) + r2 = K.get_value(r[1]) + print(r2) + print('解码完成') + list_symbol_dic = data.list_symbol # 获取拼音列表 + + print('解码完成') + return r1 + + + + + + + + + + + + + + + #data = self.data + #data = DataSpeech('E:\\语音数据集') + #data.LoadDataList('dev') + # 获取输入特征 + #data_input = data.GetMfccFeature(wavsignal, fs) + #data_input = data.GetFrequencyFeature(wavsignal, fs) + + #arr_zero = np.zeros((1, 200), dtype=np.int16) #一个全是0的行向量 #import matplotlib.pyplot as plt #plt.subplot(111) @@ -275,42 +363,42 @@ class ModelSpeech(): # 语音模型类 # data_input = np.row_stack((data_input,arr_zero)) #print(len(data_input)) - list_symbol = data.list_symbol # 获取拼音列表 + #list_symbol = data.list_symbol # 获取拼音列表 - labels = [ list_symbol[0] ] + #labels = [ list_symbol[0] ] #while(len(labels) < 64): # labels.append('') - labels_num = [] - for i in labels: - labels_num.append(data.SymbolToNum(i)) + #labels_num = [] + #for i in labels: + # labels_num.append(data.SymbolToNum(i)) - data_input = np.array(data_input, dtype=np.int16) - data_input = data_input.reshape(data_input.shape[0],data_input.shape[1]) + #data_input = np.array(data_input, dtype=np.int16) + #data_input = data_input.reshape(data_input.shape[0],data_input.shape[1]) - labels_num = np.array(labels_num, dtype=np.int16) - labels_num = labels_num.reshape(labels_num.shape[0]) + #labels_num = np.array(labels_num, dtype=np.int16) + #labels_num = labels_num.reshape(labels_num.shape[0]) - input_length = np.array([data_input.shape[0] // 4 - 3], dtype=np.int16) - input_length = np.array(input_length) - input_length = input_length.reshape(input_length.shape[0]) + #input_length = np.array([data_input.shape[0] // 4 - 3], dtype=np.int16) + #input_length = np.array(input_length) + #input_length = input_length.reshape(input_length.shape[0]) - label_length = np.array([labels_num.shape[0]], dtype=np.int16) - label_length = np.array(label_length) - label_length = label_length.reshape(label_length.shape[0]) + #label_length = np.array([labels_num.shape[0]], dtype=np.int16) + #label_length = np.array(label_length) + #label_length = label_length.reshape(label_length.shape[0]) - x = [data_input, labels_num, input_length, label_length] + #x = [data_input, labels_num, input_length, label_length] #x = next(data.data_genetator(1, self.AUDIO_LENGTH)) #x = kr.utils.np_utils.to_categorical(x) - print(x) - x=np.array(x) + #print(x) + #x=np.array(x) - pred = self._model.predict(x=x) + #pred = self._model.predict(x=x) #pred = self._model.predict_on_batch([data_input, labels_num, input_length, label_length]) - return [labels,pred] + #return [labels,pred] pass @@ -354,8 +442,8 @@ if(__name__=='__main__'): ms = ModelSpeech(datapath) - #ms.LoadModel(modelpath + 'speech_model_e_0_step_1.model') - ms.TrainModel(datapath, epoch = 2, batch_size = 8, save_step = 10) + ms.LoadModel(modelpath + '5test\\speech_model_e_0_step_1400.model') + #ms.TrainModel(datapath, epoch = 2, batch_size = 8, save_step = 10) #ms.TestModel(datapath, str_dataset='dev', data_count = 32) - #r = ms.RecognizeSpeech_FromFile('E:\\语音数据集\\wav\\test\\D4\\D4_750.wav') - #print('*[提示] 语音识别结果:\n',r) + r = ms.RecognizeSpeech_FromFile('E:\\语音数据集\\wav\\test\\D4\\D4_750.wav') + print('*[提示] 语音识别结果:\n',r) diff --git a/SpeechModel5_old.py b/SpeechModel5_old.py index 50effcd..4293465 100644 --- a/SpeechModel5_old.py +++ b/SpeechModel5_old.py @@ -402,7 +402,7 @@ if(__name__=='__main__'): ms = ModelSpeech(datapath) - ms.LoadModel(modelpath + 'speech_model5_e_0_step_1.model') + ms.LoadModel(modelpath + '5test\\speech_model_e_0_step_100.model') #ms.TrainModel(datapath, epoch = 2, batch_size = 16, save_step = 1) #ms.TestModel(datapath, str_dataset='dev', data_count = 32) r = ms.RecognizeSpeech_FromFile('E:\\语音数据集\\wav\\test\\D4\\D4_750.wav') diff --git a/readdata.py b/readdata.py index 2e23ec6..df464b0 100644 --- a/readdata.py +++ b/readdata.py @@ -117,7 +117,7 @@ class DataSpeech(): #print('wavsignal[0][j]:\n',wavsignal[0][j]) #data_line = abs(fft(data_line)) / len(wavsignal[0]) data_line = fft(data_line) / len(wavsignal[0]) - data_input.append(data_line[0:len(data_line)//2]) + data_input.append(data_line[0:len(data_line)//2]) # 除以2是取一半数据,因为是对称的 #print('data_line:\n',data_line) return data_input