diff --git a/train.py b/train.py index 7f3a66b..f15a819 100644 --- a/train.py +++ b/train.py @@ -81,9 +81,9 @@ def train(opt): if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: - model.load_state_dict(torch.load(opt.saved_model), strict=False) + model.load_state_dict(torch.load(opt.saved_model, map_location=device), strict=False) else: - model.load_state_dict(torch.load(opt.saved_model)) + model.load_state_dict(torch.load(opt.saved_model, map_location=device)) print("Model:") print(model)