eval with alphanumeric and case sensitive setting
This commit is contained in:
parent
49942e5dc0
commit
5b132b932d
|
@ -122,7 +122,7 @@ def hierarchical_dataset(root, opt, select_data='/'):
|
||||||
dataset_list.append(dataset)
|
dataset_list.append(dataset)
|
||||||
|
|
||||||
concatenated_dataset = ConcatDataset(dataset_list)
|
concatenated_dataset = ConcatDataset(dataset_list)
|
||||||
|
|
||||||
return concatenated_dataset, dataset_log
|
return concatenated_dataset, dataset_log
|
||||||
|
|
||||||
|
|
||||||
|
|
19
test.py
19
test.py
|
@ -2,6 +2,7 @@ import os
|
||||||
import time
|
import time
|
||||||
import string
|
import string
|
||||||
import argparse
|
import argparse
|
||||||
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
|
@ -133,20 +134,28 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
||||||
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
|
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
|
||||||
pred_max_prob = pred_max_prob[:pred_EOS]
|
pred_max_prob = pred_max_prob[:pred_EOS]
|
||||||
|
|
||||||
|
# To evaluate the model with 'alphanumeric and case insensitve setting'
|
||||||
|
if opt.sensitive and opt.data_filtering_off:
|
||||||
|
pred = pred.lower()
|
||||||
|
gt = gt.lower()
|
||||||
|
alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
|
||||||
|
out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
|
||||||
|
pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
|
||||||
|
|
||||||
if pred == gt:
|
if pred == gt:
|
||||||
n_correct += 1
|
n_correct += 1
|
||||||
|
|
||||||
'''
|
'''
|
||||||
(old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
|
(old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
|
||||||
"For each word we calculate the normalized edit distance to the length of the ground truth transcription."
|
"For each word we calculate the normalized edit distance to the length of the ground truth transcription."
|
||||||
if len(gt) == 0:
|
if len(gt) == 0:
|
||||||
norm_ED += 1
|
norm_ED += 1
|
||||||
else:
|
else:
|
||||||
norm_ED += edit_distance(pred, gt) / len(gt)
|
norm_ED += edit_distance(pred, gt) / len(gt)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# ICDAR2019 Normalized Edit Distance
|
# ICDAR2019 Normalized Edit Distance
|
||||||
if len(gt) == 0 or len(pred) ==0:
|
if len(gt) == 0 or len(pred) == 0:
|
||||||
norm_ED += 0
|
norm_ED += 0
|
||||||
elif len(gt) > len(pred):
|
elif len(gt) > len(pred):
|
||||||
norm_ED += 1 - edit_distance(pred, gt) / len(gt)
|
norm_ED += 1 - edit_distance(pred, gt) / len(gt)
|
||||||
|
@ -162,7 +171,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
||||||
# print(pred, gt, pred==gt, confidence_score)
|
# print(pred, gt, pred==gt, confidence_score)
|
||||||
|
|
||||||
accuracy = n_correct / float(length_of_data) * 100
|
accuracy = n_correct / float(length_of_data) * 100
|
||||||
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
|
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
|
||||||
|
|
||||||
return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
|
return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue