添加ctc层基本框架

This commit is contained in:
nl8590687 2018-03-10 22:30:20 +08:00
parent 4ca877a046
commit 2ec4b937c9
2 changed files with 77 additions and 7 deletions

30
main.py
View File

@ -9,10 +9,11 @@ import numpy as np
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Input # , Flatten,LSTM,Convolution1D,MaxPooling1D,Merge
from keras.layers import Conv1D,LSTM,MaxPooling1D, Lambda #, Merge, Conv2D, MaxPooling2D,Conv1D
from keras.layers import Conv1D,LSTM,MaxPooling1D, Lambda, TimeDistributed, Activation #, Merge, Conv2D, MaxPooling2D,Conv1D
from keras import backend as K
from readdata import DataSpeech
from neural_network import ctc_layer
class ModelSpeech(): # 语音模型类
def __init__(self,MS_OUTPUT_SIZE = 1283,BATCH_SIZE = 32):
@ -33,9 +34,9 @@ class ModelSpeech(): # 语音模型类
隐藏层四循环层LSTM层
隐藏层五Dropout层需要断开的神经元的比例为0.2防止过拟合
隐藏层六全连接层神经元数量为self.MS_OUTPUT_SIZE使用softmax作为激活函数
输出层lambda即CTC层使用CTC的loss作为损失函数实现多输出
输出层自定义即CTC层使用CTC的loss作为损失函数实现连接性时序多输出
当前未完成针对多输出的CTC层尚未添加
当前未完成针对多输出的CTC层尚未实现
'''
# 每一帧使用13维mfcc特征及其13维一阶差分和13维二阶差分表示最大信号序列长度为1500
layer_input = Input((1500,39))
@ -43,17 +44,32 @@ class ModelSpeech(): # 语音模型类
layer_h1 = Conv1D(256, 5, use_bias=True, padding="valid")(layer_input) # 卷积层
layer_h2 = MaxPooling1D(pool_size=2, strides=None, padding="valid")(layer_h1) # 池化层
layer_h3 = Dropout(0.2)(layer_h2) # 随机中断部分神经网络连接,防止过拟合
layer_h4 = LSTM(256, activation='relu', use_bias=True)(layer_h3) # LSTM层
layer_h4 = LSTM(256, activation='relu', use_bias=True, return_sequences=True)(layer_h3) # LSTM层
layer_h5 = Dropout(0.2)(layer_h4) # 随机中断部分神经网络连接,防止过拟合
layer_h6 = Dense(self.MS_OUTPUT_SIZE, activation="softmax")(layer_h5) # 全连接层
layer_out = ctc_layer()(layer_h6) # CTC层 尚未实现!
#labels = Input(name='the_labels', shape=[60], dtype='float32')
layer_out = Lambda(ctc_lambda_func,output_shape=(self.MS_OUTPUT_SIZE, ), name='ctc')(layer_h6) # CTC
#layer_out = Lambda(ctc_lambda_func,output_shape=(self.MS_OUTPUT_SIZE, ), name='ctc')(layer_h6) # CTC
#layer_out = TimeDistributed(Dense(self.MS_OUTPUT_SIZE, activation="softmax"))(layer_h5)
_model = Model(inputs = layer_input, outputs = layer_out)
#_model.compile(optimizer="sgd", loss='categorical_crossentropy',metrics=["accuracy"])
_model.compile(optimizer="sgd", loss='ctc',metrics=["accuracy"])
#_model = Sequential()
#_model.add(Conv1D(256, 5, use_bias=True, padding="valid", input_shape=(1500,39)))
#_model.add(MaxPooling1D(pool_size=2, strides=None, padding="valid"))
#_model.add(Dropout(0.2))
#_model.add(LSTM(256, activation='relu', use_bias=True, return_sequences=True))
#_model.add(Dropout(0.2))
#_model.add(TimeDistributed(Dense(self.MS_OUTPUT_SIZE)))
#_model.add(Activation("softmax"))
_model.compile(optimizer="sgd", loss='categorical_crossentropy',metrics=["accuracy"])
#_model.compile(optimizer="sgd", loss='ctc',metrics=["accuracy"])
return _model
def ctc_lambda_func(args):

View File

@ -0,0 +1,54 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
'''
本代码用来实现神经网络中的CTC层
CTC层即Connectionist Temporal Classification 连续型短时分类
将这里实现的
尚未完成
'''
from keras.layers.core import Layer
from keras.engine import InputSpec
from keras import backend as K
try:
from keras import initializations
except ImportError:
from keras import initializers as initializations
import tensorflow as tf
# 继承父类Layer
class ctc_layer(Layer):
'''
对CTC层的实现具体需要再去参考下论文...以及tensorflow中ctc的实现
并将其通过自定义层加入到keras的神经网络层中
'''
def __init__(self, input_dim, output_dim, **kwargs):
super(ctc_layer, self).__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
#self.input_spec = InputSpec(min_ndim=3)
pass
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=('''input_shape[0],''' self.output_dim, -1),
initializer='uniform',
trainable=True)
super(MyLayer, self).build(input_shape) # Be sure to call this somewhere!
def call(self, x, mask=None):
decoded_dense, log_prob = K.ctc_decode(x,self.input_dim)
decoded_sequence = K.ctc_label_dense_to_sparse(decoded_dense, decoded_dense.shape[0])
return decoded_sequence
def get_config(self):
pass