This commit is contained in:
yumjunstar 2023-08-18 02:33:32 +00:00
parent a8ab4b3274
commit 62c84d04b6
1 changed files with 7 additions and 1 deletions

View File

@ -80,8 +80,14 @@ def train(opt):
model.train()
if opt.saved_model != '':
print(f'loading pretrained model from {opt.saved_model}')
# Fine tunning 목적
if opt.FT:
model.load_state_dict(torch.load(opt.saved_model), strict=False)
checkpoint = torch.load(opt.saved_model)
checkpoint = {k: v for k, v in checkpoint.items() if (k in model.state_dict().keys()) and (model.state_dict()[k].shape == checkpoint[k].shape)}
for name in model.state_dict().keys() :
if name in checkpoint.keys() :
model.state_dict()[name].copy_(checkpoint[name])
#model.load_state_dict(torch.load(opt.saved_model), strict=False)
else:
model.load_state_dict(torch.load(opt.saved_model))
print("Model:")