From c6d7a29aa46f14fa0ab857d5976930f0d559f6e6 Mon Sep 17 00:00:00 2001 From: korotaS Date: Mon, 25 Jan 2021 12:45:16 +0300 Subject: [PATCH] add map location to torch.load when loading saved model --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 7f3a66b..f15a819 100644 --- a/train.py +++ b/train.py @@ -81,9 +81,9 @@ def train(opt): if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: - model.load_state_dict(torch.load(opt.saved_model), strict=False) + model.load_state_dict(torch.load(opt.saved_model, map_location=device), strict=False) else: - model.load_state_dict(torch.load(opt.saved_model)) + model.load_state_dict(torch.load(opt.saved_model, map_location=device)) print("Model:") print(model)