style: 修复代码风格
This commit is contained in:
parent
1a94ede098
commit
1fba22ee4c
|
@ -32,18 +32,19 @@ from speech_features import Spectrogram
|
|||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
audio_length = 1600
|
||||
audio_feature_length = 200
|
||||
channels = 1
|
||||
AUDIO_LENGTH = 1600
|
||||
AUDIO_FEATURE_LENGTH = 200
|
||||
CHANNELS = 1
|
||||
# 默认输出的拼音的表示大小是1428,即1427个拼音+1个空白块
|
||||
output_size = 1428
|
||||
OUTPUT_SIZE = 1428
|
||||
sm251 = SpeechModel251(
|
||||
input_shape=(audio_length, audio_feature_length, channels),
|
||||
output_size=output_size
|
||||
input_shape=(AUDIO_LENGTH, AUDIO_FEATURE_LENGTH, CHANNELS),
|
||||
output_size=OUTPUT_SIZE
|
||||
)
|
||||
feat = Spectrogram()
|
||||
evalue_data = DataLoader('dev')
|
||||
ms = ModelSpeech(sm251, feat, max_label_length=64)
|
||||
|
||||
ms.load_model('save_models/' + sm251.get_model_name() + '.h5')
|
||||
ms.evaluate_model(data_loader=evalue_data, data_count=-1, out_report=True, show_ratio=True, show_per_step=100)
|
||||
ms.evaluate_model(data_loader=evalue_data, data_count=-1,
|
||||
out_report=True, show_ratio=True, show_per_step=100)
|
||||
|
|
|
@ -34,14 +34,14 @@ from speech_features import Spectrogram
|
|||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
audio_length = 1600
|
||||
audio_feature_length = 200
|
||||
channels = 1
|
||||
AUDIO_LENGTH = 1600
|
||||
AUDIO_FEATURE_LENGTH = 200
|
||||
CHANNELS = 1
|
||||
# 默认输出的拼音的表示大小是1428,即1427个拼音+1个空白块
|
||||
output_size = 1428
|
||||
OUTPUT_SIZE = 1428
|
||||
sm251 = SpeechModel251(
|
||||
input_shape=(audio_length, audio_feature_length, channels),
|
||||
output_size=output_size
|
||||
input_shape=(AUDIO_LENGTH, AUDIO_FEATURE_LENGTH, CHANNELS),
|
||||
output_size=OUTPUT_SIZE
|
||||
)
|
||||
feat = Spectrogram()
|
||||
train_data = DataLoader('train')
|
||||
|
@ -49,5 +49,6 @@ opt = Adam(lr = 0.0001, beta_1 = 0.9, beta_2 = 0.999, decay = 0.0, epsilon = 10e
|
|||
ms = ModelSpeech(sm251, feat, max_label_length=64)
|
||||
|
||||
#ms.load_model('save_models/' + sm251.get_model_name() + '.h5')
|
||||
ms.train_model(optimizer=opt, data_loader=train_data, epochs=1, save_step=1, batch_size=16, last_epoch=0)
|
||||
ms.train_model(optimizer=opt, data_loader=train_data,
|
||||
epochs=1, save_step=1, batch_size=16, last_epoch=0)
|
||||
ms.save_model('save_models/' + sm251.get_model_name())
|
||||
|
|
Loading…
Reference in New Issue