'fix_ctc_loss_issue'

This commit is contained in:
Baek JeongHun 2019-08-07 16:42:42 +09:00
parent 8f7255fb4b
commit 1c6efa5218
3 changed files with 12 additions and 2 deletions

View File

@ -92,7 +92,12 @@ def validation(model, criterion, evaluation_loader, converter, opt):
# Calculate evaluation loss for CTC deocder.
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2) # to use CTCloss format
# To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# https://github.com/jpuigcerver/PyLaia/issues/16
torch.backends.cudnn.enabled = False
cost = criterion(preds, text_for_loss, preds_size, length_for_loss)
torch.backends.cudnn.enabled = True
# Select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)

View File

@ -128,9 +128,14 @@ def train(opt):
if 'CTC' in opt.Prediction:
preds = model(image, text).log_softmax(2)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device)
preds = preds.permute(1, 0, 2) # to use CTCLoss format
# To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# https://github.com/jpuigcerver/PyLaia/issues/16
torch.backends.cudnn.enabled = False
cost = criterion(preds, text, preds_size, length)
torch.backends.cudnn.enabled = True
else:
preds = model(image, text[:, :-1]) # align with Attention.forward

View File

@ -30,7 +30,7 @@ class CTCLabelConverter(object):
text = ''.join(text)
text = [self.dict[char] for char in text]
return (torch.IntTensor(text), torch.IntTensor(length))
return (torch.IntTensor(text).to(device), torch.IntTensor(length).to(device))
def decode(self, text_index, length):
""" convert text-index into text-label. """