'fix_ctc_loss_issue'
This commit is contained in:
parent
8f7255fb4b
commit
1c6efa5218
5
test.py
5
test.py
|
@ -92,7 +92,12 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
|||
# Calculate evaluation loss for CTC deocder.
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||
preds = preds.permute(1, 0, 2) # to use CTCloss format
|
||||
|
||||
# To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
|
||||
# https://github.com/jpuigcerver/PyLaia/issues/16
|
||||
torch.backends.cudnn.enabled = False
|
||||
cost = criterion(preds, text_for_loss, preds_size, length_for_loss)
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
# Select max probabilty (greedy decoding) then decode index to character
|
||||
_, preds_index = preds.max(2)
|
||||
|
|
7
train.py
7
train.py
|
@ -128,9 +128,14 @@ def train(opt):
|
|||
|
||||
if 'CTC' in opt.Prediction:
|
||||
preds = model(image, text).log_softmax(2)
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device)
|
||||
preds = preds.permute(1, 0, 2) # to use CTCLoss format
|
||||
|
||||
# To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
|
||||
# https://github.com/jpuigcerver/PyLaia/issues/16
|
||||
torch.backends.cudnn.enabled = False
|
||||
cost = criterion(preds, text, preds_size, length)
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
else:
|
||||
preds = model(image, text[:, :-1]) # align with Attention.forward
|
||||
|
|
2
utils.py
2
utils.py
|
@ -30,7 +30,7 @@ class CTCLabelConverter(object):
|
|||
text = ''.join(text)
|
||||
text = [self.dict[char] for char in text]
|
||||
|
||||
return (torch.IntTensor(text), torch.IntTensor(length))
|
||||
return (torch.IntTensor(text).to(device), torch.IntTensor(length).to(device))
|
||||
|
||||
def decode(self, text_index, length):
|
||||
""" convert text-index into text-label. """
|
||||
|
|
Loading…
Reference in New Issue