eval with alphanumeric and case sensitive setting

This commit is contained in:
Baek JeongHun 2020-02-25 11:06:34 +00:00
parent 49942e5dc0
commit 5b132b932d
2 changed files with 15 additions and 6 deletions

View File

@ -122,7 +122,7 @@ def hierarchical_dataset(root, opt, select_data='/'):
dataset_list.append(dataset) dataset_list.append(dataset)
concatenated_dataset = ConcatDataset(dataset_list) concatenated_dataset = ConcatDataset(dataset_list)
return concatenated_dataset, dataset_log return concatenated_dataset, dataset_log

19
test.py
View File

@ -2,6 +2,7 @@ import os
import time import time
import string import string
import argparse import argparse
import re
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
@ -133,20 +134,28 @@ def validation(model, criterion, evaluation_loader, converter, opt):
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
pred_max_prob = pred_max_prob[:pred_EOS] pred_max_prob = pred_max_prob[:pred_EOS]
# To evaluate the model with 'alphanumeric and case insensitve setting'
if opt.sensitive and opt.data_filtering_off:
pred = pred.lower()
gt = gt.lower()
alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
if pred == gt: if pred == gt:
n_correct += 1 n_correct += 1
''' '''
(old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
"For each word we calculate the normalized edit distance to the length of the ground truth transcription." "For each word we calculate the normalized edit distance to the length of the ground truth transcription."
if len(gt) == 0: if len(gt) == 0:
norm_ED += 1 norm_ED += 1
else: else:
norm_ED += edit_distance(pred, gt) / len(gt) norm_ED += edit_distance(pred, gt) / len(gt)
''' '''
# ICDAR2019 Normalized Edit Distance # ICDAR2019 Normalized Edit Distance
if len(gt) == 0 or len(pred) ==0: if len(gt) == 0 or len(pred) == 0:
norm_ED += 0 norm_ED += 0
elif len(gt) > len(pred): elif len(gt) > len(pred):
norm_ED += 1 - edit_distance(pred, gt) / len(gt) norm_ED += 1 - edit_distance(pred, gt) / len(gt)
@ -162,7 +171,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
# print(pred, gt, pred==gt, confidence_score) # print(pred, gt, pred==gt, confidence_score)
accuracy = n_correct / float(length_of_data) * 100 accuracy = n_correct / float(length_of_data) * 100
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data