(ctc_a) (ctc_b) and fine tuning
This commit is contained in:
parent
e34c99c386
commit
6dc16df598
5
test.py
5
test.py
|
@ -92,12 +92,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
|||
# Calculate evaluation loss for CTC deocder.
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||
preds = preds.permute(1, 0, 2) # to use CTCloss format
|
||||
|
||||
# To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
|
||||
# https://github.com/jpuigcerver/PyLaia/issues/16
|
||||
torch.backends.cudnn.enabled = False
|
||||
cost = criterion(preds, text_for_loss, preds_size, length_for_loss)
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
# Select max probabilty (greedy decoding) then decode index to character
|
||||
_, preds_index = preds.max(2)
|
||||
|
|
29
train.py
29
train.py
|
@ -66,9 +66,12 @@ def train(opt):
|
|||
# data parallel for multi-GPU
|
||||
model = torch.nn.DataParallel(model).to(device)
|
||||
model.train()
|
||||
if opt.continue_model != '':
|
||||
print(f'loading pretrained model from {opt.continue_model}')
|
||||
model.load_state_dict(torch.load(opt.continue_model))
|
||||
if opt.saved_model != '':
|
||||
print(f'loading pretrained model from {opt.saved_model}')
|
||||
if opt.FT:
|
||||
model.load_state_dict(torch.load(opt.saved_model), strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch.load(opt.saved_model))
|
||||
print("Model:")
|
||||
print(model)
|
||||
|
||||
|
@ -110,8 +113,8 @@ def train(opt):
|
|||
|
||||
""" start training """
|
||||
start_iter = 0
|
||||
if opt.continue_model != '':
|
||||
start_iter = int(opt.continue_model.split('_')[-1].split('.')[0])
|
||||
if opt.saved_model != '':
|
||||
start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
|
||||
print(f'continue to train, start_iter: {start_iter}')
|
||||
|
||||
start_time = time.time()
|
||||
|
@ -128,15 +131,21 @@ def train(opt):
|
|||
|
||||
if 'CTC' in opt.Prediction:
|
||||
preds = model(image, text).log_softmax(2)
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device)
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||
preds = preds.permute(1, 0, 2) # to use CTCLoss format
|
||||
|
||||
# To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
|
||||
# (ctc_a) To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
|
||||
# https://github.com/jpuigcerver/PyLaia/issues/16
|
||||
torch.backends.cudnn.enabled = False
|
||||
cost = criterion(preds, text, preds_size, length)
|
||||
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
# # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
|
||||
# # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
|
||||
# # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0.
|
||||
# # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707
|
||||
# cost = criterion(preds, text, preds_size, length)
|
||||
|
||||
else:
|
||||
preds = model(image, text[:, :-1]) # align with Attention.forward
|
||||
target = text[:, 1:] # without [GO] Symbol
|
||||
|
@ -208,7 +217,8 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
|
||||
parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for')
|
||||
parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation')
|
||||
parser.add_argument('--continue_model', default='', help="path to model to continue training")
|
||||
parser.add_argument('--saved_model', default='', help="path to model to continue training")
|
||||
parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning')
|
||||
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)')
|
||||
parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta')
|
||||
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9')
|
||||
|
@ -271,6 +281,7 @@ if __name__ == '__main__':
|
|||
print('if you stuck too long time with multi-GPU setting, try to set --workers 0')
|
||||
# check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1
|
||||
opt.workers = opt.workers * opt.num_gpu
|
||||
opt.batch_size = opt.batch_size * opt.num_gpu
|
||||
|
||||
""" previous version
|
||||
print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size)
|
||||
|
|
2
utils.py
2
utils.py
|
@ -30,7 +30,7 @@ class CTCLabelConverter(object):
|
|||
text = ''.join(text)
|
||||
text = [self.dict[char] for char in text]
|
||||
|
||||
return (torch.IntTensor(text).to(device), torch.IntTensor(length).to(device))
|
||||
return (torch.IntTensor(text), torch.IntTensor(length))
|
||||
|
||||
def decode(self, text_index, length):
|
||||
""" convert text-index into text-label. """
|
||||
|
|
Loading…
Reference in New Issue