add map location to torch.load when loading saved model
This commit is contained in:
parent
68a80fe979
commit
c6d7a29aa4
4
train.py
4
train.py
|
@ -81,9 +81,9 @@ def train(opt):
|
|||
if opt.saved_model != '':
|
||||
print(f'loading pretrained model from {opt.saved_model}')
|
||||
if opt.FT:
|
||||
model.load_state_dict(torch.load(opt.saved_model), strict=False)
|
||||
model.load_state_dict(torch.load(opt.saved_model, map_location=device), strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch.load(opt.saved_model))
|
||||
model.load_state_dict(torch.load(opt.saved_model, map_location=device))
|
||||
print("Model:")
|
||||
print(model)
|
||||
|
||||
|
|
Loading…
Reference in New Issue