diff --git a/src/lib/trains/base_trainer.py b/src/lib/trains/base_trainer.py index 3ff6efb..1f29865 100755 --- a/src/lib/trains/base_trainer.py +++ b/src/lib/trains/base_trainer.py @@ -9,9 +9,9 @@ from models.data_parallel import DataParallel from utils.utils import AverageMeter -class ModleWithLoss(torch.nn.Module): +class ModelWithLoss(torch.nn.Module): def __init__(self, model, loss): - super(ModleWithLoss, self).__init__() + super(ModelWithLoss, self).__init__() self.model = model self.loss = loss @@ -26,7 +26,7 @@ class BaseTrainer(object): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) - self.model_with_loss = ModleWithLoss(model, self.loss) + self.model_with_loss = ModelWithLoss(model, self.loss) def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: