增加了跨平台支持
This commit is contained in:
parent
cc0572b3b7
commit
58d02601a7
|
@ -3,6 +3,8 @@
|
|||
"""
|
||||
@author: nl8590687
|
||||
"""
|
||||
import platform as plat
|
||||
|
||||
# LSTM_CNN
|
||||
import keras as kr
|
||||
import numpy as np
|
||||
|
@ -56,7 +58,7 @@ class ModelSpeech(): # 语音模型类
|
|||
layer_h3 = Conv1D(256, 5, use_bias=True, padding="valid")(layer_h2) # 卷积层
|
||||
layer_h4 = MaxPooling1D(pool_size=2, strides=None, padding="valid")(layer_h3) # 池化层
|
||||
layer_h5 = Dropout(0.2)(layer_h4) # 随机中断部分神经网络连接,防止过拟合
|
||||
layer_h6 = Dense(self.MS_OUTPUT_SIZE, use_bias=True, activation="softmax")(layer_h5) # 全连接层
|
||||
layer_h6 = Dense(256, use_bias=True, activation="softmax")(layer_h5) # 全连接层
|
||||
layer_h7 = LSTM(256, activation='relu', use_bias=True, return_sequences=True)(layer_h6) # LSTM层
|
||||
layer_h8 = Dropout(0.2)(layer_h7) # 随机中断部分神经网络连接,防止过拟合
|
||||
layer_h9 = Dense(self.MS_OUTPUT_SIZE, use_bias=True, activation="softmax")(layer_h8) # 全连接层
|
||||
|
@ -150,16 +152,16 @@ class ModelSpeech(): # 语音模型类
|
|||
print('[error] generator error. please check data format.')
|
||||
break
|
||||
|
||||
self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step))
|
||||
self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step * save_step))
|
||||
|
||||
|
||||
def LoadModel(self,filename='model_speech/LSTM_CNN_model.model'):
|
||||
def LoadModel(self,filename='model_speech/speech_model_e_0_step_1.model'):
|
||||
'''
|
||||
加载模型参数
|
||||
'''
|
||||
self._model.load_weights(filename)
|
||||
|
||||
def SaveModel(self,filename='model_speech/LSTM_CNN_model',comment=''):
|
||||
def SaveModel(self,filename='model_speech/speech_model',comment=''):
|
||||
'''
|
||||
保存模型参数
|
||||
'''
|
||||
|
@ -199,8 +201,23 @@ class ModelSpeech(): # 语音模型类
|
|||
|
||||
|
||||
if(__name__=='__main__'):
|
||||
datapath = 'E:\\语音数据集'
|
||||
datapath = ''
|
||||
modelpath = ''
|
||||
ms = ModelSpeech()
|
||||
|
||||
system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断
|
||||
if(system_type == 'Windows'):
|
||||
datapath = 'E:\\语音数据集'
|
||||
modelpath = 'model_speech\\'
|
||||
elif(system_type == 'Linux'):
|
||||
datapath = 'dataset'
|
||||
modelpath = 'model_speech/'
|
||||
else:
|
||||
print('*[Message] Unknown System\n')
|
||||
datapath = 'dataset'
|
||||
modelpath = 'model_speech/'
|
||||
|
||||
#ms.LoadModel(modelpath + 'speech_model_e_0_step_1.model')
|
||||
ms.TrainModel(datapath)
|
||||
#ms.TestModel(datapath)
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ def get_wav_list(filename):
|
|||
for i in txt_lines:
|
||||
if(i!=''):
|
||||
txt_l=i.split(' ')
|
||||
dic_filelist[txt_l[0]]='wav/'+txt_l[1]
|
||||
dic_filelist[txt_l[0]] = txt_l[1]
|
||||
list_wavmark.append(txt_l[0])
|
||||
txt_obj.close()
|
||||
return dic_filelist,list_wavmark
|
||||
|
|
35
readdata.py
35
readdata.py
|
@ -1,6 +1,8 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import platform as plat
|
||||
|
||||
import numpy as np
|
||||
from general_function.file_wav import *
|
||||
|
||||
|
@ -20,9 +22,23 @@ class DataSpeech():
|
|||
参数:
|
||||
path:数据存放位置根目录
|
||||
'''
|
||||
|
||||
system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断
|
||||
|
||||
self.datapath = path; # 数据存放位置根目录
|
||||
if('\\'!=self.datapath[-1]): # 在目录路径末尾增加斜杠
|
||||
self.datapath=self.datapath+'\\'
|
||||
|
||||
self.slash = ''
|
||||
if(system_type == 'Windows'):
|
||||
self.slash='\\' # 反斜杠
|
||||
elif(system_type == 'Linux'):
|
||||
self.slash='/' # 正斜杠
|
||||
else:
|
||||
print('*[Message] Unknown System\n')
|
||||
self.slash='/' # 正斜杠
|
||||
|
||||
if(self.slash != self.datapath[-1]): # 在目录路径末尾增加斜杠
|
||||
self.datapath = self.datapath + self.slash
|
||||
|
||||
self.dic_wavlist = {}
|
||||
self.dic_symbollist = {}
|
||||
self.SymbolNum = 0 # 记录拼音符号数量
|
||||
|
@ -45,14 +61,14 @@ class DataSpeech():
|
|||
'''
|
||||
# 设定选取哪一项作为要使用的数据集
|
||||
if(type=='train'):
|
||||
filename_wavlist='doc\\doc\\list\\train.wav.lst'
|
||||
filename_symbollist='doc\\doc\\trans\\train.syllable.txt'
|
||||
filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'train.wav.lst'
|
||||
filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'train.syllable.txt'
|
||||
elif(type=='dev'):
|
||||
filename_wavlist='doc\\doc\\list\\cv.wav.lst'
|
||||
filename_symbollist='doc\\doc\\trans\\cv.syllable.txt'
|
||||
filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'cv.wav.lst'
|
||||
filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'cv.syllable.txt'
|
||||
elif(type=='test'):
|
||||
filename_wavlist='doc\\doc\\list\\test.wav.lst'
|
||||
filename_symbollist='doc\\doc\\trans\\test.syllable.txt'
|
||||
filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'test.wav.lst'
|
||||
filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'test.syllable.txt'
|
||||
else:
|
||||
filename_wavlist = '' # 默认留空
|
||||
filename_symbollist = ''
|
||||
|
@ -85,7 +101,8 @@ class DataSpeech():
|
|||
# 读取一个文件
|
||||
filename = self.dic_wavlist[self.list_wavnum[n_start]]
|
||||
|
||||
filename=filename.replace('/','\\') # windows系统下需要添加这一行
|
||||
if('Windows' == plat.system()):
|
||||
filename=filename.replace('/','\\') # windows系统下需要执行这一行,对文件路径做特别处理
|
||||
|
||||
wavsignal,fs=read_wav_data(self.datapath+filename)
|
||||
# 获取输入特征
|
||||
|
|
Loading…
Reference in New Issue