diff --git a/model.py b/model.py index 3743f9e..be0c612 100755 --- a/model.py +++ b/model.py @@ -26,13 +26,14 @@ class Model(nn.Module): def __init__(self, opt): super(Model, self).__init__() + self.opt = opt self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} """ Transformation """ if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( - F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), batch_size=int(opt.batch_size/opt.num_gpu), I_channel_num=opt.input_channel) + F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) else: print('No Transformation module specified') diff --git a/modules/feature_extraction.py b/modules/feature_extraction.py old mode 100644 new mode 100755 index 8fe245b..b5f3004 --- a/modules/feature_extraction.py +++ b/modules/feature_extraction.py @@ -3,6 +3,7 @@ import torch.nn.functional as F class VGG_FeatureExtractor(nn.Module): + """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ def __init__(self, input_channel, output_channel=512): super(VGG_FeatureExtractor, self).__init__() @@ -28,6 +29,7 @@ class VGG_FeatureExtractor(nn.Module): class RCNN_FeatureExtractor(nn.Module): + """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ def __init__(self, input_channel, output_channel=512): super(RCNN_FeatureExtractor, self).__init__() @@ -50,6 +52,7 @@ class RCNN_FeatureExtractor(nn.Module): class ResNet_FeatureExtractor(nn.Module): + """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ def __init__(self, input_channel, output_channel=512): super(ResNet_FeatureExtractor, self).__init__() diff --git a/modules/transformation.py b/modules/transformation.py index ecf62f2..893147d 100755 --- a/modules/transformation.py +++ b/modules/transformation.py @@ -7,7 +7,7 @@ import torch.nn.functional as F class TPS_SpatialTransformerNetwork(nn.Module): """ Rectification Network of RARE, namely TPS based STN """ - def __init__(self, F, I_size, I_r_size, batch_size, I_channel_num=1): + def __init__(self, F, I_size, I_r_size, I_channel_num=1): """ Based on RARE TPS input: batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] @@ -23,7 +23,7 @@ class TPS_SpatialTransformerNetwork(nn.Module): self.I_r_size = I_r_size # = (I_r_height, I_r_width) self.I_channel_num = I_channel_num self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) - self.GridGenerator = GridGenerator(self.F, self.I_r_size, batch_size) + self.GridGenerator = GridGenerator(self.F, self.I_r_size) def forward(self, batch_I): batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 @@ -81,7 +81,7 @@ class LocalizationNetwork(nn.Module): class GridGenerator(nn.Module): """ Grid Generator of RARE, which produces P_prime by multipling T with P """ - def __init__(self, F, I_r_size, batch_size): + def __init__(self, F, I_r_size): """ Generate P_hat and inv_delta_C for later """ super(GridGenerator, self).__init__() self.eps = 1e-6 diff --git a/train.py b/train.py index a025365..d0c74d6 100755 --- a/train.py +++ b/train.py @@ -180,7 +180,7 @@ def train(opt): torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) - log.write(best_model_log+'\n') + log.write(best_model_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: @@ -259,14 +259,17 @@ if __name__ == '__main__': opt.num_gpu = torch.cuda.device_count() # print('device count', opt.num_gpu) if opt.num_gpu > 1: - opt.num_iter = int(opt.num_iter / opt.num_gpu) - opt.batch_size = opt.batch_size * opt.num_gpu - opt.workers = opt.workers * opt.num_gpu print('------ Use multi-GPU setting ------') - print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') - # If you dont care about it, just commnet out these line.) - print(f'The batch_size is multiplied with num_gpu and multiplied batch_size is {opt.batch_size}') print('if you stuck too long time with multi-GPU setting, try to set --workers 0') # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1 + opt.workers = opt.workers * opt.num_gpu + + """ previous version + print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) + opt.batch_size = opt.batch_size * opt.num_gpu + print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') + If you dont care about it, just commnet out these line.) + opt.num_iter = int(opt.num_iter / opt.num_gpu) + """ train(opt)