ASRT_SpeechRecognition/readdata24_limitless.py

303 lines
9.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
#
# This file is part of ASRT Speech Recognition Tool.
#
# ASRT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# ASRT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
# ============================================================================
'''
@author: nl8590687
一个对于单一音频时间长度不作限制的版本,正在测试
'''
import platform as plat
import os
import numpy as np
from general_function.file_wav import *
from general_function.file_dict import *
import random
#import scipy.io.wavfile as wav
from scipy.fftpack import fft
class DataSpeech():
def __init__(self, path, type, LoadToMem = False, MemWavCount = 10000):
'''
初始化
参数:
path数据存放位置根目录
'''
system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断
self.datapath = path; # 数据存放位置根目录
self.type = type # 数据类型,分为三种:训练集(train)、验证集(dev)、测试集(test)
self.slash = ''
if(system_type == 'Windows'):
self.slash='\\' # 反斜杠
elif(system_type == 'Linux'):
self.slash='/' # 正斜杠
else:
print('*[Message] Unknown System\n')
self.slash='/' # 正斜杠
if(self.slash != self.datapath[-1]): # 在目录路径末尾增加斜杠
self.datapath = self.datapath + self.slash
self.dic_wavlist_thchs30 = {}
self.dic_symbollist_thchs30 = {}
self.dic_wavlist_stcmds = {}
self.dic_symbollist_stcmds = {}
self.SymbolNum = 0 # 记录拼音符号数量
self.list_symbol = self.GetSymbolList() # 全部汉语拼音符号列表
self.list_wavnum=[] # wav文件标记列表
self.list_symbolnum=[] # symbol标记列表
self.DataNum = 0 # 记录数据量
self.LoadDataList()
self.wavs_data = []
self.LoadToMem = LoadToMem
self.MemWavCount = MemWavCount
pass
def LoadDataList(self):
'''
加载用于计算的数据列表
参数:
type选取的数据集类型
train 训练集
dev 开发集
test 测试集
'''
# 设定选取哪一项作为要使用的数据集
if(self.type=='train'):
filename_wavlist_thchs30 = 'thchs30' + self.slash + 'train.wav.lst'
filename_wavlist_stcmds = 'st-cmds' + self.slash + 'train.wav.txt'
filename_symbollist_thchs30 = 'thchs30' + self.slash + 'train.syllable.txt'
filename_symbollist_stcmds = 'st-cmds' + self.slash + 'train.syllable.txt'
elif(self.type=='dev'):
filename_wavlist_thchs30 = 'thchs30' + self.slash + 'cv.wav.lst'
filename_wavlist_stcmds = 'st-cmds' + self.slash + 'dev.wav.txt'
filename_symbollist_thchs30 = 'thchs30' + self.slash + 'cv.syllable.txt'
filename_symbollist_stcmds = 'st-cmds' + self.slash + 'dev.syllable.txt'
elif(self.type=='test'):
filename_wavlist_thchs30 = 'thchs30' + self.slash + 'test.wav.lst'
filename_wavlist_stcmds = 'st-cmds' + self.slash + 'test.wav.txt'
filename_symbollist_thchs30 = 'thchs30' + self.slash + 'test.syllable.txt'
filename_symbollist_stcmds = 'st-cmds' + self.slash + 'test.syllable.txt'
else:
filename_wavlist = '' # 默认留空
filename_symbollist = ''
# 读取数据列表wav文件列表和其对应的符号列表
self.dic_wavlist_thchs30,self.list_wavnum_thchs30 = get_wav_list(self.datapath + filename_wavlist_thchs30)
self.dic_wavlist_stcmds,self.list_wavnum_stcmds = get_wav_list(self.datapath + filename_wavlist_stcmds)
self.dic_symbollist_thchs30,self.list_symbolnum_thchs30 = get_wav_symbol(self.datapath + filename_symbollist_thchs30)
self.dic_symbollist_stcmds,self.list_symbolnum_stcmds = get_wav_symbol(self.datapath + filename_symbollist_stcmds)
self.DataNum = self.GetDataNum()
def GetDataNum(self):
'''
获取数据的数量
当wav数量和symbol数量一致的时候返回正确的值否则返回-1代表出错。
'''
num_wavlist_thchs30 = len(self.dic_wavlist_thchs30)
num_symbollist_thchs30 = len(self.dic_symbollist_thchs30)
num_wavlist_stcmds = len(self.dic_wavlist_stcmds)
num_symbollist_stcmds = len(self.dic_symbollist_stcmds)
if(num_wavlist_thchs30 == num_symbollist_thchs30 and num_wavlist_stcmds == num_symbollist_stcmds):
DataNum = num_wavlist_thchs30 + num_wavlist_stcmds
else:
DataNum = -1
return DataNum
def GetData(self,n_start,n_amount=1):
'''
读取数据,返回神经网络输入值和输出值矩阵(可直接用于神经网络训练的那种)
参数:
n_start从编号为n_start数据开始选取数据
n_amount选取的数据数量默认为1即一次一个wav文件
返回:
三个包含wav特征矩阵的神经网络输入值和一个标定的类别矩阵神经网络输出值
'''
bili = 2
if(self.type=='train'):
bili = 11
# 读取一个文件
if(n_start % bili == 0):
filename = self.dic_wavlist_thchs30[self.list_wavnum_thchs30[n_start // bili]]
list_symbol=self.dic_symbollist_thchs30[self.list_symbolnum_thchs30[n_start // bili]]
else:
n = n_start // bili * (bili - 1)
yushu = n_start % bili
length=len(self.list_wavnum_stcmds)
filename = self.dic_wavlist_stcmds[self.list_wavnum_stcmds[(n + yushu - 1)%length]]
list_symbol=self.dic_symbollist_stcmds[self.list_symbolnum_stcmds[(n + yushu - 1)%length]]
if('Windows' == plat.system()):
filename = filename.replace('/','\\') # windows系统下需要执行这一行对文件路径做特别处理
wavsignal,fs=read_wav_data(self.datapath + filename)
# 获取输出特征
feat_out=[]
#print("数据编号",n_start,filename)
for i in list_symbol:
if(''!=i):
n=self.SymbolToNum(i)
#v=self.NumToVector(n)
#feat_out.append(v)
feat_out.append(n)
#print('feat_out:',feat_out)
# 获取输入特征
data_input = GetFrequencyFeature3(wavsignal,fs)
#data_input = np.array(data_input)
data_input = data_input.reshape(data_input.shape[0],data_input.shape[1],1)
#arr_zero = np.zeros((1, 39), dtype=np.int16) #一个全是0的行向量
#data_input = data_input.T
data_label = np.array(feat_out)
return data_input, data_label
def data_genetator(self, batch_size=32, audio_length = 1600):
'''
数据生成器函数用于Keras的generator_fit训练
batch_size: 一次产生的数据量
需要再修改。。。
'''
#labels = []
#for i in range(0,batch_size):
# #input_length.append([1500])
# labels.append([0.0])
#labels = np.array(labels, dtype = np.float)
labels = np.zeros((batch_size,1), dtype = np.float)
#print(input_length,len(input_length))
while True:
#X = np.zeros((batch_size, audio_length, 200, 1), dtype = np.float)
X = []
#y = np.zeros((batch_size, 64, self.SymbolNum), dtype=np.int16)
#y = np.zeros((batch_size, 64), dtype=np.int16)
y = []
#generator = ImageCaptcha(width=width, height=height)
input_length = []
label_length = []
for i in range(batch_size):
ran_num = random.randint(0,self.DataNum - 1) # 获取一个随机数
data_input, data_labels = self.GetData(ran_num) # 通过随机数取一个数据
# 关于下面这一行取整除以8 并加8的余数在实际中如果遇到报错可尝试只在有余数时+1没有余数时+0或者干脆都不加只留整除
input_length.append(data_input.shape[0] // 8 + data_input.shape[0] % 8)
#print(data_input, data_labels)
#print('data_input长度:',len(data_input))
#X[i,0:len(data_input)] = data_input
X.append(data_input)
#print('data_labels长度:',len(data_labels))
#print(data_labels)
#y[i,0:len(data_labels)] = data_labels
y.append(data_labels)
#print(i,y[i].shape)
label_length.append([len(data_labels)])
label_length = np.matrix(label_length)
input_length = np.array([input_length]).T
#input_length = np.array(input_length)
#print('input_length:\n',input_length)
#X=X.reshape(batch_size, audio_length, 200, 1)
#print(X)
yield [X, y, input_length, label_length ], labels
pass
def GetSymbolList(self):
'''
加载拼音符号列表,用于标记符号
返回一个列表list类型变量
'''
txt_obj=open('dict.txt','r',encoding='UTF-8') # 打开文件并读入
txt_text=txt_obj.read()
txt_lines=txt_text.split('\n') # 文本分割
list_symbol=[] # 初始化符号列表
for i in txt_lines:
if(i!=''):
txt_l=i.split('\t')
list_symbol.append(txt_l[0])
txt_obj.close()
list_symbol.append('_')
self.SymbolNum = len(list_symbol)
return list_symbol
def GetSymbolNum(self):
'''
获取拼音符号数量
'''
return len(self.list_symbol)
def SymbolToNum(self,symbol):
'''
符号转为数字
'''
if(symbol != ''):
return self.list_symbol.index(symbol)
return self.SymbolNum
def NumToVector(self,num):
'''
数字转为对应的向量
'''
v_tmp=[]
for i in range(0,len(self.list_symbol)):
if(i==num):
v_tmp.append(1)
else:
v_tmp.append(0)
v=np.array(v_tmp)
return v
if(__name__=='__main__'):
#path='E:\\语音数据集'
#l=DataSpeech(path)
#l.LoadDataList('train')
#print(l.GetDataNum())
#print(l.GetData(0))
#aa=l.data_genetator()
#for i in aa:
#a,b=i
#print(a,b)
pass