eval with alphanumeric and case sensitive setting

This commit is contained in:
Baek JeongHun 2020-02-25 11:06:34 +00:00
parent 49942e5dc0
commit 5b132b932d
2 changed files with 15 additions and 6 deletions

13
test.py
View File

@ -2,6 +2,7 @@ import os
import time
import string
import argparse
import re
import torch
import torch.backends.cudnn as cudnn
@ -133,6 +134,14 @@ def validation(model, criterion, evaluation_loader, converter, opt):
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
pred_max_prob = pred_max_prob[:pred_EOS]
# To evaluate the model with 'alphanumeric and case insensitve setting'
if opt.sensitive and opt.data_filtering_off:
pred = pred.lower()
gt = gt.lower()
alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
if pred == gt:
n_correct += 1
@ -146,7 +155,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
'''
# ICDAR2019 Normalized Edit Distance
if len(gt) == 0 or len(pred) ==0:
if len(gt) == 0 or len(pred) == 0:
norm_ED += 0
elif len(gt) > len(pred):
norm_ED += 1 - edit_distance(pred, gt) / len(gt)
@ -162,7 +171,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
# print(pred, gt, pred==gt, confidence_score)
accuracy = n_correct / float(length_of_data) * 100
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data