Merge pull request #209 from ku21fan/master

add Baidu warpctc option to reproduce CTC results of our paper.
This commit is contained in:
gwkrsrch 2020-08-03 21:45:12 +09:00 committed by GitHub
commit 3c2c89a88a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 7 deletions

View File

@ -11,6 +11,7 @@ Based on this framework, we recorded the 1st place of [ICDAR2013 focused scene t
The difference between our paper and ICDAR challenge is summarized [here](https://github.com/clovaai/deep-text-recognition-benchmark/issues/13).
## Updates
**Aug 3, 2020**: added [guideline to use Baidu warpctc](https://github.com/clovaai/deep-text-recognition-benchmark/pull/209) which reproduces CTC results of our paper. <br>
**Dec 27, 2019**: added [FLOPS](https://github.com/clovaai/deep-text-recognition-benchmark/issues/125) in our paper, and minor updates such as log_dataset.txt and [ICDAR2019-NormalizedED](https://github.com/clovaai/deep-text-recognition-benchmark/blob/86451088248e0490ff8b5f74d33f7d014f6c249a/test.py#L139-L165). <br>
**Oct 22, 2019**: added [confidence score](https://github.com/clovaai/deep-text-recognition-benchmark/issues/82), and arranged the output form of training logs. <br>
**Jul 31, 2019**: The paper is accepted at International Conference on Computer Vision (ICCV), Seoul 2019, as an oral talk. <br>

16
test.py
View File

@ -23,6 +23,10 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857',
'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
# # To easily compute the total accuracy of our paper.
# eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867',
# 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80']
if calculate_infer_time:
evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image.
else:
@ -100,10 +104,17 @@ def validation(model, criterion, evaluation_loader, converter, opt):
# Calculate evaluation loss for CTC deocder.
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
# permute 'preds' to use CTCloss format
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
if opt.baiduCTC:
cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) / batch_size
else:
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
# Select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
if opt.baiduCTC:
_, preds_index = preds.max(2)
preds_index = preds_index.view(-1)
else:
_, preds_index = preds.max(2)
preds_str = converter.decode(preds_index.data, preds_size.data)
else:
@ -246,6 +257,7 @@ if __name__ == '__main__':
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode')
parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode')
""" Model Architecture """
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')

View File

@ -12,7 +12,7 @@ import torch.optim as optim
import torch.utils.data
import numpy as np
from utils import CTCLabelConverter, AttnLabelConverter, Averager
from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
from model import Model
from test import validation
@ -45,7 +45,10 @@ def train(opt):
""" model configuration """
if 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character)
if opt.baiduCTC:
converter = CTCLabelConverterForBaiduWarpctc(opt.character)
else:
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
@ -86,7 +89,12 @@ def train(opt):
""" setup loss """
if 'CTC' in opt.Prediction:
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
if opt.baiduCTC:
# need to install warpctc. see our guideline.
from warpctc_pytorch import CTCLoss
criterion = CTCLoss()
else:
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
else:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
# loss averager
@ -144,8 +152,12 @@ def train(opt):
if 'CTC' in opt.Prediction:
preds = model(image, text)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.log_softmax(2).permute(1, 0, 2)
cost = criterion(preds, text, preds_size, length)
if opt.baiduCTC:
preds = preds.permute(1, 0, 2) # to use CTCLoss format
cost = criterion(preds, text, preds_size, length) / batch_size
else:
preds = preds.log_softmax(2).permute(1, 0, 2)
cost = criterion(preds, text, preds_size, length)
else:
preds = model(image, text[:, :-1]) # align with Attention.forward
@ -232,6 +244,7 @@ if __name__ == '__main__':
parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95')
parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5')
parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode')
""" Data processing """
parser.add_argument('--select_data', type=str, default='MJ-ST',
help='select training data (default is MJ-ST, which means MJ and ST used as training data)')

View File

@ -52,6 +52,53 @@ class CTCLabelConverter(object):
return texts
class CTCLabelConverterForBaiduWarpctc(object):
""" Convert between text-label and text-index for baidu warpctc """
def __init__(self, character):
# character (str): set of the possible characters.
dict_character = list(character)
self.dict = {}
for i, char in enumerate(dict_character):
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
self.dict[char] = i + 1
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
def encode(self, text, batch_max_length=25):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
text = ''.join(text)
text = [self.dict[char] for char in text]
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
index = 0
for l in length:
t = text_index[index:index + l]
char_list = []
for i in range(l):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)
texts.append(text)
index += l
return texts
class AttnLabelConverter(object):
""" Convert between text-label and text-index """