增加了跨平台支持

This commit is contained in:
nl8590687 2018-04-01 17:16:38 +08:00
parent cc0572b3b7
commit 58d02601a7
3 changed files with 55 additions and 21 deletions

View File

@ -3,6 +3,8 @@
"""
@author: nl8590687
"""
import platform as plat
# LSTM_CNN
import keras as kr
import numpy as np
@ -56,7 +58,7 @@ class ModelSpeech(): # 语音模型类
layer_h3 = Conv1D(256, 5, use_bias=True, padding="valid")(layer_h2) # 卷积层
layer_h4 = MaxPooling1D(pool_size=2, strides=None, padding="valid")(layer_h3) # 池化层
layer_h5 = Dropout(0.2)(layer_h4) # 随机中断部分神经网络连接,防止过拟合
layer_h6 = Dense(self.MS_OUTPUT_SIZE, use_bias=True, activation="softmax")(layer_h5) # 全连接层
layer_h6 = Dense(256, use_bias=True, activation="softmax")(layer_h5) # 全连接层
layer_h7 = LSTM(256, activation='relu', use_bias=True, return_sequences=True)(layer_h6) # LSTM层
layer_h8 = Dropout(0.2)(layer_h7) # 随机中断部分神经网络连接,防止过拟合
layer_h9 = Dense(self.MS_OUTPUT_SIZE, use_bias=True, activation="softmax")(layer_h8) # 全连接层
@ -150,16 +152,16 @@ class ModelSpeech(): # 语音模型类
print('[error] generator error. please check data format.')
break
self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step))
self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step * save_step))
def LoadModel(self,filename='model_speech/LSTM_CNN_model.model'):
def LoadModel(self,filename='model_speech/speech_model_e_0_step_1.model'):
'''
加载模型参数
'''
self._model.load_weights(filename)
def SaveModel(self,filename='model_speech/LSTM_CNN_model',comment=''):
def SaveModel(self,filename='model_speech/speech_model',comment=''):
'''
保存模型参数
'''
@ -199,8 +201,23 @@ class ModelSpeech(): # 语音模型类
if(__name__=='__main__'):
datapath = 'E:\\语音数据集'
datapath = ''
modelpath = ''
ms = ModelSpeech()
system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断
if(system_type == 'Windows'):
datapath = 'E:\\语音数据集'
modelpath = 'model_speech\\'
elif(system_type == 'Linux'):
datapath = 'dataset'
modelpath = 'model_speech/'
else:
print('*[Message] Unknown System\n')
datapath = 'dataset'
modelpath = 'model_speech/'
#ms.LoadModel(modelpath + 'speech_model_e_0_step_1.model')
ms.TrainModel(datapath)
#ms.TestModel(datapath)

View File

@ -56,7 +56,7 @@ def get_wav_list(filename):
for i in txt_lines:
if(i!=''):
txt_l=i.split(' ')
dic_filelist[txt_l[0]]='wav/'+txt_l[1]
dic_filelist[txt_l[0]] = txt_l[1]
list_wavmark.append(txt_l[0])
txt_obj.close()
return dic_filelist,list_wavmark

View File

@ -1,6 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import platform as plat
import numpy as np
from general_function.file_wav import *
@ -20,9 +22,23 @@ class DataSpeech():
参数
path数据存放位置根目录
'''
system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断
self.datapath = path; # 数据存放位置根目录
if('\\'!=self.datapath[-1]): # 在目录路径末尾增加斜杠
self.datapath=self.datapath+'\\'
self.slash = ''
if(system_type == 'Windows'):
self.slash='\\' # 反斜杠
elif(system_type == 'Linux'):
self.slash='/' # 正斜杠
else:
print('*[Message] Unknown System\n')
self.slash='/' # 正斜杠
if(self.slash != self.datapath[-1]): # 在目录路径末尾增加斜杠
self.datapath = self.datapath + self.slash
self.dic_wavlist = {}
self.dic_symbollist = {}
self.SymbolNum = 0 # 记录拼音符号数量
@ -45,14 +61,14 @@ class DataSpeech():
'''
# 设定选取哪一项作为要使用的数据集
if(type=='train'):
filename_wavlist='doc\\doc\\list\\train.wav.lst'
filename_symbollist='doc\\doc\\trans\\train.syllable.txt'
filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'train.wav.lst'
filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'train.syllable.txt'
elif(type=='dev'):
filename_wavlist='doc\\doc\\list\\cv.wav.lst'
filename_symbollist='doc\\doc\\trans\\cv.syllable.txt'
filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'cv.wav.lst'
filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'cv.syllable.txt'
elif(type=='test'):
filename_wavlist='doc\\doc\\list\\test.wav.lst'
filename_symbollist='doc\\doc\\trans\\test.syllable.txt'
filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'test.wav.lst'
filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'test.syllable.txt'
else:
filename_wavlist = '' # 默认留空
filename_symbollist = ''
@ -85,7 +101,8 @@ class DataSpeech():
# 读取一个文件
filename = self.dic_wavlist[self.list_wavnum[n_start]]
filename=filename.replace('/','\\') # windows系统下需要添加这一行
if('Windows' == plat.system()):
filename=filename.replace('/','\\') # windows系统下需要执行这一行对文件路径做特别处理
wavsignal,fs=read_wav_data(self.datapath+filename)
# 获取输入特征