From 58d02601a7ff18933241ba4b39e9b7ebe17af903 Mon Sep 17 00:00:00 2001 From: nl8590687 <3210346136@qq.com> Date: Sun, 1 Apr 2018 17:16:38 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E8=B7=A8=E5=B9=B3?= =?UTF-8?q?=E5=8F=B0=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- SpeechModel.py | 27 +++++++++++++++++---- general_function/file_wav.py | 2 +- readdata.py | 47 ++++++++++++++++++++++++------------ 3 files changed, 55 insertions(+), 21 deletions(-) diff --git a/SpeechModel.py b/SpeechModel.py index 0afd6eb..0df33f9 100644 --- a/SpeechModel.py +++ b/SpeechModel.py @@ -3,6 +3,8 @@ """ @author: nl8590687 """ +import platform as plat + # LSTM_CNN import keras as kr import numpy as np @@ -56,7 +58,7 @@ class ModelSpeech(): # 语音模型类 layer_h3 = Conv1D(256, 5, use_bias=True, padding="valid")(layer_h2) # 卷积层 layer_h4 = MaxPooling1D(pool_size=2, strides=None, padding="valid")(layer_h3) # 池化层 layer_h5 = Dropout(0.2)(layer_h4) # 随机中断部分神经网络连接,防止过拟合 - layer_h6 = Dense(self.MS_OUTPUT_SIZE, use_bias=True, activation="softmax")(layer_h5) # 全连接层 + layer_h6 = Dense(256, use_bias=True, activation="softmax")(layer_h5) # 全连接层 layer_h7 = LSTM(256, activation='relu', use_bias=True, return_sequences=True)(layer_h6) # LSTM层 layer_h8 = Dropout(0.2)(layer_h7) # 随机中断部分神经网络连接,防止过拟合 layer_h9 = Dense(self.MS_OUTPUT_SIZE, use_bias=True, activation="softmax")(layer_h8) # 全连接层 @@ -150,16 +152,16 @@ class ModelSpeech(): # 语音模型类 print('[error] generator error. please check data format.') break - self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step)) + self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step * save_step)) - def LoadModel(self,filename='model_speech/LSTM_CNN_model.model'): + def LoadModel(self,filename='model_speech/speech_model_e_0_step_1.model'): ''' 加载模型参数 ''' self._model.load_weights(filename) - def SaveModel(self,filename='model_speech/LSTM_CNN_model',comment=''): + def SaveModel(self,filename='model_speech/speech_model',comment=''): ''' 保存模型参数 ''' @@ -199,8 +201,23 @@ class ModelSpeech(): # 语音模型类 if(__name__=='__main__'): - datapath = 'E:\\语音数据集' + datapath = '' + modelpath = '' ms = ModelSpeech() + + system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断 + if(system_type == 'Windows'): + datapath = 'E:\\语音数据集' + modelpath = 'model_speech\\' + elif(system_type == 'Linux'): + datapath = 'dataset' + modelpath = 'model_speech/' + else: + print('*[Message] Unknown System\n') + datapath = 'dataset' + modelpath = 'model_speech/' + + #ms.LoadModel(modelpath + 'speech_model_e_0_step_1.model') ms.TrainModel(datapath) #ms.TestModel(datapath) diff --git a/general_function/file_wav.py b/general_function/file_wav.py index fb4f64b..eb60999 100644 --- a/general_function/file_wav.py +++ b/general_function/file_wav.py @@ -56,7 +56,7 @@ def get_wav_list(filename): for i in txt_lines: if(i!=''): txt_l=i.split(' ') - dic_filelist[txt_l[0]]='wav/'+txt_l[1] + dic_filelist[txt_l[0]] = txt_l[1] list_wavmark.append(txt_l[0]) txt_obj.close() return dic_filelist,list_wavmark diff --git a/readdata.py b/readdata.py index f44300e..b16dff3 100644 --- a/readdata.py +++ b/readdata.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import platform as plat + import numpy as np from general_function.file_wav import * @@ -20,15 +22,29 @@ class DataSpeech(): 参数: path:数据存放位置根目录 ''' + + system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断 + self.datapath = path; # 数据存放位置根目录 - if('\\'!=self.datapath[-1]): # 在目录路径末尾增加斜杠 - self.datapath=self.datapath+'\\' + + self.slash = '' + if(system_type == 'Windows'): + self.slash='\\' # 反斜杠 + elif(system_type == 'Linux'): + self.slash='/' # 正斜杠 + else: + print('*[Message] Unknown System\n') + self.slash='/' # 正斜杠 + + if(self.slash != self.datapath[-1]): # 在目录路径末尾增加斜杠 + self.datapath = self.datapath + self.slash + self.dic_wavlist = {} self.dic_symbollist = {} self.SymbolNum = 0 # 记录拼音符号数量 self.list_symbol = self.GetSymbolList() # 全部汉语拼音符号列表 - self.list_wavnum=[] # wav文件标记列表 - self.list_symbolnum=[] # symbol标记列表 + self.list_wavnum = [] # wav文件标记列表 + self.list_symbolnum = [] # symbol标记列表 self.DataNum = 0 # 记录数据量 @@ -45,20 +61,20 @@ class DataSpeech(): ''' # 设定选取哪一项作为要使用的数据集 if(type=='train'): - filename_wavlist='doc\\doc\\list\\train.wav.lst' - filename_symbollist='doc\\doc\\trans\\train.syllable.txt' + filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'train.wav.lst' + filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'train.syllable.txt' elif(type=='dev'): - filename_wavlist='doc\\doc\\list\\cv.wav.lst' - filename_symbollist='doc\\doc\\trans\\cv.syllable.txt' + filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'cv.wav.lst' + filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'cv.syllable.txt' elif(type=='test'): - filename_wavlist='doc\\doc\\list\\test.wav.lst' - filename_symbollist='doc\\doc\\trans\\test.syllable.txt' + filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'test.wav.lst' + filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'test.syllable.txt' else: - filename_wavlist='' # 默认留空 - filename_symbollist='' + filename_wavlist = '' # 默认留空 + filename_symbollist = '' # 读取数据列表,wav文件列表和其对应的符号列表 - self.dic_wavlist,self.list_wavnum = get_wav_list(self.datapath+filename_wavlist) - self.dic_symbollist,self.list_symbolnum = get_wav_symbol(self.datapath+filename_symbollist) + self.dic_wavlist,self.list_wavnum = get_wav_list(self.datapath + filename_wavlist) + self.dic_symbollist,self.list_symbolnum = get_wav_symbol(self.datapath + filename_symbollist) self.DataNum = self.GetDataNum() def GetDataNum(self): @@ -85,7 +101,8 @@ class DataSpeech(): # 读取一个文件 filename = self.dic_wavlist[self.list_wavnum[n_start]] - filename=filename.replace('/','\\') # windows系统下需要添加这一行 + if('Windows' == plat.system()): + filename=filename.replace('/','\\') # windows系统下需要执行这一行,对文件路径做特别处理 wavsignal,fs=read_wav_data(self.datapath+filename) # 获取输入特征