修复pytorch训练时loss问题
This commit is contained in:
parent
9375ecb2db
commit
341b491baa
|
@ -22,10 +22,9 @@
|
|||
@author: nl8590687 / Evelynn-n
|
||||
若干 pytorch版声学模型模型的定义和实现
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as function
|
||||
import tqdm
|
||||
|
||||
|
||||
class SpeechModel251BN(nn.Module):
|
||||
|
@ -37,83 +36,86 @@ class SpeechModel251BN(nn.Module):
|
|||
self._model_name = 'SpeechModel251bn'
|
||||
self.output_shape = (input_shape[0] // self._pool_size, output_size)
|
||||
|
||||
# block 1
|
||||
self.conv0 = nn.Conv2d(1, 32, kernel_size=(3, 3), padding='same')
|
||||
torch.nn.init.kaiming_normal_(self.conv0.weight.data)
|
||||
self.bn0 = nn.BatchNorm2d(32, eps=0.0002)
|
||||
|
||||
self.conv1 = nn.Conv2d(32, 32, kernel_size=(3, 3), padding='same')
|
||||
torch.nn.init.kaiming_normal_(self.conv1.weight.data)
|
||||
self.bn1 = nn.BatchNorm2d(32, eps=0.0002)
|
||||
|
||||
# self.maxpool1 = F.max_pool2d(kernel_size=(2, 2), stride=(2, 2))
|
||||
|
||||
# block 2
|
||||
self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3), padding='same')
|
||||
torch.nn.init.kaiming_normal_(self.conv2.weight.data)
|
||||
self.bn2 = nn.BatchNorm2d(64, eps=0.0002)
|
||||
|
||||
self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), padding='same')
|
||||
torch.nn.init.kaiming_normal_(self.conv3.weight.data)
|
||||
self.bn3 = nn.BatchNorm2d(64, eps=0.0002)
|
||||
|
||||
# self.maxpool2 = F.max_pool2d(kernel_size=(2, 2), stride=(2, 2))
|
||||
|
||||
# block 3
|
||||
self.conv4 = nn.Conv2d(64, 128, kernel_size=(3, 3), padding='same')
|
||||
torch.nn.init.kaiming_normal_(self.conv4.weight.data)
|
||||
self.bn4 = nn.BatchNorm2d(128, eps=0.0002)
|
||||
|
||||
self.conv5 = nn.Conv2d(128, 128, kernel_size=(3, 3), padding='same')
|
||||
torch.nn.init.kaiming_normal_(self.conv5.weight.data)
|
||||
self.bn5 = nn.BatchNorm2d(128, eps=0.0002)
|
||||
|
||||
# self.maxpool3 = F.max_pool2d(kernel_size=(2, 2), stride=(2, 2))
|
||||
|
||||
# block 4
|
||||
self.conv6 = nn.Conv2d(128, 128, kernel_size=(3, 3), padding='same')
|
||||
torch.nn.init.kaiming_normal_(self.conv6.weight.data)
|
||||
self.bn6 = nn.BatchNorm2d(128, eps=0.0002)
|
||||
|
||||
self.conv7 = nn.Conv2d(128, 128, kernel_size=(3, 3), padding='same')
|
||||
torch.nn.init.kaiming_normal_(self.conv7.weight.data)
|
||||
self.bn7 = nn.BatchNorm2d(128, eps=0.0002)
|
||||
|
||||
# self.maxpool4 = F.max_pool2d(kernel_size=(1,1),stride=(1,1))
|
||||
|
||||
# block 5
|
||||
self.conv8 = nn.Conv2d(128, 128, kernel_size=(3, 3), padding='same')
|
||||
torch.nn.init.kaiming_normal_(self.conv8.weight.data)
|
||||
self.bn8 = nn.BatchNorm2d(128, eps=0.0002)
|
||||
|
||||
self.conv9 = nn.Conv2d(128, 128, kernel_size=(3, 3), padding='same')
|
||||
torch.nn.init.kaiming_normal_(self.conv9.weight.data)
|
||||
self.bn9 = nn.BatchNorm2d(128, eps=0.0002)
|
||||
|
||||
# self.maxpool5 = F.max_pool2d(kernel_size=(1,1),stride=(1,1))
|
||||
|
||||
self.dense0 = nn.Linear(input_shape[1]//8*128, 128)
|
||||
torch.nn.init.kaiming_normal_(self.dense0.weight.data)
|
||||
|
||||
self.dense1 = nn.Linear(128, output_size)
|
||||
torch.nn.init.kaiming_normal_(self.dense1.weight.data)
|
||||
|
||||
self.ctc_loss = nn.CTCLoss(blank=0)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 3, 2, 1)
|
||||
|
||||
# block 1
|
||||
x = function.relu(self.bn0(self.conv0(x)))
|
||||
x = function.relu(self.bn1(self.conv1(x)))
|
||||
# print(x.size())
|
||||
x = function.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2))
|
||||
|
||||
# block 2
|
||||
x = function.relu(self.bn2(self.conv2(x)))
|
||||
x = function.relu(self.bn3(self.conv3(x)))
|
||||
# print(x.size())
|
||||
x = function.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2))
|
||||
|
||||
# block 3
|
||||
x = function.relu(self.bn4(self.conv4(x)))
|
||||
x = function.relu(self.bn5(self.conv5(x)))
|
||||
# print(x.size())
|
||||
x = function.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2))
|
||||
|
||||
# block 4
|
||||
x = function.relu(self.bn6(self.conv6(x)))
|
||||
x = function.relu(self.bn7(self.conv7(x)))
|
||||
# print(x.size())
|
||||
x = function.max_pool2d(x, kernel_size=(1, 1), stride=(1, 1))
|
||||
|
||||
# block 5
|
||||
x = function.relu(self.bn8(self.conv8(x)))
|
||||
x = function.relu(self.bn9(self.conv9(x)))
|
||||
# print(x.size())
|
||||
x = function.max_pool2d(x, kernel_size=(1, 1), stride=(1, 1))
|
||||
# print(x.size())
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = x.reshape(x.size(0), x.size(1), -1)
|
||||
# x = x.permute(0, 2, 1)
|
||||
# x = x.view(x.size(0),x.size(1), -1)
|
||||
# print(x.size())
|
||||
|
||||
x = x.reshape(x.size(0), -1, x.size(3))
|
||||
x = x.permute(0, 2, 1)
|
||||
x = function.relu(self.dense0(x))
|
||||
# print(x.size())
|
||||
x = function.softmax(self.dense1(x))
|
||||
# print(x.size())
|
||||
x = function.log_softmax(self.dense1(x))
|
||||
return x
|
||||
|
||||
def compute_loss(self, y_pred, labels, input_length, label_length):
|
||||
|
@ -121,31 +123,6 @@ class SpeechModel251BN(nn.Module):
|
|||
loss = self.ctc_loss(y_pred, labels, input_length, label_length)
|
||||
return loss
|
||||
|
||||
def train_model(self, train_loader, optimizer, num_epochs=10, device='cpu'):
|
||||
self.to(device)
|
||||
self.train()
|
||||
|
||||
for epoch in tqdm(range(num_epochs)):
|
||||
epoch_loss = 0.0
|
||||
for batch in train_loader:
|
||||
inputs, labels, input_lengths, label_lengths = batch
|
||||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
input_lengths = input_lengths.to(device)
|
||||
label_lengths = label_lengths.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
y_pred = self.forward(inputs)
|
||||
|
||||
loss = self.compute_loss(y_pred, labels, input_lengths, label_lengths)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
|
||||
avg_loss = epoch_loss / len(train_loader)
|
||||
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}")
|
||||
|
||||
def get_model(self):
|
||||
return self
|
||||
|
||||
|
|
|
@ -24,22 +24,26 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.data import Dataset, DataLoader as TorchDataLoader
|
||||
|
||||
from data_loader import DataLoader
|
||||
from speech_features.speech_features import SpeechFeatureMeta
|
||||
|
||||
|
||||
class SpeechDataset(Dataset):
|
||||
def __init__(self, data_loader, speech_features, input_shape, max_label_length, device='cpu'):
|
||||
def __init__(self, data_loader, speech_features, input_shape, max_label_length):
|
||||
self.data_loader = data_loader
|
||||
self.input_shape = input_shape
|
||||
self.speech_features = speech_features
|
||||
self.max_label_length = max_label_length
|
||||
self.data_count = self.data_loader.get_data_count()
|
||||
self.device = device
|
||||
|
||||
def __len__(self):
|
||||
return self.data_count
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
wav_data, sample_rate, data_labels = self.data_loader.get_data(index)
|
||||
|
||||
|
@ -54,49 +58,74 @@ class SpeechDataset(Dataset):
|
|||
# 初始化输入特征数组,填充到 `input_shape` 大小
|
||||
x = torch.zeros(self.input_shape)
|
||||
x[:len(data_input)] = torch.tensor(data_input, dtype=torch.float32)
|
||||
x = x.permute(2, 0, 1)
|
||||
|
||||
# 初始化标签数组,填充到 `max_label_length` 大小
|
||||
y = torch.zeros(self.max_label_length, dtype=torch.int16)
|
||||
y[:len(data_labels)] = torch.tensor(data_labels, dtype=torch.int16)
|
||||
y[:len(data_labels)] = torch.tensor(data_labels, dtype=torch.int16) + 1
|
||||
|
||||
# 转换为 PyTorch 张量
|
||||
input_length = torch.tensor([inlen], dtype=torch.float32)
|
||||
label_length = torch.tensor([len(data_labels)], dtype=torch.float32)
|
||||
input_length = torch.tensor((inlen,), dtype=torch.float32)
|
||||
label_length = torch.tensor((len(data_labels),), dtype=torch.float32)
|
||||
return x, y, input_length, label_length
|
||||
|
||||
|
||||
class ModelSpeech:
|
||||
def __init__(self, speech_model, speech_features, max_label_length=64):
|
||||
def __init__(self, speech_model: torch.nn.Module, speech_features: SpeechFeatureMeta, max_label_length: int = 64):
|
||||
"""模型初始化"""
|
||||
self.speech_model = speech_model
|
||||
self.trained_model = speech_model.get_model()
|
||||
self.speech_features = speech_features
|
||||
self.max_label_length = max_label_length
|
||||
|
||||
def train(self, data_loader, epochs, batch_size, optimizer, save_step=1, last_epoch=0, device='cpu'):
|
||||
def train(self, data_loader: DataLoader, epochs: int, batch_size: int, optimizer: torch.optim.Optimizer,
|
||||
device: str = 'cpu'):
|
||||
"""训练模型"""
|
||||
save_filename = os.path.join('save_models_torch', self.speech_model.get_model_name() + '.pth')
|
||||
speechdata = SpeechDataset(data_loader, self.speech_features, input_shape=self.speech_model.input_shape,
|
||||
max_label_length=self.max_label_length)
|
||||
self.trained_model.to(device)
|
||||
print('[ASRT] torch model successfully initialized to device: {}'.format(device))
|
||||
data_loader = DataLoader(data_loader, batch_size=batch_size, shuffle=True)
|
||||
data_loader = TorchDataLoader(speechdata, batch_size=batch_size, shuffle=True)
|
||||
model = self.speech_model
|
||||
for epoch in range(epochs):
|
||||
print('[ASRT] Epoch {}/{}'.format(epoch+1, epochs))
|
||||
print('[ASRT] Epoch {}/{}'.format(epoch + 1, epochs))
|
||||
epoch_loss = 0.0
|
||||
iter_index = 0
|
||||
t0 = time.time()
|
||||
for batch in data_loader:
|
||||
x, y, input_length, label_length = batch
|
||||
x = x.to(device)
|
||||
y = y.to(device)
|
||||
input_length = input_length.to(device).unsqueeze(1).long()
|
||||
label_length = label_length.to(device).unsqueeze(1).long()
|
||||
input_length = input_length.to(device).long()
|
||||
label_length = label_length.to(device).long()
|
||||
|
||||
optimizer.zero_grad()
|
||||
y_pred = model(x)
|
||||
# print(y_pred.shape, y.shape, input_length.shape, label_length.shape)
|
||||
loss = model.compute_loss(y_pred, y, input_length, label_length)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
epoch_loss += loss.item()
|
||||
|
||||
avg_loss = epoch_loss / len(data_loader)
|
||||
print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")
|
||||
epoch_loss += loss.item()
|
||||
iter_index += 1
|
||||
t1 = time.time()
|
||||
predict_total_time = (t1-t0)*len(data_loader)/iter_index
|
||||
predict_remain_time = predict_total_time - (t1-t0)
|
||||
cur_batch_loss = loss.item()
|
||||
cur_avg_loss = epoch_loss / iter_index
|
||||
print("[ASRT]", f"{predict_remain_time:.2f}/{predict_total_time:.2f} s,",
|
||||
f"step {iter_index}/{len(data_loader)},", f"current loss: {cur_batch_loss:.4f}",
|
||||
f"avg loss: {cur_avg_loss:.4f}", end="\r")
|
||||
|
||||
save_filename = os.path.join('save_models_torch', f"{self.speech_model.get_model_name()}_epoch{epoch+1}.pth")
|
||||
self.save_weight(save_filename)
|
||||
avg_loss = epoch_loss / len(data_loader)
|
||||
total_time = time.time()-t0
|
||||
avg_time_per_step = total_time / len(data_loader)
|
||||
print("[ASRT]", f"epoch {epoch + 1}/{epochs},", f"time cost: {total_time:.2f} s,",
|
||||
f"{avg_time_per_step:.2f} s/step", f"avg loss: {avg_loss:.4f}")
|
||||
|
||||
def save_weight(self, filename: str):
|
||||
save_filename = os.path.join('save_models_torch', filename + ".pth")
|
||||
torch.save(self.speech_model.state_dict(), save_filename)
|
||||
|
||||
def load_weight(self, filepath: str):
|
||||
self.speech_model.load_state_dict(torch.load(filepath))
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2016-2099 Ailemon.net
|
||||
#
|
||||
# This file is part of ASRT Speech Recognition Tool.
|
||||
#
|
||||
# ASRT is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
# ASRT is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
@author: nl8590687
|
||||
pytorch版声学模型训练脚本入口
|
||||
"""
|
||||
from torch import optim
|
||||
|
||||
from torch_speech_model import *
|
||||
from speech_features import SpecAugment
|
||||
from data_loader import DataLoader
|
||||
from model_zoo.speech_model.pytorch_backend import SpeechModel251BN
|
||||
|
||||
if __name__ == "__main__":
|
||||
feat = SpecAugment()
|
||||
data_loader = DataLoader('train')
|
||||
|
||||
model = SpeechModel251BN()
|
||||
speechModel = ModelSpeech(model, feat, max_label_length=64)
|
||||
print(model)
|
||||
|
||||
# speechModel.load_weight(os.path.join('save_models_torch', model.get_model_name()+"_save.pth"))
|
||||
speechModel.train(data_loader, epochs=10, batch_size=16, optimizer=optim.Adam(model.parameters(), lr=0.001),
|
||||
device="cuda:0")
|
||||
speechModel.save_weight(model.get_model_name()+"_save")
|
Loading…
Reference in New Issue