diff --git a/dataset.py b/dataset.py index 02193b6..af7592f 100755 --- a/dataset.py +++ b/dataset.py @@ -122,7 +122,7 @@ def hierarchical_dataset(root, opt, select_data='/'): dataset_list.append(dataset) concatenated_dataset = ConcatDataset(dataset_list) - + return concatenated_dataset, dataset_log diff --git a/test.py b/test.py index 87d69d0..88b04cf 100755 --- a/test.py +++ b/test.py @@ -2,6 +2,7 @@ import os import time import string import argparse +import re import torch import torch.backends.cudnn as cudnn @@ -133,20 +134,28 @@ def validation(model, criterion, evaluation_loader, converter, opt): pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] + # To evaluate the model with 'alphanumeric and case insensitve setting' + if opt.sensitive and opt.data_filtering_off: + pred = pred.lower() + gt = gt.lower() + alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' + out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' + pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) + if pred == gt: n_correct += 1 ''' (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks - "For each word we calculate the normalized edit distance to the length of the ground truth transcription." + "For each word we calculate the normalized edit distance to the length of the ground truth transcription." if len(gt) == 0: norm_ED += 1 else: norm_ED += edit_distance(pred, gt) / len(gt) ''' - - # ICDAR2019 Normalized Edit Distance - if len(gt) == 0 or len(pred) ==0: + + # ICDAR2019 Normalized Edit Distance + if len(gt) == 0 or len(pred) == 0: norm_ED += 0 elif len(gt) > len(pred): norm_ED += 1 - edit_distance(pred, gt) / len(gt) @@ -162,7 +171,7 @@ def validation(model, criterion, evaluation_loader, converter, opt): # print(pred, gt, pred==gt, confidence_score) accuracy = n_correct / float(length_of_data) * 100 - norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance + norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data