add map location to torch.load when loading saved model

This commit is contained in:
korotaS 2021-01-25 12:45:16 +03:00
parent 68a80fe979
commit c6d7a29aa4
1 changed files with 2 additions and 2 deletions

View File

@ -81,9 +81,9 @@ def train(opt):
if opt.saved_model != '':
print(f'loading pretrained model from {opt.saved_model}')
if opt.FT:
model.load_state_dict(torch.load(opt.saved_model), strict=False)
model.load_state_dict(torch.load(opt.saved_model, map_location=device), strict=False)
else:
model.load_state_dict(torch.load(opt.saved_model))
model.load_state_dict(torch.load(opt.saved_model, map_location=device))
print("Model:")
print(model)