From 87499023673aca52e704c7dba0d200983e1c77a2 Mon Sep 17 00:00:00 2001 From: Baek JeongHun Date: Tue, 22 Oct 2019 17:26:58 +0000 Subject: [PATCH] ctc update --- demo.py | 4 ++-- test.py | 2 +- train.py | 11 ++++++----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/demo.py b/demo.py index 571ec89..ed679dd 100755 --- a/demo.py +++ b/demo.py @@ -56,8 +56,8 @@ def demo(opt): # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) - _, preds_index = preds.permute(1, 0, 2).max(2) - preds_index = preds_index.transpose(1, 0).contiguous().view(-1) + _, preds_index = preds.max(2) + preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: diff --git a/test.py b/test.py index eb90eda..a215825 100755 --- a/test.py +++ b/test.py @@ -97,7 +97,7 @@ def validation(model, criterion, evaluation_loader, converter, opt): # Select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) - preds_index = preds_index.transpose(1, 0).contiguous().view(-1) + preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: diff --git a/train.py b/train.py index 35d0776..d3e21d9 100755 --- a/train.py +++ b/train.py @@ -132,13 +132,14 @@ def train(opt): if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) - preds = preds.permute(1, 0, 2) # to use CTCLoss format + # permute 'preds' to use CTCloss format + cost = criterion(preds.permute(1, 0, 2), text.to(device), preds_size.to(device), length.to(device)) # For PyTorch 1.3.0 - # (ctc_a) To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss + # (ctc_a) For PyTorch 1.2.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 - torch.backends.cudnn.enabled = False - cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) - torch.backends.cudnn.enabled = True + # torch.backends.cudnn.enabled = False + # cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) + # torch.backends.cudnn.enabled = True # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a). # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.