diff --git a/general_function/file_wav.py b/general_function/file_wav.py index b0f8bd9..fb4f64b 100644 --- a/general_function/file_wav.py +++ b/general_function/file_wav.py @@ -52,11 +52,14 @@ def get_wav_list(filename): txt_text=txt_obj.read() txt_lines=txt_text.split('\n') # 文本分割 dic_filelist={} # 初始化字典 + list_wavmark=[] # 初始化wav列表 for i in txt_lines: if(i!=''): txt_l=i.split(' ') - dic_filelist[txt_l[0]]=txt_l[1] - return dic_filelist + dic_filelist[txt_l[0]]='wav/'+txt_l[1] + list_wavmark.append(txt_l[0]) + txt_obj.close() + return dic_filelist,list_wavmark def get_wav_symbol(filename): ''' @@ -67,11 +70,14 @@ def get_wav_symbol(filename): txt_text=txt_obj.read() txt_lines=txt_text.split('\n') # 文本分割 dic_symbol_list={} # 初始化字典 + list_symbolmark=[] # 初始化symbol列表 for i in txt_lines: if(i!=''): txt_l=i.split(' ') dic_symbol_list[txt_l[0]]=txt_l[1:] - return dic_symbol_list + list_symbolmark.append(txt_l[0]) + txt_obj.close() + return dic_symbol_list,list_symbolmark if(__name__=='__main__'): #dic=get_wav_symbol('E:\\语音数据集\\doc\\doc\\trans\\train.syllable.txt') diff --git a/log.md b/log.md index e0cb4fa..6b3a614 100644 --- a/log.md +++ b/log.md @@ -1 +1,18 @@ -# ASRT_SpeechRecognition 基于深度学习的语音识别系统 ## Introduction 这里是更新记录日志文件 如果有什么问题,团队内部需要在这里直接写出来 ## Log ### 2017-08-29 准备使用现有的包[python_speech_features](https://github.com/jameslyons/python_speech_features)来实现特征的提取,以及求一阶二阶差分。 ### 2017-08-28 开始准备制作语音信号处理方面的功能 ### 2017-08-22 准备使用Keras基于LSTM/CNN尝试实现 \ No newline at end of file +# ASRT_SpeechRecognition +基于深度学习的语音识别系统 + +## Introduction + +这里是更新记录日志文件 + +如果有什么问题,团队内部需要在这里直接写出来 + +## Log +### 2017-08-31 +数据处理部分的代码基本完成,现在准备撸模型 +### 2017-08-29 +准备使用现有的包[python_speech_features](https://github.com/jameslyons/python_speech_features)来实现特征的提取,以及求一阶二阶差分。 +### 2017-08-28 +开始准备制作语音信号处理方面的功能 +### 2017-08-22 +准备使用Keras基于LSTM/CNN尝试实现 diff --git a/main.py b/main.py index d8e66a6..ebfe4f4 100644 --- a/main.py +++ b/main.py @@ -40,23 +40,30 @@ class ModelSpeech(): # 语音模型类 _model.compile(optimizer="adam", loss='categorical_crossentropy',metrics=["accuracy"]) return _model - def TrainModel(self,datas,epoch = 2,save_step=5000,filename='model_speech/LSTM_CNN_model'): + def TrainModel(self,datapath,epoch = 2,save_step=1000,filename='model_speech/LSTM_CNN_model'): ''' 训练模型 + 参数: + datapath: 数据保存的路径 + epoch: 迭代轮数 + save_step: 每多少步保存一次模型 + filename: 默认保存文件名,不含文件后缀名 ''' + for epoch in range(epoch): + pass pass - def LoadModel(self,filename='model_speech/LSTM_CNN_model'): + def LoadModel(self,filename='model_speech/LSTM_CNN_model.model'): ''' 加载模型参数 ''' self._model.load_weights(filename) - def SaveModel(self,filename='model_speech/LSTM_CNN_model'): + def SaveModel(self,filename='model_speech/LSTM_CNN_model',comment=''): ''' 保存模型参数 ''' - self._model.save_weights(filename+'.model') + self._model.save_weights(filename+comment+'.model') def TestModel(self): ''' diff --git a/readdata.py b/readdata.py index f1cdbdc..142b592 100644 --- a/readdata.py +++ b/readdata.py @@ -11,39 +11,141 @@ from python_speech_features import logfbank #import scipy.io.wavfile as wav class DataSpeech(): + + def __init__(self,path): ''' 初始化 参数: path:数据存放位置根目录 ''' + self.datapath = path; # 数据存放位置根目录 + if('\\'!=self.datapath[-1]): # 在目录路径末尾增加斜杠 + self.datapath=self.datapath+'\\' + self.dic_wavlist = {} + self.dic_symbollist = {} + self.list_symbol = self.GetSymbolList() # 全部汉语拼音符号列表 + self.list_wavnum=[] # wav文件标记列表 + self.list_symbolnum=[] # symbol标记列表 pass - - def GetData(self,n): + + def LoadDataList(self,type): ''' - 读取数据,返回神经网络输入值和输出值矩阵 + 加载用于计算的数据列表 参数: - n:第几个数据 + type:选取的数据集类型 + train 训练集 + dev 开发集 + test 测试集 ''' - pass + # 设定选取哪一项作为要使用的数据集 + if(type=='train'): + filename_wavlist='doc\\doc\\list\\train.wav.lst' + filename_symbollist='doc\\doc\\trans\\train.syllable.txt' + elif(type=='dev'): + filename_wavlist='doc\\doc\\list\\cv.wav.lst' + filename_symbollist='doc\\doc\\trans\\cv.syllable.txt' + elif(type=='test'): + filename_wavlist='doc\\doc\\list\\test.wav.lst' + filename_symbollist='doc\\doc\\trans\\test.syllable.txt' + else: + filename_wavlist='' # 默认留空 + filename_symbollist='' + # 读取数据列表,wav文件列表和其对应的符号列表 + self.dic_wavlist,self.list_wavnum = get_wav_list(self.datapath+filename_wavlist) + self.dic_symbollist,self.list_symbolnum = get_wav_symbol(self.datapath+filename_symbollist) def GetDataNum(self): ''' 获取数据的数量 + 当wav数量和symbol数量一致的时候返回正确的值,否则返回-1,代表出错。 ''' - pass + if(len(self.dic_wavlist) == len(self.dic_symbollist)): + return len(self.dic_wavlist) + else: + return -1 + + def GetData(self,n_start,n_amount=1): + ''' + 读取数据,返回神经网络输入值和输出值矩阵(可直接用于神经网络训练的那种) + 参数: + n_start:从编号为n_start数据开始选取数据 + n_amount:选取的数据数量,默认为1,即一次一个wav文件 + 返回: + 三个包含wav特征矩阵的神经网络输入值,和一个标定的类别矩阵神经网络输出值 + ''' + # 读取一个文件 + filename = self.dic_wavlist[self.list_wavnum[n_start]] + + filename=filename.replace('/','\\') # windows系统下需要添加这一行 + + wavsignal,fs=read_wav_data(self.datapath+filename) + # 获取输入特征 + feat_mfcc=mfcc(wavsignal[0],fs) + feat_mfcc_d=delta(feat_mfcc,2) + feat_mfcc_dd=delta(feat_mfcc_d,2) + # 获取输出特征 + list_symbol=self.dic_symbollist[self.list_symbolnum[n_start]] + feat_out=[] + for i in list_symbol: + if(''!=i): + n=self.SymbolToNum(i) + v=self.NumToVector(n) + feat_out.append(v) + # 返回值分别是mfcc特征向量的矩阵及其一阶差分和二阶差分矩阵,以及对应的拼音符号矩阵 + return feat_mfcc,feat_mfcc_d,feat_mfcc_dd,np.array(feat_out) + def GetSymbolList(self): + ''' + 加载拼音符号列表,用于标记符号 + 返回一个列表list类型变量 + ''' + txt_obj=open(self.datapath+'dict.txt','r',encoding='UTF-8') # 打开文件并读入 + txt_text=txt_obj.read() + txt_lines=txt_text.split('\n') # 文本分割 + list_symbol=[] # 初始化符号列表 + for i in txt_lines: + if(i!=''): + txt_l=i.split('\t') + list_symbol.append(txt_l[0]) + txt_obj.close() + list_symbol.append(' ') + return list_symbol + + def SymbolToNum(self,symbol): + ''' + 符号转为数字 + ''' + return self.list_symbol.index(symbol) + + def NumToVector(self,num): + ''' + 数字转为对应的向量 + ''' + v_tmp=[] + for i in range(0,len(self.list_symbol)): + if(i==num): + v_tmp.append(1) + else: + v_tmp.append(0) + v=np.array([v_tmp]) + return v if(__name__=='__main__'): - wave_data, fs = read_wav_data("general_function\\A2_0.wav") - print(wave_data) + #wave_data, fs = read_wav_data("general_function\\A2_0.wav") + #print(wave_data) #(fs,wave_data)=wav.read('E:\\国创项目工程\代码\\ASRT_SpeechRecognition\\general_function\\A2_0.wav') - wav_show(wave_data[0],fs) + #wav_show(wave_data[0],fs) #mfcc_feat = mfcc(wave_data[0],fs) # 计算MFCC特征 - #print(mfcc_feat[100:110,:]) + #print(mfcc_feat[0:3,:]) #d_mfcc_feat_1 = delta(mfcc_feat, 2) #print(d_mfcc_feat_1[0,:]) #d_mfcc_feat_2 = delta(d_mfcc_feat_1, 2) #print(d_mfcc_feat_2[0,:]) + #path='E:\\语音数据集' + #l=DataSpeech(path) + #l.LoadDataList('train') + #print(l.GetDataNum()) + #print(l.GetData(0)) pass \ No newline at end of file