log_dataset.txt, IC19-NED, minor updates
This commit is contained in:
parent
6593928855
commit
8920de93e8
44
dataset.py
44
dataset.py
|
@ -22,8 +22,12 @@ class Batch_Balanced_Dataset(object):
|
|||
For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5",
|
||||
the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST.
|
||||
"""
|
||||
print('-' * 80)
|
||||
log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
|
||||
dashed_line = '-' * 80
|
||||
print(dashed_line)
|
||||
log.write(dashed_line + '\n')
|
||||
print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}')
|
||||
log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n')
|
||||
assert len(opt.select_data) == len(opt.batch_ratio)
|
||||
|
||||
_AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
|
||||
|
@ -33,9 +37,11 @@ class Batch_Balanced_Dataset(object):
|
|||
Total_batch_size = 0
|
||||
for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio):
|
||||
_batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1)
|
||||
print('-' * 80)
|
||||
_dataset = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d])
|
||||
print(dashed_line)
|
||||
log.write(dashed_line + '\n')
|
||||
_dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d])
|
||||
total_number_dataset = len(_dataset)
|
||||
log.write(_dataset_log)
|
||||
|
||||
"""
|
||||
The total number of data can be modified with opt.total_data_usage_ratio.
|
||||
|
@ -47,8 +53,10 @@ class Batch_Balanced_Dataset(object):
|
|||
indices = range(total_number_dataset)
|
||||
_dataset, _ = [Subset(_dataset, indices[offset - length:offset])
|
||||
for offset, length in zip(_accumulate(dataset_split), dataset_split)]
|
||||
print(f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}')
|
||||
print(f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}')
|
||||
selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n'
|
||||
selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}'
|
||||
print(selected_d_log)
|
||||
log.write(selected_d_log + '\n')
|
||||
batch_size_list.append(str(_batch_size))
|
||||
Total_batch_size += _batch_size
|
||||
|
||||
|
@ -59,10 +67,16 @@ class Batch_Balanced_Dataset(object):
|
|||
collate_fn=_AlignCollate, pin_memory=True)
|
||||
self.data_loader_list.append(_data_loader)
|
||||
self.dataloader_iter_list.append(iter(_data_loader))
|
||||
print('-' * 80)
|
||||
print('Total_batch_size: ', '+'.join(batch_size_list), '=', str(Total_batch_size))
|
||||
|
||||
Total_batch_size_log = f'{dashed_line}\n'
|
||||
batch_size_sum = '+'.join(batch_size_list)
|
||||
Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n'
|
||||
Total_batch_size_log += f'{dashed_line}'
|
||||
opt.batch_size = Total_batch_size
|
||||
print('-' * 80)
|
||||
|
||||
print(Total_batch_size_log)
|
||||
log.write(Total_batch_size_log + '\n')
|
||||
log.close()
|
||||
|
||||
def get_batch(self):
|
||||
balanced_batch_images = []
|
||||
|
@ -89,7 +103,9 @@ class Batch_Balanced_Dataset(object):
|
|||
def hierarchical_dataset(root, opt, select_data='/'):
|
||||
""" select_data='/' contains all sub-directory of root directory """
|
||||
dataset_list = []
|
||||
print(f'dataset_root: {root}\t dataset: {select_data[0]}')
|
||||
dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}'
|
||||
print(dataset_log)
|
||||
dataset_log += '\n'
|
||||
for dirpath, dirnames, filenames in os.walk(root+'/'):
|
||||
if not dirnames:
|
||||
select_flag = False
|
||||
|
@ -100,12 +116,14 @@ def hierarchical_dataset(root, opt, select_data='/'):
|
|||
|
||||
if select_flag:
|
||||
dataset = LmdbDataset(dirpath, opt)
|
||||
print(f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}')
|
||||
sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}'
|
||||
print(sub_dataset_log)
|
||||
dataset_log += f'{sub_dataset_log}\n'
|
||||
dataset_list.append(dataset)
|
||||
|
||||
concatenated_dataset = ConcatDataset(dataset_list)
|
||||
|
||||
return concatenated_dataset
|
||||
return concatenated_dataset, dataset_log
|
||||
|
||||
|
||||
class LmdbDataset(Dataset):
|
||||
|
@ -129,8 +147,8 @@ class LmdbDataset(Dataset):
|
|||
else:
|
||||
""" Filtering part
|
||||
If you want to evaluate IC15-2077 & CUTE datasets which have special character labels,
|
||||
use --data_filtering_off and evaluation with this snippet (only evaluate on alphabets and digits).
|
||||
https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/dataset.py#L186-L188
|
||||
use --data_filtering_off and only evaluate on alphabets and digits.
|
||||
see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L190-L192
|
||||
"""
|
||||
self.filtered_index_list = []
|
||||
for index in range(self.nSamples):
|
||||
|
|
16
demo.py
16
demo.py
|
@ -52,7 +52,7 @@ def demo(opt):
|
|||
text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
|
||||
|
||||
if 'CTC' in opt.Prediction:
|
||||
preds = model(image, text_for_pred).log_softmax(2)
|
||||
preds = model(image, text_for_pred)
|
||||
|
||||
# Select max probabilty (greedy decoding) then decode index to character
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||
|
@ -67,9 +67,14 @@ def demo(opt):
|
|||
_, preds_index = preds.max(2)
|
||||
preds_str = converter.decode(preds_index, length_for_pred)
|
||||
|
||||
print('-' * 80)
|
||||
print(f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score')
|
||||
print('-' * 80)
|
||||
|
||||
log = open(f'./log_demo_result.txt', 'a')
|
||||
dashed_line = '-' * 80
|
||||
head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'
|
||||
|
||||
print(f'{dashed_line}\n{head}\n{dashed_line}')
|
||||
log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')
|
||||
|
||||
preds_prob = F.softmax(preds, dim=2)
|
||||
preds_max_prob, _ = preds_prob.max(dim=2)
|
||||
for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
|
||||
|
@ -81,9 +86,10 @@ def demo(opt):
|
|||
# calculate confidence score (= multiply of pred_max_prob)
|
||||
confidence_score = pred_max_prob.cumprod(dim=0)[-1]
|
||||
|
||||
# print(f'{img_name}\t{pred}\t{confidence_score:0.4f}')
|
||||
print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
|
||||
log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n')
|
||||
|
||||
log.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
45
test.py
45
test.py
|
@ -31,11 +31,14 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa
|
|||
total_forward_time = 0
|
||||
total_evaluation_data_number = 0
|
||||
total_correct_number = 0
|
||||
print('-' * 80)
|
||||
log = open(f'./result/{opt.experiment_name}/log_all_evaluation.txt', 'a')
|
||||
dashed_line = '-' * 80
|
||||
print(dashed_line)
|
||||
log.write(dashed_line + '\n')
|
||||
for eval_data in eval_data_list:
|
||||
eval_data_path = os.path.join(opt.eval_data, eval_data)
|
||||
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
|
||||
eval_data = hierarchical_dataset(root=eval_data_path, opt=opt)
|
||||
eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt)
|
||||
evaluation_loader = torch.utils.data.DataLoader(
|
||||
eval_data, batch_size=evaluation_batch_size,
|
||||
shuffle=False,
|
||||
|
@ -48,8 +51,11 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa
|
|||
total_forward_time += infer_time
|
||||
total_evaluation_data_number += len(eval_data)
|
||||
total_correct_number += accuracy_by_best_model * length_of_data
|
||||
print('Acc %0.3f\t normalized_ED %0.3f' % (accuracy_by_best_model, norm_ED_by_best_model))
|
||||
print('-' * 80)
|
||||
log.write(eval_data_log)
|
||||
print(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}')
|
||||
log.write(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}\n')
|
||||
print(dashed_line)
|
||||
log.write(dashed_line + '\n')
|
||||
|
||||
averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000
|
||||
total_accuracy = total_correct_number / total_evaluation_data_number
|
||||
|
@ -61,8 +67,8 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa
|
|||
evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t'
|
||||
evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}'
|
||||
print(evaluation_log)
|
||||
with open(f'./result/{opt.experiment_name}/log_all_evaluation.txt', 'a') as log:
|
||||
log.write(evaluation_log + '\n')
|
||||
log.close()
|
||||
|
||||
return None
|
||||
|
||||
|
@ -87,13 +93,13 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
|||
|
||||
start_time = time.time()
|
||||
if 'CTC' in opt.Prediction:
|
||||
preds = model(image, text_for_pred).log_softmax(2)
|
||||
preds = model(image, text_for_pred)
|
||||
forward_time = time.time() - start_time
|
||||
|
||||
# Calculate evaluation loss for CTC deocder.
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||
# permute 'preds' to use CTCloss format
|
||||
cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
|
||||
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
|
||||
|
||||
# Select max probabilty (greedy decoding) then decode index to character
|
||||
_, preds_index = preds.max(2)
|
||||
|
@ -129,10 +135,23 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
|||
|
||||
if pred == gt:
|
||||
n_correct += 1
|
||||
|
||||
'''
|
||||
(old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
|
||||
"For each word we calculate the normalized edit distance to the length of the ground truth transcription."
|
||||
if len(gt) == 0:
|
||||
norm_ED += 1
|
||||
else:
|
||||
norm_ED += edit_distance(pred, gt) / len(gt)
|
||||
'''
|
||||
|
||||
# ICDAR2019 Normalized Edit Distance
|
||||
if len(gt) == 0 or len(pred) ==0:
|
||||
norm_ED += 0
|
||||
elif len(gt) > len(pred):
|
||||
norm_ED += 1 - edit_distance(pred, gt) / len(gt)
|
||||
else:
|
||||
norm_ED += 1 - edit_distance(pred, gt) / len(pred)
|
||||
|
||||
# calculate confidence score (= multiply of pred_max_prob)
|
||||
try:
|
||||
|
@ -143,6 +162,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
|||
# print(pred, gt, pred==gt, confidence_score)
|
||||
|
||||
accuracy = n_correct / float(length_of_data) * 100
|
||||
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
|
||||
|
||||
return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
|
||||
|
||||
|
@ -185,8 +205,9 @@ def test(opt):
|
|||
if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets
|
||||
benchmark_all_eval(model, criterion, converter, opt)
|
||||
else:
|
||||
log = open(f'./result/{opt.experiment_name}/log_evaluation.txt', 'a')
|
||||
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
|
||||
eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt)
|
||||
eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt)
|
||||
evaluation_loader = torch.utils.data.DataLoader(
|
||||
eval_data, batch_size=opt.batch_size,
|
||||
shuffle=False,
|
||||
|
@ -194,10 +215,10 @@ def test(opt):
|
|||
collate_fn=AlignCollate_evaluation, pin_memory=True)
|
||||
_, accuracy_by_best_model, _, _, _, _, _, _ = validation(
|
||||
model, criterion, evaluation_loader, converter, opt)
|
||||
|
||||
print(accuracy_by_best_model)
|
||||
with open('./result/{0}/log_evaluation.txt'.format(opt.experiment_name), 'a') as log:
|
||||
log.write(str(accuracy_by_best_model) + '\n')
|
||||
log.write(eval_data_log)
|
||||
print(f'{accuracy_by_best_model:0.3f}')
|
||||
log.write(f'{accuracy_by_best_model:0.3f}\n')
|
||||
log.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
41
train.py
41
train.py
|
@ -24,20 +24,24 @@ def train(opt):
|
|||
if not opt.data_filtering_off:
|
||||
print('Filtering the images containing characters which are not in opt.character')
|
||||
print('Filtering the images whose label is longer than opt.batch_max_length')
|
||||
# see https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/dataset.py#L130
|
||||
# see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130
|
||||
|
||||
opt.select_data = opt.select_data.split('-')
|
||||
opt.batch_ratio = opt.batch_ratio.split('-')
|
||||
train_dataset = Batch_Balanced_Dataset(opt)
|
||||
|
||||
log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
|
||||
AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
|
||||
valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt)
|
||||
valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=opt.batch_size,
|
||||
shuffle=True, # 'True' to check training progress with validation function.
|
||||
num_workers=int(opt.workers),
|
||||
collate_fn=AlignCollate_valid, pin_memory=True)
|
||||
log.write(valid_dataset_log)
|
||||
print('-' * 80)
|
||||
log.write('-' * 80 + '\n')
|
||||
log.close()
|
||||
|
||||
""" model configuration """
|
||||
if 'CTC' in opt.Prediction:
|
||||
|
@ -176,13 +180,9 @@ def train(opt):
|
|||
|
||||
# training loss and validation loss
|
||||
loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
|
||||
print(loss_log)
|
||||
log.write(loss_log + '\n')
|
||||
loss_avg.reset()
|
||||
|
||||
current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'
|
||||
print(current_model_log)
|
||||
log.write(current_model_log + '\n')
|
||||
|
||||
# keep best accuracy model (on valid dataset)
|
||||
if current_accuracy > best_accuracy:
|
||||
|
@ -192,22 +192,24 @@ def train(opt):
|
|||
best_norm_ED = current_norm_ED
|
||||
torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
|
||||
best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'
|
||||
print(best_model_log)
|
||||
log.write(best_model_log + '\n')
|
||||
|
||||
loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
|
||||
print(loss_model_log)
|
||||
log.write(loss_model_log + '\n')
|
||||
|
||||
# show some predicted results
|
||||
print('-' * 80)
|
||||
print(f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F')
|
||||
log.write(f'{"Ground Truth":25s} | {"Prediction":25s} | {"Confidence Score"}\n')
|
||||
print('-' * 80)
|
||||
dashed_line = '-' * 80
|
||||
head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
|
||||
predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
|
||||
for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
|
||||
if 'Attn' in opt.Prediction:
|
||||
gt = gt[:gt.find('[s]')]
|
||||
pred = pred[:pred.find('[s]')]
|
||||
|
||||
print(f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}')
|
||||
log.write(f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n')
|
||||
print('-' * 80)
|
||||
predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
|
||||
predicted_result_log += f'{dashed_line}'
|
||||
print(predicted_result_log)
|
||||
log.write(predicted_result_log + '\n')
|
||||
|
||||
# save model per 1e+5 iter.
|
||||
if (i + 1) % 1e+5 == 0:
|
||||
|
@ -249,17 +251,20 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
|
||||
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
|
||||
parser.add_argument('--rgb', action='store_true', help='use rgb input')
|
||||
parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
|
||||
parser.add_argument('--character', type=str,
|
||||
default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
|
||||
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
|
||||
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
|
||||
parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode')
|
||||
""" Model Architecture """
|
||||
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
|
||||
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
|
||||
parser.add_argument('--FeatureExtraction', type=str, required=True,
|
||||
help='FeatureExtraction stage. VGG|RCNN|ResNet')
|
||||
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
|
||||
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
|
||||
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
|
||||
parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
|
||||
parser.add_argument('--input_channel', type=int, default=1,
|
||||
help='the number of input channel of Feature extractor')
|
||||
parser.add_argument('--output_channel', type=int, default=512,
|
||||
help='the number of output channel of Feature extractor')
|
||||
parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
|
||||
|
|
Loading…
Reference in New Issue