style: 修复代码风格

This commit is contained in:
nl 2021-11-21 20:20:26 +08:00
parent 1a94ede098
commit 1fba22ee4c
2 changed files with 16 additions and 14 deletions

View File

@ -32,18 +32,19 @@ from speech_features import Spectrogram
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
audio_length = 1600
audio_feature_length = 200
channels = 1
AUDIO_LENGTH = 1600
AUDIO_FEATURE_LENGTH = 200
CHANNELS = 1
# 默认输出的拼音的表示大小是1428即1427个拼音+1个空白块
output_size = 1428
OUTPUT_SIZE = 1428
sm251 = SpeechModel251(
input_shape=(audio_length, audio_feature_length, channels),
output_size=output_size
input_shape=(AUDIO_LENGTH, AUDIO_FEATURE_LENGTH, CHANNELS),
output_size=OUTPUT_SIZE
)
feat = Spectrogram()
evalue_data = DataLoader('dev')
ms = ModelSpeech(sm251, feat, max_label_length=64)
ms.load_model('save_models/' + sm251.get_model_name() + '.h5')
ms.evaluate_model(data_loader=evalue_data, data_count=-1, out_report=True, show_ratio=True, show_per_step=100)
ms.evaluate_model(data_loader=evalue_data, data_count=-1,
out_report=True, show_ratio=True, show_per_step=100)

View File

@ -34,14 +34,14 @@ from speech_features import Spectrogram
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
audio_length = 1600
audio_feature_length = 200
channels = 1
AUDIO_LENGTH = 1600
AUDIO_FEATURE_LENGTH = 200
CHANNELS = 1
# 默认输出的拼音的表示大小是1428即1427个拼音+1个空白块
output_size = 1428
OUTPUT_SIZE = 1428
sm251 = SpeechModel251(
input_shape=(audio_length, audio_feature_length, channels),
output_size=output_size
input_shape=(AUDIO_LENGTH, AUDIO_FEATURE_LENGTH, CHANNELS),
output_size=OUTPUT_SIZE
)
feat = Spectrogram()
train_data = DataLoader('train')
@ -49,5 +49,6 @@ opt = Adam(lr = 0.0001, beta_1 = 0.9, beta_2 = 0.999, decay = 0.0, epsilon = 10e
ms = ModelSpeech(sm251, feat, max_label_length=64)
#ms.load_model('save_models/' + sm251.get_model_name() + '.h5')
ms.train_model(optimizer=opt, data_loader=train_data, epochs=1, save_step=1, batch_size=16, last_epoch=0)
ms.train_model(optimizer=opt, data_loader=train_data,
epochs=1, save_step=1, batch_size=16, last_epoch=0)
ms.save_model('save_models/' + sm251.get_model_name())