diff --git a/demo.py b/demo.py index 571ec89..ed679dd 100755 --- a/demo.py +++ b/demo.py @@ -56,8 +56,8 @@ def demo(opt): # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) - _, preds_index = preds.permute(1, 0, 2).max(2) - preds_index = preds_index.transpose(1, 0).contiguous().view(-1) + _, preds_index = preds.max(2) + preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: diff --git a/test.py b/test.py index eb90eda..a215825 100755 --- a/test.py +++ b/test.py @@ -97,7 +97,7 @@ def validation(model, criterion, evaluation_loader, converter, opt): # Select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) - preds_index = preds_index.transpose(1, 0).contiguous().view(-1) + preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: diff --git a/train.py b/train.py index 35d0776..d3e21d9 100755 --- a/train.py +++ b/train.py @@ -132,13 +132,14 @@ def train(opt): if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) - preds = preds.permute(1, 0, 2) # to use CTCLoss format + # permute 'preds' to use CTCloss format + cost = criterion(preds.permute(1, 0, 2), text.to(device), preds_size.to(device), length.to(device)) # For PyTorch 1.3.0 - # (ctc_a) To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss + # (ctc_a) For PyTorch 1.2.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 - torch.backends.cudnn.enabled = False - cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) - torch.backends.cudnn.enabled = True + # torch.backends.cudnn.enabled = False + # cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) + # torch.backends.cudnn.enabled = True # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a). # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.