一个暂时能训练的版本,可能会炸
This commit is contained in:
parent
7c59bf7874
commit
7d69729035
|
@ -28,6 +28,8 @@ class ModelSpeech(): # 语音模型类
|
|||
self.MS_OUTPUT_SIZE = MS_OUTPUT_SIZE # 神经网络最终输出的每一个字符向量维度的大小
|
||||
self.BATCH_SIZE = BATCH_SIZE # 一次训练的batch
|
||||
self.label_max_string_length = 64
|
||||
self.AUDIO_LENGTH = 1600
|
||||
self.AUDIO_FEATURE_LENGTH = 39
|
||||
self._model = self.CreateModel()
|
||||
|
||||
|
||||
|
@ -47,7 +49,7 @@ class ModelSpeech(): # 语音模型类
|
|||
当前未完成,针对多输出的CTC层尚未实现
|
||||
'''
|
||||
# 每一帧使用13维mfcc特征及其13维一阶差分和13维二阶差分表示,最大信号序列长度为1500
|
||||
input_data = Input(name='the_input', shape=(1500,39))
|
||||
input_data = Input(name='the_input', shape=(self.AUDIO_LENGTH,self.AUDIO_FEATURE_LENGTH))
|
||||
|
||||
layer_h1 = Conv1D(256, 5, use_bias=True, padding="valid")(input_data) # 卷积层
|
||||
layer_h2 = MaxPooling1D(pool_size=2, strides=None, padding="valid")(layer_h1) # 池化层
|
||||
|
@ -77,7 +79,7 @@ class ModelSpeech(): # 语音模型类
|
|||
loss_out = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])
|
||||
|
||||
# clipnorm seems to speeds up convergence
|
||||
sgd = SGD(lr=0.02, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
|
||||
sgd = SGD(lr=0.002, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
|
||||
|
||||
model = Model(inputs=[input_data, labels, input_length, label_length], outputs=loss_out)
|
||||
|
||||
|
@ -137,7 +139,7 @@ class ModelSpeech(): # 语音模型类
|
|||
try:
|
||||
print('[message] epoch %d . Have train datas %d+'%(epoch, n_step*save_step))
|
||||
# data_genetator是一个生成器函数
|
||||
yielddatas = data.data_genetator(self.BATCH_SIZE)
|
||||
yielddatas = data.data_genetator(self.BATCH_SIZE, self.AUDIO_LENGTH)
|
||||
#self._model.fit_generator(yielddatas, save_step, nb_worker=2)
|
||||
self._model.fit_generator(yielddatas, save_step)
|
||||
n_step += 1
|
||||
|
|
|
@ -108,13 +108,13 @@ class DataSpeech():
|
|||
data_label = np.array(feat_out)
|
||||
return data_input, data_label
|
||||
|
||||
def data_genetator(self, batch_size=32):
|
||||
def data_genetator(self, batch_size=32, audio_length = 1600):
|
||||
'''
|
||||
数据生成器函数,用于Keras的generator_fit训练
|
||||
batch_size: 一次产生的数据量
|
||||
需要再修改。。。
|
||||
'''
|
||||
X = np.zeros((batch_size, 1500,39), dtype=np.int16)
|
||||
X = np.zeros((batch_size, audio_length,39), dtype=np.int16)
|
||||
#y = np.zeros((batch_size, 64, self.SymbolNum), dtype=np.int16)
|
||||
y = np.zeros((batch_size, 64), dtype=np.int16)
|
||||
|
||||
|
@ -123,7 +123,7 @@ class DataSpeech():
|
|||
labels = []
|
||||
for i in range(0,batch_size):
|
||||
#input_length.append([1500])
|
||||
label_length.append([39])
|
||||
label_length.append([30])
|
||||
labels.append([1])
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue