diff --git a/dataset.py b/dataset.py index 2f6103e..02193b6 100755 --- a/dataset.py +++ b/dataset.py @@ -22,8 +22,12 @@ class Batch_Balanced_Dataset(object): For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5", the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST. """ - print('-' * 80) + log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') + dashed_line = '-' * 80 + print(dashed_line) + log.write(dashed_line + '\n') print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') + log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n') assert len(opt.select_data) == len(opt.batch_ratio) _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) @@ -33,9 +37,11 @@ class Batch_Balanced_Dataset(object): Total_batch_size = 0 for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) - print('-' * 80) - _dataset = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d]) + print(dashed_line) + log.write(dashed_line + '\n') + _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d]) total_number_dataset = len(_dataset) + log.write(_dataset_log) """ The total number of data can be modified with opt.total_data_usage_ratio. @@ -47,8 +53,10 @@ class Batch_Balanced_Dataset(object): indices = range(total_number_dataset) _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(dataset_split), dataset_split)] - print(f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}') - print(f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}') + selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n' + selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}' + print(selected_d_log) + log.write(selected_d_log + '\n') batch_size_list.append(str(_batch_size)) Total_batch_size += _batch_size @@ -59,10 +67,16 @@ class Batch_Balanced_Dataset(object): collate_fn=_AlignCollate, pin_memory=True) self.data_loader_list.append(_data_loader) self.dataloader_iter_list.append(iter(_data_loader)) - print('-' * 80) - print('Total_batch_size: ', '+'.join(batch_size_list), '=', str(Total_batch_size)) + + Total_batch_size_log = f'{dashed_line}\n' + batch_size_sum = '+'.join(batch_size_list) + Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n' + Total_batch_size_log += f'{dashed_line}' opt.batch_size = Total_batch_size - print('-' * 80) + + print(Total_batch_size_log) + log.write(Total_batch_size_log + '\n') + log.close() def get_batch(self): balanced_batch_images = [] @@ -89,7 +103,9 @@ class Batch_Balanced_Dataset(object): def hierarchical_dataset(root, opt, select_data='/'): """ select_data='/' contains all sub-directory of root directory """ dataset_list = [] - print(f'dataset_root: {root}\t dataset: {select_data[0]}') + dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}' + print(dataset_log) + dataset_log += '\n' for dirpath, dirnames, filenames in os.walk(root+'/'): if not dirnames: select_flag = False @@ -100,12 +116,14 @@ def hierarchical_dataset(root, opt, select_data='/'): if select_flag: dataset = LmdbDataset(dirpath, opt) - print(f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}') + sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}' + print(sub_dataset_log) + dataset_log += f'{sub_dataset_log}\n' dataset_list.append(dataset) concatenated_dataset = ConcatDataset(dataset_list) - - return concatenated_dataset + + return concatenated_dataset, dataset_log class LmdbDataset(Dataset): @@ -129,8 +147,8 @@ class LmdbDataset(Dataset): else: """ Filtering part If you want to evaluate IC15-2077 & CUTE datasets which have special character labels, - use --data_filtering_off and evaluation with this snippet (only evaluate on alphabets and digits). - https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/dataset.py#L186-L188 + use --data_filtering_off and only evaluate on alphabets and digits. + see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L190-L192 """ self.filtered_index_list = [] for index in range(self.nSamples): diff --git a/demo.py b/demo.py index 71d059d..45f3d04 100755 --- a/demo.py +++ b/demo.py @@ -52,7 +52,7 @@ def demo(opt): text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: - preds = model(image, text_for_pred).log_softmax(2) + preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) @@ -67,9 +67,14 @@ def demo(opt): _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) - print('-' * 80) - print(f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score') - print('-' * 80) + + log = open(f'./log_demo_result.txt', 'a') + dashed_line = '-' * 80 + head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' + + print(f'{dashed_line}\n{head}\n{dashed_line}') + log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') + preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): @@ -81,9 +86,10 @@ def demo(opt): # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] - # print(f'{img_name}\t{pred}\t{confidence_score:0.4f}') print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') + log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') + log.close() if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/test.py b/test.py index c18abe7..87d69d0 100755 --- a/test.py +++ b/test.py @@ -31,11 +31,14 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa total_forward_time = 0 total_evaluation_data_number = 0 total_correct_number = 0 - print('-' * 80) + log = open(f'./result/{opt.experiment_name}/log_all_evaluation.txt', 'a') + dashed_line = '-' * 80 + print(dashed_line) + log.write(dashed_line + '\n') for eval_data in eval_data_list: eval_data_path = os.path.join(opt.eval_data, eval_data) AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) - eval_data = hierarchical_dataset(root=eval_data_path, opt=opt) + eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=evaluation_batch_size, shuffle=False, @@ -48,8 +51,11 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa total_forward_time += infer_time total_evaluation_data_number += len(eval_data) total_correct_number += accuracy_by_best_model * length_of_data - print('Acc %0.3f\t normalized_ED %0.3f' % (accuracy_by_best_model, norm_ED_by_best_model)) - print('-' * 80) + log.write(eval_data_log) + print(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}') + log.write(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}\n') + print(dashed_line) + log.write(dashed_line + '\n') averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000 total_accuracy = total_correct_number / total_evaluation_data_number @@ -61,8 +67,8 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t' evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}' print(evaluation_log) - with open(f'./result/{opt.experiment_name}/log_all_evaluation.txt', 'a') as log: - log.write(evaluation_log + '\n') + log.write(evaluation_log + '\n') + log.close() return None @@ -87,13 +93,13 @@ def validation(model, criterion, evaluation_loader, converter, opt): start_time = time.time() if 'CTC' in opt.Prediction: - preds = model(image, text_for_pred).log_softmax(2) + preds = model(image, text_for_pred) forward_time = time.time() - start_time # Calculate evaluation loss for CTC deocder. preds_size = torch.IntTensor([preds.size(1)] * batch_size) # permute 'preds' to use CTCloss format - cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) + cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) # Select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) @@ -129,10 +135,23 @@ def validation(model, criterion, evaluation_loader, converter, opt): if pred == gt: n_correct += 1 + + ''' + (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks + "For each word we calculate the normalized edit distance to the length of the ground truth transcription." if len(gt) == 0: norm_ED += 1 else: norm_ED += edit_distance(pred, gt) / len(gt) + ''' + + # ICDAR2019 Normalized Edit Distance + if len(gt) == 0 or len(pred) ==0: + norm_ED += 0 + elif len(gt) > len(pred): + norm_ED += 1 - edit_distance(pred, gt) / len(gt) + else: + norm_ED += 1 - edit_distance(pred, gt) / len(pred) # calculate confidence score (= multiply of pred_max_prob) try: @@ -143,6 +162,7 @@ def validation(model, criterion, evaluation_loader, converter, opt): # print(pred, gt, pred==gt, confidence_score) accuracy = n_correct / float(length_of_data) * 100 + norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data @@ -185,8 +205,9 @@ def test(opt): if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets benchmark_all_eval(model, criterion, converter, opt) else: + log = open(f'./result/{opt.experiment_name}/log_evaluation.txt', 'a') AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) - eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt) + eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, @@ -194,10 +215,10 @@ def test(opt): collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) - - print(accuracy_by_best_model) - with open('./result/{0}/log_evaluation.txt'.format(opt.experiment_name), 'a') as log: - log.write(str(accuracy_by_best_model) + '\n') + log.write(eval_data_log) + print(f'{accuracy_by_best_model:0.3f}') + log.write(f'{accuracy_by_best_model:0.3f}\n') + log.close() if __name__ == '__main__': diff --git a/train.py b/train.py index e24e511..c814294 100755 --- a/train.py +++ b/train.py @@ -24,21 +24,25 @@ def train(opt): if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') - # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/dataset.py#L130 + # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) + log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) - valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) + valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) + log.write(valid_dataset_log) print('-' * 80) - + log.write('-' * 80 + '\n') + log.close() + """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) @@ -176,13 +180,9 @@ def train(opt): # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' - print(loss_log) - log.write(loss_log + '\n') loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' - print(current_model_log) - log.write(current_model_log + '\n') # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: @@ -192,22 +192,24 @@ def train(opt): best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' - print(best_model_log) - log.write(best_model_log + '\n') + + loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' + print(loss_model_log) + log.write(loss_model_log + '\n') # show some predicted results - print('-' * 80) - print(f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F') - log.write(f'{"Ground Truth":25s} | {"Prediction":25s} | {"Confidence Score"}\n') - print('-' * 80) + dashed_line = '-' * 80 + head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' + predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] - print(f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}') - log.write(f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n') - print('-' * 80) + predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' + predicted_result_log += f'{dashed_line}' + print(predicted_result_log) + log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: @@ -249,17 +251,20 @@ if __name__ == '__main__': parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') parser.add_argument('--rgb', action='store_true', help='use rgb input') - parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') + parser.add_argument('--character', type=str, + default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') """ Model Architecture """ parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') - parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') + parser.add_argument('--FeatureExtraction', type=str, required=True, + help='FeatureExtraction stage. VGG|RCNN|ResNet') parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') - parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') + parser.add_argument('--input_channel', type=int, default=1, + help='the number of input channel of Feature extractor') parser.add_argument('--output_channel', type=int, default=512, help='the number of output channel of Feature extractor') parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')