From 62c84d04b6b6a2cf07c200141b6b77f5084e9ff5 Mon Sep 17 00:00:00 2001 From: yumjunstar Date: Fri, 18 Aug 2023 02:33:32 +0000 Subject: [PATCH] Char add --- train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 7f3a66b..9d0471b 100644 --- a/train.py +++ b/train.py @@ -80,8 +80,14 @@ def train(opt): model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') + # Fine tunning 목적 if opt.FT: - model.load_state_dict(torch.load(opt.saved_model), strict=False) + checkpoint = torch.load(opt.saved_model) + checkpoint = {k: v for k, v in checkpoint.items() if (k in model.state_dict().keys()) and (model.state_dict()[k].shape == checkpoint[k].shape)} + for name in model.state_dict().keys() : + if name in checkpoint.keys() : + model.state_dict()[name].copy_(checkpoint[name]) + #model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:")