diff --git a/train.py b/train.py index 7f3a66b..9d0471b 100644 --- a/train.py +++ b/train.py @@ -80,8 +80,14 @@ def train(opt): model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') + # Fine tunning 목적 if opt.FT: - model.load_state_dict(torch.load(opt.saved_model), strict=False) + checkpoint = torch.load(opt.saved_model) + checkpoint = {k: v for k, v in checkpoint.items() if (k in model.state_dict().keys()) and (model.state_dict()[k].shape == checkpoint[k].shape)} + for name in model.state_dict().keys() : + if name in checkpoint.keys() : + model.state_dict()[name].copy_(checkpoint[name]) + #model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:")