eval with alphanumeric and case sensitive setting

This commit is contained in:
Baek JeongHun 2020-02-25 11:06:34 +00:00
parent 49942e5dc0
commit 5b132b932d
2 changed files with 15 additions and 6 deletions

View File

@ -122,7 +122,7 @@ def hierarchical_dataset(root, opt, select_data='/'):
dataset_list.append(dataset)
concatenated_dataset = ConcatDataset(dataset_list)
return concatenated_dataset, dataset_log

19
test.py
View File

@ -2,6 +2,7 @@ import os
import time
import string
import argparse
import re
import torch
import torch.backends.cudnn as cudnn
@ -133,20 +134,28 @@ def validation(model, criterion, evaluation_loader, converter, opt):
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
pred_max_prob = pred_max_prob[:pred_EOS]
# To evaluate the model with 'alphanumeric and case insensitve setting'
if opt.sensitive and opt.data_filtering_off:
pred = pred.lower()
gt = gt.lower()
alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
if pred == gt:
n_correct += 1
'''
(old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
"For each word we calculate the normalized edit distance to the length of the ground truth transcription."
"For each word we calculate the normalized edit distance to the length of the ground truth transcription."
if len(gt) == 0:
norm_ED += 1
else:
norm_ED += edit_distance(pred, gt) / len(gt)
'''
# ICDAR2019 Normalized Edit Distance
if len(gt) == 0 or len(pred) ==0:
# ICDAR2019 Normalized Edit Distance
if len(gt) == 0 or len(pred) == 0:
norm_ED += 0
elif len(gt) > len(pred):
norm_ED += 1 - edit_distance(pred, gt) / len(gt)
@ -162,7 +171,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
# print(pred, gt, pred==gt, confidence_score)
accuracy = n_correct / float(length_of_data) * 100
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data