Merge pull request #209 from ku21fan/master
add Baidu warpctc option to reproduce CTC results of our paper.
This commit is contained in:
commit
3c2c89a88a
|
@ -11,6 +11,7 @@ Based on this framework, we recorded the 1st place of [ICDAR2013 focused scene t
|
|||
The difference between our paper and ICDAR challenge is summarized [here](https://github.com/clovaai/deep-text-recognition-benchmark/issues/13).
|
||||
|
||||
## Updates
|
||||
**Aug 3, 2020**: added [guideline to use Baidu warpctc](https://github.com/clovaai/deep-text-recognition-benchmark/pull/209) which reproduces CTC results of our paper. <br>
|
||||
**Dec 27, 2019**: added [FLOPS](https://github.com/clovaai/deep-text-recognition-benchmark/issues/125) in our paper, and minor updates such as log_dataset.txt and [ICDAR2019-NormalizedED](https://github.com/clovaai/deep-text-recognition-benchmark/blob/86451088248e0490ff8b5f74d33f7d014f6c249a/test.py#L139-L165). <br>
|
||||
**Oct 22, 2019**: added [confidence score](https://github.com/clovaai/deep-text-recognition-benchmark/issues/82), and arranged the output form of training logs. <br>
|
||||
**Jul 31, 2019**: The paper is accepted at International Conference on Computer Vision (ICCV), Seoul 2019, as an oral talk. <br>
|
||||
|
|
16
test.py
16
test.py
|
@ -23,6 +23,10 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa
|
|||
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857',
|
||||
'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
|
||||
|
||||
# # To easily compute the total accuracy of our paper.
|
||||
# eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867',
|
||||
# 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80']
|
||||
|
||||
if calculate_infer_time:
|
||||
evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image.
|
||||
else:
|
||||
|
@ -100,10 +104,17 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
|||
# Calculate evaluation loss for CTC deocder.
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||
# permute 'preds' to use CTCloss format
|
||||
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
|
||||
if opt.baiduCTC:
|
||||
cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) / batch_size
|
||||
else:
|
||||
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
|
||||
|
||||
# Select max probabilty (greedy decoding) then decode index to character
|
||||
_, preds_index = preds.max(2)
|
||||
if opt.baiduCTC:
|
||||
_, preds_index = preds.max(2)
|
||||
preds_index = preds_index.view(-1)
|
||||
else:
|
||||
_, preds_index = preds.max(2)
|
||||
preds_str = converter.decode(preds_index.data, preds_size.data)
|
||||
|
||||
else:
|
||||
|
@ -246,6 +257,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
|
||||
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
|
||||
parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode')
|
||||
parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode')
|
||||
""" Model Architecture """
|
||||
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
|
||||
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
|
||||
|
|
23
train.py
23
train.py
|
@ -12,7 +12,7 @@ import torch.optim as optim
|
|||
import torch.utils.data
|
||||
import numpy as np
|
||||
|
||||
from utils import CTCLabelConverter, AttnLabelConverter, Averager
|
||||
from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager
|
||||
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
|
||||
from model import Model
|
||||
from test import validation
|
||||
|
@ -45,7 +45,10 @@ def train(opt):
|
|||
|
||||
""" model configuration """
|
||||
if 'CTC' in opt.Prediction:
|
||||
converter = CTCLabelConverter(opt.character)
|
||||
if opt.baiduCTC:
|
||||
converter = CTCLabelConverterForBaiduWarpctc(opt.character)
|
||||
else:
|
||||
converter = CTCLabelConverter(opt.character)
|
||||
else:
|
||||
converter = AttnLabelConverter(opt.character)
|
||||
opt.num_class = len(converter.character)
|
||||
|
@ -86,7 +89,12 @@ def train(opt):
|
|||
|
||||
""" setup loss """
|
||||
if 'CTC' in opt.Prediction:
|
||||
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
|
||||
if opt.baiduCTC:
|
||||
# need to install warpctc. see our guideline.
|
||||
from warpctc_pytorch import CTCLoss
|
||||
criterion = CTCLoss()
|
||||
else:
|
||||
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
|
||||
else:
|
||||
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
|
||||
# loss averager
|
||||
|
@ -144,8 +152,12 @@ def train(opt):
|
|||
if 'CTC' in opt.Prediction:
|
||||
preds = model(image, text)
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||
preds = preds.log_softmax(2).permute(1, 0, 2)
|
||||
cost = criterion(preds, text, preds_size, length)
|
||||
if opt.baiduCTC:
|
||||
preds = preds.permute(1, 0, 2) # to use CTCLoss format
|
||||
cost = criterion(preds, text, preds_size, length) / batch_size
|
||||
else:
|
||||
preds = preds.log_softmax(2).permute(1, 0, 2)
|
||||
cost = criterion(preds, text, preds_size, length)
|
||||
|
||||
else:
|
||||
preds = model(image, text[:, :-1]) # align with Attention.forward
|
||||
|
@ -232,6 +244,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95')
|
||||
parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8')
|
||||
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5')
|
||||
parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode')
|
||||
""" Data processing """
|
||||
parser.add_argument('--select_data', type=str, default='MJ-ST',
|
||||
help='select training data (default is MJ-ST, which means MJ and ST used as training data)')
|
||||
|
|
47
utils.py
47
utils.py
|
@ -52,6 +52,53 @@ class CTCLabelConverter(object):
|
|||
return texts
|
||||
|
||||
|
||||
class CTCLabelConverterForBaiduWarpctc(object):
|
||||
""" Convert between text-label and text-index for baidu warpctc """
|
||||
|
||||
def __init__(self, character):
|
||||
# character (str): set of the possible characters.
|
||||
dict_character = list(character)
|
||||
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
|
||||
self.dict[char] = i + 1
|
||||
|
||||
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
|
||||
|
||||
def encode(self, text, batch_max_length=25):
|
||||
"""convert text-label into text-index.
|
||||
input:
|
||||
text: text labels of each image. [batch_size]
|
||||
output:
|
||||
text: concatenated text index for CTCLoss.
|
||||
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
||||
length: length of each text. [batch_size]
|
||||
"""
|
||||
length = [len(s) for s in text]
|
||||
text = ''.join(text)
|
||||
text = [self.dict[char] for char in text]
|
||||
|
||||
return (torch.IntTensor(text), torch.IntTensor(length))
|
||||
|
||||
def decode(self, text_index, length):
|
||||
""" convert text-index into text-label. """
|
||||
texts = []
|
||||
index = 0
|
||||
for l in length:
|
||||
t = text_index[index:index + l]
|
||||
|
||||
char_list = []
|
||||
for i in range(l):
|
||||
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
|
||||
char_list.append(self.character[t[i]])
|
||||
text = ''.join(char_list)
|
||||
|
||||
texts.append(text)
|
||||
index += l
|
||||
return texts
|
||||
|
||||
|
||||
class AttnLabelConverter(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
|
Loading…
Reference in New Issue