upgrade to PyTorch 1.1.0 (use torch.nn.CTCLoss) and test.py update
This commit is contained in:
parent
7333391d33
commit
cf390a0873
46
test.py
46
test.py
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
import time
|
||||
import string
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
from torch_baidu_ctc import CTCLoss
|
||||
from nltk.metrics.distance import edit_distance
|
||||
|
||||
from utils import CTCLabelConverter, AttnLabelConverter, Averager
|
||||
|
@ -16,9 +16,6 @@ from model import Model
|
|||
|
||||
def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False):
|
||||
""" evaluation with 10 benchmark evaluation datasets """
|
||||
list_accuracy = []
|
||||
Total_forward_time = 0
|
||||
Total_evaluation_data_number = 0
|
||||
# The evaluation datasets, dataset order is same with Table 1 in our paper.
|
||||
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857',
|
||||
'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
|
||||
|
@ -28,31 +25,39 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa
|
|||
else:
|
||||
evaluation_batch_size = opt.batch_size
|
||||
|
||||
list_accuracy = []
|
||||
total_forward_time = 0
|
||||
total_evaluation_data_number = 0
|
||||
total_correct_number = 0
|
||||
print('-' * 80)
|
||||
for eval_data in eval_data_list:
|
||||
eval_data_path = os.path.join(opt.eval_data, eval_data)
|
||||
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW)
|
||||
eval_data = hierarchical_dataset(root=eval_data_path, opt=opt)
|
||||
print('-' * 80)
|
||||
Total_evaluation_data_number += len(eval_data)
|
||||
evaluation_loader = torch.utils.data.DataLoader(
|
||||
eval_data, batch_size=evaluation_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=int(opt.workers),
|
||||
collate_fn=AlignCollate_evaluation, pin_memory=True)
|
||||
|
||||
_, accuracy_by_best_model, _, _, _, infer_time = validation(
|
||||
_, accuracy_by_best_model, norm_ED_by_best_model, _, _, infer_time, length_of_data = validation(
|
||||
model, criterion, evaluation_loader, converter, opt)
|
||||
Total_forward_time += infer_time
|
||||
list_accuracy.append(f'{accuracy_by_best_model:0.3f}')
|
||||
total_forward_time += infer_time
|
||||
total_evaluation_data_number += len(eval_data)
|
||||
total_correct_number += accuracy_by_best_model * length_of_data
|
||||
print('Acc %0.3f\t normalized_ED %0.3f' % (accuracy_by_best_model, norm_ED_by_best_model))
|
||||
print('-' * 80)
|
||||
|
||||
averaged_forward_time = Total_forward_time / Total_evaluation_data_number * 1000
|
||||
averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000
|
||||
total_accuracy = total_correct_number / total_evaluation_data_number
|
||||
params_num = sum([np.prod(p.size()) for p in model.parameters()])
|
||||
|
||||
evaluation_log = 'accuracy: '
|
||||
for name, accuracy in zip(eval_data_list, list_accuracy):
|
||||
evaluation_log += f'{name}: {accuracy}\t'
|
||||
evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}, # parameters: {params_num/1e6:0.3f}'
|
||||
evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t'
|
||||
evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}'
|
||||
print(evaluation_log)
|
||||
with open(f'./result/{opt.experiment_name}/log_all_evaluation.txt', 'a') as log:
|
||||
log.write(evaluation_log + '\n')
|
||||
|
@ -67,7 +72,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
|||
|
||||
n_correct = 0
|
||||
norm_ED = 0
|
||||
max_length = 25
|
||||
max_length = opt.batch_max_length
|
||||
length_of_data = 0
|
||||
infer_time = 0
|
||||
valid_loss_avg = Averager()
|
||||
|
@ -85,13 +90,13 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
|||
|
||||
start_time = time.time()
|
||||
if 'CTC' in opt.Prediction:
|
||||
preds = model(image, text_for_pred)
|
||||
preds = model(image, text_for_pred).log_softmax(2)
|
||||
forward_time = time.time() - start_time
|
||||
|
||||
# Calculate evaluation loss for CTC deocder.
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||
preds = preds.permute(1, 0, 2) # to use CTCloss format
|
||||
cost = criterion(preds, text_for_loss, preds_size, length_for_loss) / batch_size
|
||||
cost = criterion(preds, text_for_loss, preds_size, length_for_loss)
|
||||
|
||||
# Select max probabilty (greedy decoding) then decode index to character
|
||||
_, preds = preds.max(2)
|
||||
|
@ -126,7 +131,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
|||
|
||||
accuracy = n_correct / float(length_of_data) * 100
|
||||
|
||||
return valid_loss_avg.val(), accuracy, norm_ED, sim_preds, cpu_texts, infer_time
|
||||
return valid_loss_avg.val(), accuracy, norm_ED, sim_preds, cpu_texts, infer_time, length_of_data
|
||||
|
||||
|
||||
def test(opt):
|
||||
|
@ -135,11 +140,10 @@ def test(opt):
|
|||
converter = CTCLabelConverter(opt.character)
|
||||
else:
|
||||
converter = AttnLabelConverter(opt.character)
|
||||
|
||||
opt.num_class = len(converter.character)
|
||||
|
||||
if opt.rgb:
|
||||
opt.input_channel = 3
|
||||
|
||||
opt.num_class = len(converter.character)
|
||||
model = Model(opt)
|
||||
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
|
||||
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
|
||||
|
@ -159,7 +163,7 @@ def test(opt):
|
|||
|
||||
""" setup loss """
|
||||
if 'CTC' in opt.Prediction:
|
||||
criterion = CTCLoss(reduction='sum')
|
||||
criterion = torch.nn.CTCLoss(zero_infinity=True).cuda()
|
||||
else:
|
||||
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0
|
||||
|
||||
|
@ -175,7 +179,7 @@ def test(opt):
|
|||
shuffle=False,
|
||||
num_workers=int(opt.workers),
|
||||
collate_fn=AlignCollate_evaluation, pin_memory=True)
|
||||
_, accuracy_by_best_model, _, _, _, _ = validation(
|
||||
_, accuracy_by_best_model, _, _, _, _, _ = validation(
|
||||
model, criterion, evaluation_loader, converter, opt)
|
||||
|
||||
print(accuracy_by_best_model)
|
||||
|
@ -210,6 +214,10 @@ if __name__ == '__main__':
|
|||
|
||||
opt = parser.parse_args()
|
||||
|
||||
""" vocab / character number configuration """
|
||||
if opt.sensitive:
|
||||
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
|
||||
|
||||
cudnn.benchmark = True
|
||||
cudnn.deterministic = True
|
||||
opt.num_gpu = torch.cuda.device_count()
|
||||
|
|
14
train.py
14
train.py
|
@ -2,6 +2,7 @@ import os
|
|||
import sys
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
@ -10,7 +11,6 @@ import torch.nn.init as init
|
|||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
from torch_baidu_ctc import CTCLoss
|
||||
|
||||
from utils import CTCLabelConverter, AttnLabelConverter, Averager
|
||||
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
|
||||
|
@ -39,6 +39,7 @@ def train(opt):
|
|||
else:
|
||||
converter = AttnLabelConverter(opt.character)
|
||||
opt.num_class = len(converter.character)
|
||||
|
||||
if opt.rgb:
|
||||
opt.input_channel = 3
|
||||
model = Model(opt)
|
||||
|
@ -72,7 +73,7 @@ def train(opt):
|
|||
|
||||
""" setup loss """
|
||||
if 'CTC' in opt.Prediction:
|
||||
criterion = CTCLoss(reduction='sum')
|
||||
criterion = torch.nn.CTCLoss(zero_infinity=True).cuda()
|
||||
else:
|
||||
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0
|
||||
# loss averager
|
||||
|
@ -128,10 +129,10 @@ def train(opt):
|
|||
batch_size = image.size(0)
|
||||
|
||||
if 'CTC' in opt.Prediction:
|
||||
preds = model(image, text)
|
||||
preds = model(image, text).log_softmax(2)
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||
preds = preds.permute(1, 0, 2) # to use CTCLoss format
|
||||
cost = criterion(preds, text, preds_size, length) / batch_size
|
||||
cost = criterion(preds, text, preds_size, length)
|
||||
|
||||
else:
|
||||
preds = model(image, text)
|
||||
|
@ -155,7 +156,7 @@ def train(opt):
|
|||
loss_avg.reset()
|
||||
|
||||
model.eval()
|
||||
valid_loss, current_accuracy, current_norm_ED, preds, gts, infer_time = validation(
|
||||
valid_loss, current_accuracy, current_norm_ED, preds, gts, infer_time, length_of_data = validation(
|
||||
model, criterion, valid_loader, converter, opt)
|
||||
model.train()
|
||||
|
||||
|
@ -245,7 +246,8 @@ if __name__ == '__main__':
|
|||
|
||||
""" vocab / character number configuration """
|
||||
if opt.sensitive:
|
||||
opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
# opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
|
||||
|
||||
""" Seed and GPU setting """
|
||||
# print("Random Seed: ", opt.manualSeed)
|
||||
|
|
Loading…
Reference in New Issue