diff --git a/test.py b/test.py index 6603f78..d09bd33 100755 --- a/test.py +++ b/test.py @@ -1,12 +1,12 @@ import os import time +import string import argparse import torch import torch.backends.cudnn as cudnn import torch.utils.data import numpy as np -from torch_baidu_ctc import CTCLoss from nltk.metrics.distance import edit_distance from utils import CTCLabelConverter, AttnLabelConverter, Averager @@ -16,9 +16,6 @@ from model import Model def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False): """ evaluation with 10 benchmark evaluation datasets """ - list_accuracy = [] - Total_forward_time = 0 - Total_evaluation_data_number = 0 # The evaluation datasets, dataset order is same with Table 1 in our paper. eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] @@ -28,31 +25,39 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa else: evaluation_batch_size = opt.batch_size + list_accuracy = [] + total_forward_time = 0 + total_evaluation_data_number = 0 + total_correct_number = 0 print('-' * 80) for eval_data in eval_data_list: eval_data_path = os.path.join(opt.eval_data, eval_data) AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW) eval_data = hierarchical_dataset(root=eval_data_path, opt=opt) - print('-' * 80) - Total_evaluation_data_number += len(eval_data) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=evaluation_batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) - _, accuracy_by_best_model, _, _, _, infer_time = validation( + _, accuracy_by_best_model, norm_ED_by_best_model, _, _, infer_time, length_of_data = validation( model, criterion, evaluation_loader, converter, opt) - Total_forward_time += infer_time list_accuracy.append(f'{accuracy_by_best_model:0.3f}') + total_forward_time += infer_time + total_evaluation_data_number += len(eval_data) + total_correct_number += accuracy_by_best_model * length_of_data + print('Acc %0.3f\t normalized_ED %0.3f' % (accuracy_by_best_model, norm_ED_by_best_model)) + print('-' * 80) - averaged_forward_time = Total_forward_time / Total_evaluation_data_number * 1000 + averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000 + total_accuracy = total_correct_number / total_evaluation_data_number params_num = sum([np.prod(p.size()) for p in model.parameters()]) evaluation_log = 'accuracy: ' for name, accuracy in zip(eval_data_list, list_accuracy): evaluation_log += f'{name}: {accuracy}\t' - evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}, # parameters: {params_num/1e6:0.3f}' + evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t' + evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}' print(evaluation_log) with open(f'./result/{opt.experiment_name}/log_all_evaluation.txt', 'a') as log: log.write(evaluation_log + '\n') @@ -67,7 +72,7 @@ def validation(model, criterion, evaluation_loader, converter, opt): n_correct = 0 norm_ED = 0 - max_length = 25 + max_length = opt.batch_max_length length_of_data = 0 infer_time = 0 valid_loss_avg = Averager() @@ -85,13 +90,13 @@ def validation(model, criterion, evaluation_loader, converter, opt): start_time = time.time() if 'CTC' in opt.Prediction: - preds = model(image, text_for_pred) + preds = model(image, text_for_pred).log_softmax(2) forward_time = time.time() - start_time # Calculate evaluation loss for CTC deocder. preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # to use CTCloss format - cost = criterion(preds, text_for_loss, preds_size, length_for_loss) / batch_size + cost = criterion(preds, text_for_loss, preds_size, length_for_loss) # Select max probabilty (greedy decoding) then decode index to character _, preds = preds.max(2) @@ -126,7 +131,7 @@ def validation(model, criterion, evaluation_loader, converter, opt): accuracy = n_correct / float(length_of_data) * 100 - return valid_loss_avg.val(), accuracy, norm_ED, sim_preds, cpu_texts, infer_time + return valid_loss_avg.val(), accuracy, norm_ED, sim_preds, cpu_texts, infer_time, length_of_data def test(opt): @@ -135,11 +140,10 @@ def test(opt): converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) - + opt.num_class = len(converter.character) + if opt.rgb: opt.input_channel = 3 - - opt.num_class = len(converter.character) model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, @@ -159,7 +163,7 @@ def test(opt): """ setup loss """ if 'CTC' in opt.Prediction: - criterion = CTCLoss(reduction='sum') + criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0 @@ -175,7 +179,7 @@ def test(opt): shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) - _, accuracy_by_best_model, _, _, _, _ = validation( + _, accuracy_by_best_model, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) print(accuracy_by_best_model) @@ -210,6 +214,10 @@ if __name__ == '__main__': opt = parser.parse_args() + """ vocab / character number configuration """ + if opt.sensitive: + opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). + cudnn.benchmark = True cudnn.deterministic = True opt.num_gpu = torch.cuda.device_count() diff --git a/train.py b/train.py index 5d64ee1..c5717a3 100755 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ import os import sys import time import random +import string import argparse import torch @@ -10,7 +11,6 @@ import torch.nn.init as init import torch.optim as optim import torch.utils.data import numpy as np -from torch_baidu_ctc import CTCLoss from utils import CTCLabelConverter, AttnLabelConverter, Averager from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset @@ -39,6 +39,7 @@ def train(opt): else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) + if opt.rgb: opt.input_channel = 3 model = Model(opt) @@ -72,7 +73,7 @@ def train(opt): """ setup loss """ if 'CTC' in opt.Prediction: - criterion = CTCLoss(reduction='sum') + criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0 # loss averager @@ -128,10 +129,10 @@ def train(opt): batch_size = image.size(0) if 'CTC' in opt.Prediction: - preds = model(image, text) + preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # to use CTCLoss format - cost = criterion(preds, text, preds_size, length) / batch_size + cost = criterion(preds, text, preds_size, length) else: preds = model(image, text) @@ -155,7 +156,7 @@ def train(opt): loss_avg.reset() model.eval() - valid_loss, current_accuracy, current_norm_ED, preds, gts, infer_time = validation( + valid_loss, current_accuracy, current_norm_ED, preds, gts, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() @@ -245,7 +246,8 @@ if __name__ == '__main__': """ vocab / character number configuration """ if opt.sensitive: - opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + # opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). """ Seed and GPU setting """ # print("Random Seed: ", opt.manualSeed)