增加了跨平台支持

This commit is contained in:
nl8590687 2018-04-01 17:16:38 +08:00
parent cc0572b3b7
commit 58d02601a7
3 changed files with 55 additions and 21 deletions

View File

@ -3,6 +3,8 @@
""" """
@author: nl8590687 @author: nl8590687
""" """
import platform as plat
# LSTM_CNN # LSTM_CNN
import keras as kr import keras as kr
import numpy as np import numpy as np
@ -56,7 +58,7 @@ class ModelSpeech(): # 语音模型类
layer_h3 = Conv1D(256, 5, use_bias=True, padding="valid")(layer_h2) # 卷积层 layer_h3 = Conv1D(256, 5, use_bias=True, padding="valid")(layer_h2) # 卷积层
layer_h4 = MaxPooling1D(pool_size=2, strides=None, padding="valid")(layer_h3) # 池化层 layer_h4 = MaxPooling1D(pool_size=2, strides=None, padding="valid")(layer_h3) # 池化层
layer_h5 = Dropout(0.2)(layer_h4) # 随机中断部分神经网络连接,防止过拟合 layer_h5 = Dropout(0.2)(layer_h4) # 随机中断部分神经网络连接,防止过拟合
layer_h6 = Dense(self.MS_OUTPUT_SIZE, use_bias=True, activation="softmax")(layer_h5) # 全连接层 layer_h6 = Dense(256, use_bias=True, activation="softmax")(layer_h5) # 全连接层
layer_h7 = LSTM(256, activation='relu', use_bias=True, return_sequences=True)(layer_h6) # LSTM层 layer_h7 = LSTM(256, activation='relu', use_bias=True, return_sequences=True)(layer_h6) # LSTM层
layer_h8 = Dropout(0.2)(layer_h7) # 随机中断部分神经网络连接,防止过拟合 layer_h8 = Dropout(0.2)(layer_h7) # 随机中断部分神经网络连接,防止过拟合
layer_h9 = Dense(self.MS_OUTPUT_SIZE, use_bias=True, activation="softmax")(layer_h8) # 全连接层 layer_h9 = Dense(self.MS_OUTPUT_SIZE, use_bias=True, activation="softmax")(layer_h8) # 全连接层
@ -150,16 +152,16 @@ class ModelSpeech(): # 语音模型类
print('[error] generator error. please check data format.') print('[error] generator error. please check data format.')
break break
self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step)) self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step * save_step))
def LoadModel(self,filename='model_speech/LSTM_CNN_model.model'): def LoadModel(self,filename='model_speech/speech_model_e_0_step_1.model'):
''' '''
加载模型参数 加载模型参数
''' '''
self._model.load_weights(filename) self._model.load_weights(filename)
def SaveModel(self,filename='model_speech/LSTM_CNN_model',comment=''): def SaveModel(self,filename='model_speech/speech_model',comment=''):
''' '''
保存模型参数 保存模型参数
''' '''
@ -199,8 +201,23 @@ class ModelSpeech(): # 语音模型类
if(__name__=='__main__'): if(__name__=='__main__'):
datapath = 'E:\\语音数据集' datapath = ''
modelpath = ''
ms = ModelSpeech() ms = ModelSpeech()
system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断
if(system_type == 'Windows'):
datapath = 'E:\\语音数据集'
modelpath = 'model_speech\\'
elif(system_type == 'Linux'):
datapath = 'dataset'
modelpath = 'model_speech/'
else:
print('*[Message] Unknown System\n')
datapath = 'dataset'
modelpath = 'model_speech/'
#ms.LoadModel(modelpath + 'speech_model_e_0_step_1.model')
ms.TrainModel(datapath) ms.TrainModel(datapath)
#ms.TestModel(datapath) #ms.TestModel(datapath)

View File

@ -56,7 +56,7 @@ def get_wav_list(filename):
for i in txt_lines: for i in txt_lines:
if(i!=''): if(i!=''):
txt_l=i.split(' ') txt_l=i.split(' ')
dic_filelist[txt_l[0]]='wav/'+txt_l[1] dic_filelist[txt_l[0]] = txt_l[1]
list_wavmark.append(txt_l[0]) list_wavmark.append(txt_l[0])
txt_obj.close() txt_obj.close()
return dic_filelist,list_wavmark return dic_filelist,list_wavmark

View File

@ -1,6 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import platform as plat
import numpy as np import numpy as np
from general_function.file_wav import * from general_function.file_wav import *
@ -20,15 +22,29 @@ class DataSpeech():
参数 参数
path数据存放位置根目录 path数据存放位置根目录
''' '''
system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断
self.datapath = path; # 数据存放位置根目录 self.datapath = path; # 数据存放位置根目录
if('\\'!=self.datapath[-1]): # 在目录路径末尾增加斜杠
self.datapath=self.datapath+'\\' self.slash = ''
if(system_type == 'Windows'):
self.slash='\\' # 反斜杠
elif(system_type == 'Linux'):
self.slash='/' # 正斜杠
else:
print('*[Message] Unknown System\n')
self.slash='/' # 正斜杠
if(self.slash != self.datapath[-1]): # 在目录路径末尾增加斜杠
self.datapath = self.datapath + self.slash
self.dic_wavlist = {} self.dic_wavlist = {}
self.dic_symbollist = {} self.dic_symbollist = {}
self.SymbolNum = 0 # 记录拼音符号数量 self.SymbolNum = 0 # 记录拼音符号数量
self.list_symbol = self.GetSymbolList() # 全部汉语拼音符号列表 self.list_symbol = self.GetSymbolList() # 全部汉语拼音符号列表
self.list_wavnum=[] # wav文件标记列表 self.list_wavnum = [] # wav文件标记列表
self.list_symbolnum=[] # symbol标记列表 self.list_symbolnum = [] # symbol标记列表
self.DataNum = 0 # 记录数据量 self.DataNum = 0 # 记录数据量
@ -45,20 +61,20 @@ class DataSpeech():
''' '''
# 设定选取哪一项作为要使用的数据集 # 设定选取哪一项作为要使用的数据集
if(type=='train'): if(type=='train'):
filename_wavlist='doc\\doc\\list\\train.wav.lst' filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'train.wav.lst'
filename_symbollist='doc\\doc\\trans\\train.syllable.txt' filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'train.syllable.txt'
elif(type=='dev'): elif(type=='dev'):
filename_wavlist='doc\\doc\\list\\cv.wav.lst' filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'cv.wav.lst'
filename_symbollist='doc\\doc\\trans\\cv.syllable.txt' filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'cv.syllable.txt'
elif(type=='test'): elif(type=='test'):
filename_wavlist='doc\\doc\\list\\test.wav.lst' filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'test.wav.lst'
filename_symbollist='doc\\doc\\trans\\test.syllable.txt' filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'test.syllable.txt'
else: else:
filename_wavlist='' # 默认留空 filename_wavlist = '' # 默认留空
filename_symbollist='' filename_symbollist = ''
# 读取数据列表wav文件列表和其对应的符号列表 # 读取数据列表wav文件列表和其对应的符号列表
self.dic_wavlist,self.list_wavnum = get_wav_list(self.datapath+filename_wavlist) self.dic_wavlist,self.list_wavnum = get_wav_list(self.datapath + filename_wavlist)
self.dic_symbollist,self.list_symbolnum = get_wav_symbol(self.datapath+filename_symbollist) self.dic_symbollist,self.list_symbolnum = get_wav_symbol(self.datapath + filename_symbollist)
self.DataNum = self.GetDataNum() self.DataNum = self.GetDataNum()
def GetDataNum(self): def GetDataNum(self):
@ -85,7 +101,8 @@ class DataSpeech():
# 读取一个文件 # 读取一个文件
filename = self.dic_wavlist[self.list_wavnum[n_start]] filename = self.dic_wavlist[self.list_wavnum[n_start]]
filename=filename.replace('/','\\') # windows系统下需要添加这一行 if('Windows' == plat.system()):
filename=filename.replace('/','\\') # windows系统下需要执行这一行对文件路径做特别处理
wavsignal,fs=read_wav_data(self.datapath+filename) wavsignal,fs=read_wav_data(self.datapath+filename)
# 获取输入特征 # 获取输入特征