增加了跨平台支持
This commit is contained in:
parent
cc0572b3b7
commit
58d02601a7
|
@ -3,6 +3,8 @@
|
||||||
"""
|
"""
|
||||||
@author: nl8590687
|
@author: nl8590687
|
||||||
"""
|
"""
|
||||||
|
import platform as plat
|
||||||
|
|
||||||
# LSTM_CNN
|
# LSTM_CNN
|
||||||
import keras as kr
|
import keras as kr
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -56,7 +58,7 @@ class ModelSpeech(): # 语音模型类
|
||||||
layer_h3 = Conv1D(256, 5, use_bias=True, padding="valid")(layer_h2) # 卷积层
|
layer_h3 = Conv1D(256, 5, use_bias=True, padding="valid")(layer_h2) # 卷积层
|
||||||
layer_h4 = MaxPooling1D(pool_size=2, strides=None, padding="valid")(layer_h3) # 池化层
|
layer_h4 = MaxPooling1D(pool_size=2, strides=None, padding="valid")(layer_h3) # 池化层
|
||||||
layer_h5 = Dropout(0.2)(layer_h4) # 随机中断部分神经网络连接,防止过拟合
|
layer_h5 = Dropout(0.2)(layer_h4) # 随机中断部分神经网络连接,防止过拟合
|
||||||
layer_h6 = Dense(self.MS_OUTPUT_SIZE, use_bias=True, activation="softmax")(layer_h5) # 全连接层
|
layer_h6 = Dense(256, use_bias=True, activation="softmax")(layer_h5) # 全连接层
|
||||||
layer_h7 = LSTM(256, activation='relu', use_bias=True, return_sequences=True)(layer_h6) # LSTM层
|
layer_h7 = LSTM(256, activation='relu', use_bias=True, return_sequences=True)(layer_h6) # LSTM层
|
||||||
layer_h8 = Dropout(0.2)(layer_h7) # 随机中断部分神经网络连接,防止过拟合
|
layer_h8 = Dropout(0.2)(layer_h7) # 随机中断部分神经网络连接,防止过拟合
|
||||||
layer_h9 = Dense(self.MS_OUTPUT_SIZE, use_bias=True, activation="softmax")(layer_h8) # 全连接层
|
layer_h9 = Dense(self.MS_OUTPUT_SIZE, use_bias=True, activation="softmax")(layer_h8) # 全连接层
|
||||||
|
@ -150,16 +152,16 @@ class ModelSpeech(): # 语音模型类
|
||||||
print('[error] generator error. please check data format.')
|
print('[error] generator error. please check data format.')
|
||||||
break
|
break
|
||||||
|
|
||||||
self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step))
|
self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step * save_step))
|
||||||
|
|
||||||
|
|
||||||
def LoadModel(self,filename='model_speech/LSTM_CNN_model.model'):
|
def LoadModel(self,filename='model_speech/speech_model_e_0_step_1.model'):
|
||||||
'''
|
'''
|
||||||
加载模型参数
|
加载模型参数
|
||||||
'''
|
'''
|
||||||
self._model.load_weights(filename)
|
self._model.load_weights(filename)
|
||||||
|
|
||||||
def SaveModel(self,filename='model_speech/LSTM_CNN_model',comment=''):
|
def SaveModel(self,filename='model_speech/speech_model',comment=''):
|
||||||
'''
|
'''
|
||||||
保存模型参数
|
保存模型参数
|
||||||
'''
|
'''
|
||||||
|
@ -199,8 +201,23 @@ class ModelSpeech(): # 语音模型类
|
||||||
|
|
||||||
|
|
||||||
if(__name__=='__main__'):
|
if(__name__=='__main__'):
|
||||||
datapath = 'E:\\语音数据集'
|
datapath = ''
|
||||||
|
modelpath = ''
|
||||||
ms = ModelSpeech()
|
ms = ModelSpeech()
|
||||||
|
|
||||||
|
system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断
|
||||||
|
if(system_type == 'Windows'):
|
||||||
|
datapath = 'E:\\语音数据集'
|
||||||
|
modelpath = 'model_speech\\'
|
||||||
|
elif(system_type == 'Linux'):
|
||||||
|
datapath = 'dataset'
|
||||||
|
modelpath = 'model_speech/'
|
||||||
|
else:
|
||||||
|
print('*[Message] Unknown System\n')
|
||||||
|
datapath = 'dataset'
|
||||||
|
modelpath = 'model_speech/'
|
||||||
|
|
||||||
|
#ms.LoadModel(modelpath + 'speech_model_e_0_step_1.model')
|
||||||
ms.TrainModel(datapath)
|
ms.TrainModel(datapath)
|
||||||
#ms.TestModel(datapath)
|
#ms.TestModel(datapath)
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ def get_wav_list(filename):
|
||||||
for i in txt_lines:
|
for i in txt_lines:
|
||||||
if(i!=''):
|
if(i!=''):
|
||||||
txt_l=i.split(' ')
|
txt_l=i.split(' ')
|
||||||
dic_filelist[txt_l[0]]='wav/'+txt_l[1]
|
dic_filelist[txt_l[0]] = txt_l[1]
|
||||||
list_wavmark.append(txt_l[0])
|
list_wavmark.append(txt_l[0])
|
||||||
txt_obj.close()
|
txt_obj.close()
|
||||||
return dic_filelist,list_wavmark
|
return dic_filelist,list_wavmark
|
||||||
|
|
47
readdata.py
47
readdata.py
|
@ -1,6 +1,8 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import platform as plat
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from general_function.file_wav import *
|
from general_function.file_wav import *
|
||||||
|
|
||||||
|
@ -20,15 +22,29 @@ class DataSpeech():
|
||||||
参数:
|
参数:
|
||||||
path:数据存放位置根目录
|
path:数据存放位置根目录
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断
|
||||||
|
|
||||||
self.datapath = path; # 数据存放位置根目录
|
self.datapath = path; # 数据存放位置根目录
|
||||||
if('\\'!=self.datapath[-1]): # 在目录路径末尾增加斜杠
|
|
||||||
self.datapath=self.datapath+'\\'
|
self.slash = ''
|
||||||
|
if(system_type == 'Windows'):
|
||||||
|
self.slash='\\' # 反斜杠
|
||||||
|
elif(system_type == 'Linux'):
|
||||||
|
self.slash='/' # 正斜杠
|
||||||
|
else:
|
||||||
|
print('*[Message] Unknown System\n')
|
||||||
|
self.slash='/' # 正斜杠
|
||||||
|
|
||||||
|
if(self.slash != self.datapath[-1]): # 在目录路径末尾增加斜杠
|
||||||
|
self.datapath = self.datapath + self.slash
|
||||||
|
|
||||||
self.dic_wavlist = {}
|
self.dic_wavlist = {}
|
||||||
self.dic_symbollist = {}
|
self.dic_symbollist = {}
|
||||||
self.SymbolNum = 0 # 记录拼音符号数量
|
self.SymbolNum = 0 # 记录拼音符号数量
|
||||||
self.list_symbol = self.GetSymbolList() # 全部汉语拼音符号列表
|
self.list_symbol = self.GetSymbolList() # 全部汉语拼音符号列表
|
||||||
self.list_wavnum=[] # wav文件标记列表
|
self.list_wavnum = [] # wav文件标记列表
|
||||||
self.list_symbolnum=[] # symbol标记列表
|
self.list_symbolnum = [] # symbol标记列表
|
||||||
|
|
||||||
self.DataNum = 0 # 记录数据量
|
self.DataNum = 0 # 记录数据量
|
||||||
|
|
||||||
|
@ -45,20 +61,20 @@ class DataSpeech():
|
||||||
'''
|
'''
|
||||||
# 设定选取哪一项作为要使用的数据集
|
# 设定选取哪一项作为要使用的数据集
|
||||||
if(type=='train'):
|
if(type=='train'):
|
||||||
filename_wavlist='doc\\doc\\list\\train.wav.lst'
|
filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'train.wav.lst'
|
||||||
filename_symbollist='doc\\doc\\trans\\train.syllable.txt'
|
filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'train.syllable.txt'
|
||||||
elif(type=='dev'):
|
elif(type=='dev'):
|
||||||
filename_wavlist='doc\\doc\\list\\cv.wav.lst'
|
filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'cv.wav.lst'
|
||||||
filename_symbollist='doc\\doc\\trans\\cv.syllable.txt'
|
filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'cv.syllable.txt'
|
||||||
elif(type=='test'):
|
elif(type=='test'):
|
||||||
filename_wavlist='doc\\doc\\list\\test.wav.lst'
|
filename_wavlist = 'doc' + self.slash + 'list' + self.slash + 'test.wav.lst'
|
||||||
filename_symbollist='doc\\doc\\trans\\test.syllable.txt'
|
filename_symbollist = 'doc' + self.slash + 'trans' + self.slash + 'test.syllable.txt'
|
||||||
else:
|
else:
|
||||||
filename_wavlist='' # 默认留空
|
filename_wavlist = '' # 默认留空
|
||||||
filename_symbollist=''
|
filename_symbollist = ''
|
||||||
# 读取数据列表,wav文件列表和其对应的符号列表
|
# 读取数据列表,wav文件列表和其对应的符号列表
|
||||||
self.dic_wavlist,self.list_wavnum = get_wav_list(self.datapath+filename_wavlist)
|
self.dic_wavlist,self.list_wavnum = get_wav_list(self.datapath + filename_wavlist)
|
||||||
self.dic_symbollist,self.list_symbolnum = get_wav_symbol(self.datapath+filename_symbollist)
|
self.dic_symbollist,self.list_symbolnum = get_wav_symbol(self.datapath + filename_symbollist)
|
||||||
self.DataNum = self.GetDataNum()
|
self.DataNum = self.GetDataNum()
|
||||||
|
|
||||||
def GetDataNum(self):
|
def GetDataNum(self):
|
||||||
|
@ -85,7 +101,8 @@ class DataSpeech():
|
||||||
# 读取一个文件
|
# 读取一个文件
|
||||||
filename = self.dic_wavlist[self.list_wavnum[n_start]]
|
filename = self.dic_wavlist[self.list_wavnum[n_start]]
|
||||||
|
|
||||||
filename=filename.replace('/','\\') # windows系统下需要添加这一行
|
if('Windows' == plat.system()):
|
||||||
|
filename=filename.replace('/','\\') # windows系统下需要执行这一行,对文件路径做特别处理
|
||||||
|
|
||||||
wavsignal,fs=read_wav_data(self.datapath+filename)
|
wavsignal,fs=read_wav_data(self.datapath+filename)
|
||||||
# 获取输入特征
|
# 获取输入特征
|
||||||
|
|
Loading…
Reference in New Issue