Fix typo modle to model
This commit is contained in:
parent
002263e07f
commit
55ced5c633
|
@ -9,9 +9,9 @@ from models.data_parallel import DataParallel
|
||||||
from utils.utils import AverageMeter
|
from utils.utils import AverageMeter
|
||||||
|
|
||||||
|
|
||||||
class ModleWithLoss(torch.nn.Module):
|
class ModelWithLoss(torch.nn.Module):
|
||||||
def __init__(self, model, loss):
|
def __init__(self, model, loss):
|
||||||
super(ModleWithLoss, self).__init__()
|
super(ModelWithLoss, self).__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class BaseTrainer(object):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.loss_stats, self.loss = self._get_losses(opt)
|
self.loss_stats, self.loss = self._get_losses(opt)
|
||||||
self.model_with_loss = ModleWithLoss(model, self.loss)
|
self.model_with_loss = ModelWithLoss(model, self.loss)
|
||||||
|
|
||||||
def set_device(self, gpus, chunk_sizes, device):
|
def set_device(self, gpus, chunk_sizes, device):
|
||||||
if len(gpus) > 1:
|
if len(gpus) > 1:
|
||||||
|
|
Loading…
Reference in New Issue