2019-04-05 18:45:29 +08:00
|
|
|
import os
|
|
|
|
import time
|
2019-05-09 11:21:18 +08:00
|
|
|
import string
|
2019-04-05 18:45:29 +08:00
|
|
|
import argparse
|
2020-02-25 19:06:34 +08:00
|
|
|
import re
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
import torch.utils.data
|
2019-10-22 19:46:45 +08:00
|
|
|
import torch.nn.functional as F
|
2019-04-05 18:45:29 +08:00
|
|
|
import numpy as np
|
|
|
|
from nltk.metrics.distance import edit_distance
|
|
|
|
|
|
|
|
from utils import CTCLabelConverter, AttnLabelConverter, Averager
|
|
|
|
from dataset import hierarchical_dataset, AlignCollate
|
|
|
|
from model import Model
|
2019-08-03 16:03:46 +08:00
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False):
|
|
|
|
""" evaluation with 10 benchmark evaluation datasets """
|
|
|
|
# The evaluation datasets, dataset order is same with Table 1 in our paper.
|
|
|
|
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857',
|
|
|
|
'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
|
|
|
|
|
|
|
|
if calculate_infer_time:
|
|
|
|
evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image.
|
|
|
|
else:
|
|
|
|
evaluation_batch_size = opt.batch_size
|
|
|
|
|
2019-05-09 11:21:18 +08:00
|
|
|
list_accuracy = []
|
|
|
|
total_forward_time = 0
|
|
|
|
total_evaluation_data_number = 0
|
|
|
|
total_correct_number = 0
|
2019-12-27 18:31:47 +08:00
|
|
|
log = open(f'./result/{opt.experiment_name}/log_all_evaluation.txt', 'a')
|
|
|
|
dashed_line = '-' * 80
|
|
|
|
print(dashed_line)
|
|
|
|
log.write(dashed_line + '\n')
|
2019-04-05 18:45:29 +08:00
|
|
|
for eval_data in eval_data_list:
|
|
|
|
eval_data_path = os.path.join(opt.eval_data, eval_data)
|
2019-05-10 10:11:06 +08:00
|
|
|
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
|
2019-12-27 18:31:47 +08:00
|
|
|
eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt)
|
2019-04-05 18:45:29 +08:00
|
|
|
evaluation_loader = torch.utils.data.DataLoader(
|
|
|
|
eval_data, batch_size=evaluation_batch_size,
|
|
|
|
shuffle=False,
|
|
|
|
num_workers=int(opt.workers),
|
|
|
|
collate_fn=AlignCollate_evaluation, pin_memory=True)
|
|
|
|
|
2019-10-22 21:58:23 +08:00
|
|
|
_, accuracy_by_best_model, norm_ED_by_best_model, _, _, _, infer_time, length_of_data = validation(
|
2019-04-05 18:45:29 +08:00
|
|
|
model, criterion, evaluation_loader, converter, opt)
|
|
|
|
list_accuracy.append(f'{accuracy_by_best_model:0.3f}')
|
2019-05-09 11:21:18 +08:00
|
|
|
total_forward_time += infer_time
|
|
|
|
total_evaluation_data_number += len(eval_data)
|
|
|
|
total_correct_number += accuracy_by_best_model * length_of_data
|
2019-12-27 18:31:47 +08:00
|
|
|
log.write(eval_data_log)
|
|
|
|
print(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}')
|
|
|
|
log.write(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}\n')
|
|
|
|
print(dashed_line)
|
|
|
|
log.write(dashed_line + '\n')
|
2019-04-05 18:45:29 +08:00
|
|
|
|
2019-05-09 11:21:18 +08:00
|
|
|
averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000
|
|
|
|
total_accuracy = total_correct_number / total_evaluation_data_number
|
2019-04-05 18:45:29 +08:00
|
|
|
params_num = sum([np.prod(p.size()) for p in model.parameters()])
|
|
|
|
|
|
|
|
evaluation_log = 'accuracy: '
|
|
|
|
for name, accuracy in zip(eval_data_list, list_accuracy):
|
|
|
|
evaluation_log += f'{name}: {accuracy}\t'
|
2019-05-09 11:21:18 +08:00
|
|
|
evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t'
|
|
|
|
evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}'
|
2019-04-05 18:45:29 +08:00
|
|
|
print(evaluation_log)
|
2019-12-27 18:31:47 +08:00
|
|
|
log.write(evaluation_log + '\n')
|
|
|
|
log.close()
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def validation(model, criterion, evaluation_loader, converter, opt):
|
|
|
|
""" validation or evaluation """
|
|
|
|
n_correct = 0
|
|
|
|
norm_ED = 0
|
|
|
|
length_of_data = 0
|
|
|
|
infer_time = 0
|
|
|
|
valid_loss_avg = Averager()
|
|
|
|
|
2019-05-17 21:44:38 +08:00
|
|
|
for i, (image_tensors, labels) in enumerate(evaluation_loader):
|
|
|
|
batch_size = image_tensors.size(0)
|
2019-04-05 18:45:29 +08:00
|
|
|
length_of_data = length_of_data + batch_size
|
2019-08-03 16:03:46 +08:00
|
|
|
image = image_tensors.to(device)
|
2019-08-03 14:55:32 +08:00
|
|
|
# For max length prediction
|
2019-08-03 16:03:46 +08:00
|
|
|
length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
|
|
|
|
text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
|
2019-04-05 18:45:29 +08:00
|
|
|
|
2019-08-03 14:55:32 +08:00
|
|
|
text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
if 'CTC' in opt.Prediction:
|
2019-12-27 18:31:47 +08:00
|
|
|
preds = model(image, text_for_pred)
|
2019-04-05 18:45:29 +08:00
|
|
|
forward_time = time.time() - start_time
|
|
|
|
|
|
|
|
# Calculate evaluation loss for CTC deocder.
|
|
|
|
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
2019-10-22 21:58:23 +08:00
|
|
|
# permute 'preds' to use CTCloss format
|
2019-12-27 18:31:47 +08:00
|
|
|
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
# Select max probabilty (greedy decoding) then decode index to character
|
2019-05-17 21:44:38 +08:00
|
|
|
_, preds_index = preds.max(2)
|
2019-10-23 01:26:58 +08:00
|
|
|
preds_index = preds_index.view(-1)
|
2019-05-17 21:44:38 +08:00
|
|
|
preds_str = converter.decode(preds_index.data, preds_size.data)
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
else:
|
2019-04-07 20:41:11 +08:00
|
|
|
preds = model(image, text_for_pred, is_train=False)
|
2019-04-05 18:45:29 +08:00
|
|
|
forward_time = time.time() - start_time
|
|
|
|
|
|
|
|
preds = preds[:, :text_for_loss.shape[1] - 1, :]
|
|
|
|
target = text_for_loss[:, 1:] # without [GO] Symbol
|
|
|
|
cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))
|
|
|
|
|
|
|
|
# select max probabilty (greedy decoding) then decode index to character
|
|
|
|
_, preds_index = preds.max(2)
|
2019-05-17 21:44:38 +08:00
|
|
|
preds_str = converter.decode(preds_index, length_for_pred)
|
|
|
|
labels = converter.decode(text_for_loss[:, 1:], length_for_loss)
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
infer_time += forward_time
|
|
|
|
valid_loss_avg.add(cost)
|
|
|
|
|
2019-10-22 19:46:45 +08:00
|
|
|
# calculate accuracy & confidence score
|
|
|
|
preds_prob = F.softmax(preds, dim=2)
|
|
|
|
preds_max_prob, _ = preds_prob.max(dim=2)
|
|
|
|
confidence_score_list = []
|
2019-10-22 21:58:23 +08:00
|
|
|
for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
|
2019-05-17 21:44:38 +08:00
|
|
|
if 'Attn' in opt.Prediction:
|
2019-10-22 21:58:23 +08:00
|
|
|
gt = gt[:gt.find('[s]')]
|
2019-10-22 19:46:45 +08:00
|
|
|
pred_EOS = pred.find('[s]')
|
|
|
|
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
|
|
|
|
pred_max_prob = pred_max_prob[:pred_EOS]
|
2019-04-05 18:45:29 +08:00
|
|
|
|
2020-02-25 19:07:49 +08:00
|
|
|
# To evaluate 'case sensitive model' with alphanumeric and case insensitve setting.
|
2020-02-25 19:06:34 +08:00
|
|
|
if opt.sensitive and opt.data_filtering_off:
|
|
|
|
pred = pred.lower()
|
|
|
|
gt = gt.lower()
|
|
|
|
alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
|
|
|
|
out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
|
|
|
|
pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
|
2020-02-25 19:13:06 +08:00
|
|
|
gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)
|
2020-02-25 19:06:34 +08:00
|
|
|
|
2019-04-05 18:45:29 +08:00
|
|
|
if pred == gt:
|
|
|
|
n_correct += 1
|
2019-12-27 18:31:47 +08:00
|
|
|
|
|
|
|
'''
|
|
|
|
(old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
|
2020-02-25 19:06:34 +08:00
|
|
|
"For each word we calculate the normalized edit distance to the length of the ground truth transcription."
|
2019-07-01 20:08:07 +08:00
|
|
|
if len(gt) == 0:
|
|
|
|
norm_ED += 1
|
|
|
|
else:
|
|
|
|
norm_ED += edit_distance(pred, gt) / len(gt)
|
2019-12-27 18:31:47 +08:00
|
|
|
'''
|
2020-02-25 19:06:34 +08:00
|
|
|
|
|
|
|
# ICDAR2019 Normalized Edit Distance
|
|
|
|
if len(gt) == 0 or len(pred) == 0:
|
2019-12-27 18:31:47 +08:00
|
|
|
norm_ED += 0
|
|
|
|
elif len(gt) > len(pred):
|
|
|
|
norm_ED += 1 - edit_distance(pred, gt) / len(gt)
|
|
|
|
else:
|
|
|
|
norm_ED += 1 - edit_distance(pred, gt) / len(pred)
|
2019-04-05 18:45:29 +08:00
|
|
|
|
2019-10-22 19:46:45 +08:00
|
|
|
# calculate confidence score (= multiply of pred_max_prob)
|
2019-10-22 21:58:23 +08:00
|
|
|
try:
|
|
|
|
confidence_score = pred_max_prob.cumprod(dim=0)[-1]
|
|
|
|
except:
|
|
|
|
confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s])
|
2019-10-22 19:46:45 +08:00
|
|
|
confidence_score_list.append(confidence_score)
|
|
|
|
# print(pred, gt, pred==gt, confidence_score)
|
|
|
|
|
2019-04-05 18:45:29 +08:00
|
|
|
accuracy = n_correct / float(length_of_data) * 100
|
2020-02-25 19:06:34 +08:00
|
|
|
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
|
2019-04-05 18:45:29 +08:00
|
|
|
|
2019-10-22 21:58:23 +08:00
|
|
|
return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test(opt):
|
|
|
|
""" model configuration """
|
|
|
|
if 'CTC' in opt.Prediction:
|
|
|
|
converter = CTCLabelConverter(opt.character)
|
|
|
|
else:
|
|
|
|
converter = AttnLabelConverter(opt.character)
|
2019-05-09 11:21:18 +08:00
|
|
|
opt.num_class = len(converter.character)
|
|
|
|
|
2019-05-08 14:16:45 +08:00
|
|
|
if opt.rgb:
|
|
|
|
opt.input_channel = 3
|
2019-04-09 17:06:32 +08:00
|
|
|
model = Model(opt)
|
|
|
|
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
|
2019-04-05 18:45:29 +08:00
|
|
|
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
|
|
|
|
opt.SequenceModeling, opt.Prediction)
|
2019-08-03 16:03:46 +08:00
|
|
|
model = torch.nn.DataParallel(model).to(device)
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
# load model
|
2019-05-17 21:44:38 +08:00
|
|
|
print('loading pretrained model from %s' % opt.saved_model)
|
2019-11-11 19:36:06 +08:00
|
|
|
model.load_state_dict(torch.load(opt.saved_model, map_location=device))
|
2019-05-17 21:44:38 +08:00
|
|
|
opt.experiment_name = '_'.join(opt.saved_model.split('/')[1:])
|
2019-04-05 18:45:29 +08:00
|
|
|
# print(model)
|
|
|
|
|
|
|
|
""" keep evaluation model and result logs """
|
2019-04-09 14:06:26 +08:00
|
|
|
os.makedirs(f'./result/{opt.experiment_name}', exist_ok=True)
|
|
|
|
os.system(f'cp {opt.saved_model} ./result/{opt.experiment_name}/')
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
""" setup loss """
|
|
|
|
if 'CTC' in opt.Prediction:
|
2019-08-03 16:03:46 +08:00
|
|
|
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
|
2019-04-05 18:45:29 +08:00
|
|
|
else:
|
2019-08-03 16:03:46 +08:00
|
|
|
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
""" evaluation """
|
|
|
|
model.eval()
|
2019-08-03 14:55:32 +08:00
|
|
|
with torch.no_grad():
|
|
|
|
if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets
|
|
|
|
benchmark_all_eval(model, criterion, converter, opt)
|
|
|
|
else:
|
2019-12-27 18:31:47 +08:00
|
|
|
log = open(f'./result/{opt.experiment_name}/log_evaluation.txt', 'a')
|
2019-08-03 14:55:32 +08:00
|
|
|
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
|
2019-12-27 18:31:47 +08:00
|
|
|
eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt)
|
2019-08-03 14:55:32 +08:00
|
|
|
evaluation_loader = torch.utils.data.DataLoader(
|
|
|
|
eval_data, batch_size=opt.batch_size,
|
|
|
|
shuffle=False,
|
|
|
|
num_workers=int(opt.workers),
|
|
|
|
collate_fn=AlignCollate_evaluation, pin_memory=True)
|
2019-10-22 21:58:23 +08:00
|
|
|
_, accuracy_by_best_model, _, _, _, _, _, _ = validation(
|
2019-08-03 14:55:32 +08:00
|
|
|
model, criterion, evaluation_loader, converter, opt)
|
2019-12-27 18:31:47 +08:00
|
|
|
log.write(eval_data_log)
|
|
|
|
print(f'{accuracy_by_best_model:0.3f}')
|
|
|
|
log.write(f'{accuracy_by_best_model:0.3f}\n')
|
|
|
|
log.close()
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('--eval_data', required=True, help='path to evaluation dataset')
|
|
|
|
parser.add_argument('--benchmark_all_eval', action='store_true', help='evaluate 10 benchmark evaluation datasets')
|
|
|
|
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
|
|
|
|
parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
|
2019-05-17 21:44:38 +08:00
|
|
|
parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation")
|
2019-04-05 18:45:29 +08:00
|
|
|
""" Data processing """
|
|
|
|
parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
|
|
|
|
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
|
|
|
|
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
|
|
|
|
parser.add_argument('--rgb', action='store_true', help='use rgb input')
|
|
|
|
parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
|
|
|
|
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
|
2019-05-10 10:11:06 +08:00
|
|
|
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
|
2019-07-16 18:04:20 +08:00
|
|
|
parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode')
|
2019-04-05 18:45:29 +08:00
|
|
|
""" Model Architecture """
|
|
|
|
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
|
|
|
|
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
|
|
|
|
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
|
|
|
|
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
|
2019-04-09 17:06:32 +08:00
|
|
|
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
|
2019-04-05 18:45:29 +08:00
|
|
|
parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
|
|
|
|
parser.add_argument('--output_channel', type=int, default=512,
|
|
|
|
help='the number of output channel of Feature extractor')
|
|
|
|
parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
|
|
|
|
|
|
|
|
opt = parser.parse_args()
|
|
|
|
|
2019-05-09 11:21:18 +08:00
|
|
|
""" vocab / character number configuration """
|
|
|
|
if opt.sensitive:
|
|
|
|
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
|
|
|
|
|
2019-04-05 18:45:29 +08:00
|
|
|
cudnn.benchmark = True
|
|
|
|
cudnn.deterministic = True
|
2019-04-14 20:54:22 +08:00
|
|
|
opt.num_gpu = torch.cuda.device_count()
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
test(opt)
|