From ab74ee4bfc230e5ba2eb38dbce87cebd7016aac5 Mon Sep 17 00:00:00 2001 From: nl8590687 <3210346136@qq.com> Date: Sat, 26 Aug 2017 23:40:28 +0800 Subject: [PATCH] implement to read wav files and list files. --- README.md | 2 ++ general_function/file_wav.py | 62 ++++++++++++++++++++++++++++++------ main.py | 57 ++++++++++++++++++--------------- readdata.py | 2 +- 4 files changed, 87 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 2a17148..ca13d7c 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,8 @@ This project will use TensorFlow based on RNN and CNN to implement. +本项目尚未完成,想要Fork的同学请手慢。 + ## Model 模型 diff --git a/general_function/file_wav.py b/general_function/file_wav.py index dd6a66f..f6fe5a5 100644 --- a/general_function/file_wav.py +++ b/general_function/file_wav.py @@ -1,19 +1,63 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- import os +import wave +import numpy as np +import matplotlib.pyplot as plt -def read_wav_file(filename): +def read_wav_data(filename): ''' - ȡһwavļһļ + 读取一个wav文件,返回声音信号的时域谱矩阵和播放时间 ''' - #Ӵ + wav = wave.open(filename,"rb") # 打开一个wav格式的声音文件流 + num_frame = wav.getnframes() # 获取帧数 + num_channel=wav.getnchannels() # 获取声道数 + framerate=wav.getframerate() # 获取帧速率 + num_sample_width=wav.getsampwidth() # 获取实例的比特宽度,即每一帧的字节数 + str_data = wav.readframes(num_frame) # 读取全部的帧 + wav.close() # 关闭流 + wave_data = np.fromstring(str_data, dtype = np.short) # 将声音文件数据转换为数组矩阵形式 + wave_data.shape = -1, num_channel # 按照声道数将数组整形,单声道时候是一列数组,双声道时候是两列的矩阵 + wave_data = wave_data.T # 将矩阵转置 + time = np.arange(0, num_frame) * (1.0/framerate) # 计算声音的播放时间,单位为秒 + return wave_data, time + +def wav_show(wave_data, time): # 显示出来声音波形 + #wave_data, time = read_wave_data("C:\\Users\\nl\\Desktop\\A2_0.wav") + #draw the wave + #plt.subplot(211) + plt.plot(time, wave_data[0]) + #plt.subplot(212) + #plt.plot(time, wave_data[1], c = "g") + plt.show() + def get_wav_list(filename): ''' - ȡһwavļб - ps:רмļڴѵ֤ͲԵwavļб + 读取一个wav文件列表,返回一个存储该列表的字典类型值 + ps:在数据中专门有几个文件用于存放用于训练、验证和测试的wav文件列表 ''' - #Ӵ + txt_obj=open(filename,'r') # 打开文件并读入 + txt_text=txt_obj.read() + txt_lines=txt_text.split('\n') # 文本分割 + dic_filelist={} # 初始化字典 + for i in txt_lines: + if(i!=''): + txt_l=i.split(' ') + dic_filelist[txt_l[0]]=txt_l[1] + return dic_filelist + +def get_wav_symbol(filename): + ''' + 读取指定数据集中,所有wav文件对应的语音符号 + 返回一个存储符号集的字典类型值 + ''' + print('test') +#if(__name__=='__main__'): + #dic=get_wav_list('E:\\语音数据集\\doc\\doc\\list\\train.wav.lst') + #for i in dic: + #print(i,dic[i]) + #wave_data, time = read_wav_data("C:\\Users\\nl\\Desktop\\A2_0.wav") + #wav_show(wave_data,time) - - - diff --git a/main.py b/main.py index 80e9599..f29dc2c 100644 --- a/main.py +++ b/main.py @@ -1,22 +1,22 @@ -# -*- coding: encoding -*- +# -*- coding: utf-8 -*- """ @author: nl8590687 """ -#LSTM_CNN +# LSTM_CNN import keras as kr import numpy as np from keras.models import Sequential -from keras.layers import Dense, Dropout, Flatten#,Input,LSTM,Convolution1D,MaxPooling1D,Merge -from keras.layers import Conv1D,LSTM,MaxPooling1D,Merge#Conv2D, MaxPooling2D,Conv1D +from keras.layers import Dense, Dropout, Flatten # ,Input,LSTM,Convolution1D,MaxPooling1D,Merge +from keras.layers import Conv1D,LSTM,MaxPooling1D,Merge # Conv2D, MaxPooling2D,Conv1D class ModelSpeech(): # 语音模型类 def __init__(self,MS_EMBED_SIZE = 64,BATCH_SIZE = 32): # 初始化 self.MS_EMBED_SIZE = MS_EMBED_SIZE # LSTM 的大小 - self.BATCH_SIZE = BATCH_SIZE # 一次训练的batch - self._model = self.createLSTMModel() + self.BATCH_SIZE = BATCH_SIZE # 一次训练的batch + self._model = self.createLSTMModel() - def CreateLSTMModel(self):# 定义训练模型,尚未完成 + def CreateModel(self): # 定义训练模型,尚未完成 # 定义LSTM/CNN模型 _model = Sequential() @@ -27,25 +27,30 @@ class ModelSpeech(): # 语音模型类 _model.add(Dropout(0.3)) _model.add(Flatten()) - - #_model = Sequential() - #_model.add(Merge([m_lstm, aenc], mode="concat", concat_axis=-1)) - _model.add(Dense(1279, activation="softmax")) - _model.compile(optimizer="adam", loss='categorical_crossentropy',metrics=["accuracy"]) - return _model - - def Train(self): - # 训练模型 - def LoadModel(self,filename='model_speech/LSTM_CNN.model'): - self._model.load_weights(filename) + #_model = Sequential() + #_model.add(Merge([m_lstm, aenc], mode="concat", concat_axis=-1)) + _model.add(Dense(1279, activation="softmax")) + _model.compile(optimizer="adam", loss='categorical_crossentropy',metrics=["accuracy"]) + return _model + + def TrainModel(self,datas,epoch = 2,save_step=5000,filename='model_speech/LSTM_CNN_model'): # 训练模型 + print('test') + + def LoadModel(self,filename='model_speech/LSTM_CNN_model'): # 加载模型参数 + self._model.load_weights(filename) + + def SaveModel(self,filename='model_speech/LSTM_CNN_model'): # 保存模型参数 + self._model.save_weights(filename+'.model') + + def TestModel(self): # 测试检验模型效果 + print('test') + + @property + def model(self): # 返回keras model + return self._model - def SaveModel(self,filename='model_speech/LSTM_CNN.model'): - # 保存模型参数 - - def Test(self): - # 测试检验模型效果 - - -print('test') \ No newline at end of file + +print('test') +print(__name__) \ No newline at end of file diff --git a/readdata.py b/readdata.py index 9962574..8fde703 100644 --- a/readdata.py +++ b/readdata.py @@ -1,4 +1,4 @@ -# -*- coding: encoding -*- +# -*- coding: utf-8 -*- import numpy as np