(ctc_a) (ctc_b) and fine tuning

This commit is contained in:
Baek JeongHun 2019-10-07 05:33:44 +00:00
parent e34c99c386
commit 6dc16df598
3 changed files with 21 additions and 15 deletions

View File

@ -92,12 +92,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
# Calculate evaluation loss for CTC deocder.
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2) # to use CTCloss format
# To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# https://github.com/jpuigcerver/PyLaia/issues/16
torch.backends.cudnn.enabled = False
cost = criterion(preds, text_for_loss, preds_size, length_for_loss)
torch.backends.cudnn.enabled = True
# Select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)

View File

@ -66,9 +66,12 @@ def train(opt):
# data parallel for multi-GPU
model = torch.nn.DataParallel(model).to(device)
model.train()
if opt.continue_model != '':
print(f'loading pretrained model from {opt.continue_model}')
model.load_state_dict(torch.load(opt.continue_model))
if opt.saved_model != '':
print(f'loading pretrained model from {opt.saved_model}')
if opt.FT:
model.load_state_dict(torch.load(opt.saved_model), strict=False)
else:
model.load_state_dict(torch.load(opt.saved_model))
print("Model:")
print(model)
@ -110,8 +113,8 @@ def train(opt):
""" start training """
start_iter = 0
if opt.continue_model != '':
start_iter = int(opt.continue_model.split('_')[-1].split('.')[0])
if opt.saved_model != '':
start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
print(f'continue to train, start_iter: {start_iter}')
start_time = time.time()
@ -128,15 +131,21 @@ def train(opt):
if 'CTC' in opt.Prediction:
preds = model(image, text).log_softmax(2)
preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2) # to use CTCLoss format
# To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# (ctc_a) To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# https://github.com/jpuigcerver/PyLaia/issues/16
torch.backends.cudnn.enabled = False
cost = criterion(preds, text, preds_size, length)
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
torch.backends.cudnn.enabled = True
# # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
# # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
# # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0.
# # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707
# cost = criterion(preds, text, preds_size, length)
else:
preds = model(image, text[:, :-1]) # align with Attention.forward
target = text[:, 1:] # without [GO] Symbol
@ -208,7 +217,8 @@ if __name__ == '__main__':
parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for')
parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation')
parser.add_argument('--continue_model', default='', help="path to model to continue training")
parser.add_argument('--saved_model', default='', help="path to model to continue training")
parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)')
parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta')
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9')
@ -271,6 +281,7 @@ if __name__ == '__main__':
print('if you stuck too long time with multi-GPU setting, try to set --workers 0')
# check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1
opt.workers = opt.workers * opt.num_gpu
opt.batch_size = opt.batch_size * opt.num_gpu
""" previous version
print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size)

View File

@ -30,7 +30,7 @@ class CTCLabelConverter(object):
text = ''.join(text)
text = [self.dict[char] for char in text]
return (torch.IntTensor(text).to(device), torch.IntTensor(length).to(device))
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, text_index, length):
""" convert text-index into text-label. """