ctc update

This commit is contained in:
Baek JeongHun 2019-10-22 17:26:58 +00:00
parent ec5319296b
commit 8749902367
3 changed files with 9 additions and 8 deletions

View File

@ -56,8 +56,8 @@ def demo(opt):
# Select max probabilty (greedy decoding) then decode index to character
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
_, preds_index = preds.permute(1, 0, 2).max(2)
preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
_, preds_index = preds.max(2)
preds_index = preds_index.view(-1)
preds_str = converter.decode(preds_index.data, preds_size.data)
else:

View File

@ -97,7 +97,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
# Select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
preds_index = preds_index.view(-1)
preds_str = converter.decode(preds_index.data, preds_size.data)
else:

View File

@ -132,13 +132,14 @@ def train(opt):
if 'CTC' in opt.Prediction:
preds = model(image, text).log_softmax(2)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2) # to use CTCLoss format
# permute 'preds' to use CTCloss format
cost = criterion(preds.permute(1, 0, 2), text.to(device), preds_size.to(device), length.to(device)) # For PyTorch 1.3.0
# (ctc_a) To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# (ctc_a) For PyTorch 1.2.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# https://github.com/jpuigcerver/PyLaia/issues/16
torch.backends.cudnn.enabled = False
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
torch.backends.cudnn.enabled = True
# torch.backends.cudnn.enabled = False
# cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
# torch.backends.cudnn.enabled = True
# # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
# # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.