ctc update
This commit is contained in:
parent
ec5319296b
commit
8749902367
4
demo.py
4
demo.py
|
@ -56,8 +56,8 @@ def demo(opt):
|
||||||
|
|
||||||
# Select max probabilty (greedy decoding) then decode index to character
|
# Select max probabilty (greedy decoding) then decode index to character
|
||||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||||
_, preds_index = preds.permute(1, 0, 2).max(2)
|
_, preds_index = preds.max(2)
|
||||||
preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
|
preds_index = preds_index.view(-1)
|
||||||
preds_str = converter.decode(preds_index.data, preds_size.data)
|
preds_str = converter.decode(preds_index.data, preds_size.data)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
2
test.py
2
test.py
|
@ -97,7 +97,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
||||||
|
|
||||||
# Select max probabilty (greedy decoding) then decode index to character
|
# Select max probabilty (greedy decoding) then decode index to character
|
||||||
_, preds_index = preds.max(2)
|
_, preds_index = preds.max(2)
|
||||||
preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
|
preds_index = preds_index.view(-1)
|
||||||
preds_str = converter.decode(preds_index.data, preds_size.data)
|
preds_str = converter.decode(preds_index.data, preds_size.data)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
11
train.py
11
train.py
|
@ -132,13 +132,14 @@ def train(opt):
|
||||||
if 'CTC' in opt.Prediction:
|
if 'CTC' in opt.Prediction:
|
||||||
preds = model(image, text).log_softmax(2)
|
preds = model(image, text).log_softmax(2)
|
||||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||||
preds = preds.permute(1, 0, 2) # to use CTCLoss format
|
# permute 'preds' to use CTCloss format
|
||||||
|
cost = criterion(preds.permute(1, 0, 2), text.to(device), preds_size.to(device), length.to(device)) # For PyTorch 1.3.0
|
||||||
|
|
||||||
# (ctc_a) To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
|
# (ctc_a) For PyTorch 1.2.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
|
||||||
# https://github.com/jpuigcerver/PyLaia/issues/16
|
# https://github.com/jpuigcerver/PyLaia/issues/16
|
||||||
torch.backends.cudnn.enabled = False
|
# torch.backends.cudnn.enabled = False
|
||||||
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
|
# cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
|
||||||
torch.backends.cudnn.enabled = True
|
# torch.backends.cudnn.enabled = True
|
||||||
|
|
||||||
# # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
|
# # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
|
||||||
# # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
|
# # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
|
||||||
|
|
Loading…
Reference in New Issue