upgrade to PyTorch 1.1.0 (use torch.nn.CTCLoss) and test.py update

This commit is contained in:
Baek JeongHun 2019-05-09 03:21:18 +00:00
parent 7333391d33
commit cf390a0873
2 changed files with 35 additions and 25 deletions

46
test.py
View File

@ -1,12 +1,12 @@
import os
import time
import string
import argparse
import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import numpy as np
from torch_baidu_ctc import CTCLoss
from nltk.metrics.distance import edit_distance
from utils import CTCLabelConverter, AttnLabelConverter, Averager
@ -16,9 +16,6 @@ from model import Model
def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False):
""" evaluation with 10 benchmark evaluation datasets """
list_accuracy = []
Total_forward_time = 0
Total_evaluation_data_number = 0
# The evaluation datasets, dataset order is same with Table 1 in our paper.
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857',
'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
@ -28,31 +25,39 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa
else:
evaluation_batch_size = opt.batch_size
list_accuracy = []
total_forward_time = 0
total_evaluation_data_number = 0
total_correct_number = 0
print('-' * 80)
for eval_data in eval_data_list:
eval_data_path = os.path.join(opt.eval_data, eval_data)
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW)
eval_data = hierarchical_dataset(root=eval_data_path, opt=opt)
print('-' * 80)
Total_evaluation_data_number += len(eval_data)
evaluation_loader = torch.utils.data.DataLoader(
eval_data, batch_size=evaluation_batch_size,
shuffle=False,
num_workers=int(opt.workers),
collate_fn=AlignCollate_evaluation, pin_memory=True)
_, accuracy_by_best_model, _, _, _, infer_time = validation(
_, accuracy_by_best_model, norm_ED_by_best_model, _, _, infer_time, length_of_data = validation(
model, criterion, evaluation_loader, converter, opt)
Total_forward_time += infer_time
list_accuracy.append(f'{accuracy_by_best_model:0.3f}')
total_forward_time += infer_time
total_evaluation_data_number += len(eval_data)
total_correct_number += accuracy_by_best_model * length_of_data
print('Acc %0.3f\t normalized_ED %0.3f' % (accuracy_by_best_model, norm_ED_by_best_model))
print('-' * 80)
averaged_forward_time = Total_forward_time / Total_evaluation_data_number * 1000
averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000
total_accuracy = total_correct_number / total_evaluation_data_number
params_num = sum([np.prod(p.size()) for p in model.parameters()])
evaluation_log = 'accuracy: '
for name, accuracy in zip(eval_data_list, list_accuracy):
evaluation_log += f'{name}: {accuracy}\t'
evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}, # parameters: {params_num/1e6:0.3f}'
evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t'
evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}'
print(evaluation_log)
with open(f'./result/{opt.experiment_name}/log_all_evaluation.txt', 'a') as log:
log.write(evaluation_log + '\n')
@ -67,7 +72,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
n_correct = 0
norm_ED = 0
max_length = 25
max_length = opt.batch_max_length
length_of_data = 0
infer_time = 0
valid_loss_avg = Averager()
@ -85,13 +90,13 @@ def validation(model, criterion, evaluation_loader, converter, opt):
start_time = time.time()
if 'CTC' in opt.Prediction:
preds = model(image, text_for_pred)
preds = model(image, text_for_pred).log_softmax(2)
forward_time = time.time() - start_time
# Calculate evaluation loss for CTC deocder.
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2) # to use CTCloss format
cost = criterion(preds, text_for_loss, preds_size, length_for_loss) / batch_size
cost = criterion(preds, text_for_loss, preds_size, length_for_loss)
# Select max probabilty (greedy decoding) then decode index to character
_, preds = preds.max(2)
@ -126,7 +131,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
accuracy = n_correct / float(length_of_data) * 100
return valid_loss_avg.val(), accuracy, norm_ED, sim_preds, cpu_texts, infer_time
return valid_loss_avg.val(), accuracy, norm_ED, sim_preds, cpu_texts, infer_time, length_of_data
def test(opt):
@ -135,11 +140,10 @@ def test(opt):
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
if opt.rgb:
opt.input_channel = 3
opt.num_class = len(converter.character)
model = Model(opt)
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
@ -159,7 +163,7 @@ def test(opt):
""" setup loss """
if 'CTC' in opt.Prediction:
criterion = CTCLoss(reduction='sum')
criterion = torch.nn.CTCLoss(zero_infinity=True).cuda()
else:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0
@ -175,7 +179,7 @@ def test(opt):
shuffle=False,
num_workers=int(opt.workers),
collate_fn=AlignCollate_evaluation, pin_memory=True)
_, accuracy_by_best_model, _, _, _, _ = validation(
_, accuracy_by_best_model, _, _, _, _, _ = validation(
model, criterion, evaluation_loader, converter, opt)
print(accuracy_by_best_model)
@ -210,6 +214,10 @@ if __name__ == '__main__':
opt = parser.parse_args()
""" vocab / character number configuration """
if opt.sensitive:
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
cudnn.benchmark = True
cudnn.deterministic = True
opt.num_gpu = torch.cuda.device_count()

View File

@ -2,6 +2,7 @@ import os
import sys
import time
import random
import string
import argparse
import torch
@ -10,7 +11,6 @@ import torch.nn.init as init
import torch.optim as optim
import torch.utils.data
import numpy as np
from torch_baidu_ctc import CTCLoss
from utils import CTCLabelConverter, AttnLabelConverter, Averager
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
@ -39,6 +39,7 @@ def train(opt):
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
if opt.rgb:
opt.input_channel = 3
model = Model(opt)
@ -72,7 +73,7 @@ def train(opt):
""" setup loss """
if 'CTC' in opt.Prediction:
criterion = CTCLoss(reduction='sum')
criterion = torch.nn.CTCLoss(zero_infinity=True).cuda()
else:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0
# loss averager
@ -128,10 +129,10 @@ def train(opt):
batch_size = image.size(0)
if 'CTC' in opt.Prediction:
preds = model(image, text)
preds = model(image, text).log_softmax(2)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2) # to use CTCLoss format
cost = criterion(preds, text, preds_size, length) / batch_size
cost = criterion(preds, text, preds_size, length)
else:
preds = model(image, text)
@ -155,7 +156,7 @@ def train(opt):
loss_avg.reset()
model.eval()
valid_loss, current_accuracy, current_norm_ED, preds, gts, infer_time = validation(
valid_loss, current_accuracy, current_norm_ED, preds, gts, infer_time, length_of_data = validation(
model, criterion, valid_loader, converter, opt)
model.train()
@ -245,7 +246,8 @@ if __name__ == '__main__':
""" vocab / character number configuration """
if opt.sensitive:
opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
# opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
""" Seed and GPU setting """
# print("Random Seed: ", opt.manualSeed)