comment and multi-gpu setting update

This commit is contained in:
Baek JeongHun 2019-04-16 07:28:53 +00:00
parent eb5570fd8c
commit c2e28f5c0d
4 changed files with 18 additions and 11 deletions

View File

@ -26,13 +26,14 @@ class Model(nn.Module):
def __init__(self, opt): def __init__(self, opt):
super(Model, self).__init__() super(Model, self).__init__()
self.opt = opt
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
""" Transformation """ """ Transformation """
if opt.Transformation == 'TPS': if opt.Transformation == 'TPS':
self.Transformation = TPS_SpatialTransformerNetwork( self.Transformation = TPS_SpatialTransformerNetwork(
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), batch_size=int(opt.batch_size/opt.num_gpu), I_channel_num=opt.input_channel) F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
else: else:
print('No Transformation module specified') print('No Transformation module specified')

3
modules/feature_extraction.py Normal file → Executable file
View File

@ -3,6 +3,7 @@ import torch.nn.functional as F
class VGG_FeatureExtractor(nn.Module): class VGG_FeatureExtractor(nn.Module):
""" FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """
def __init__(self, input_channel, output_channel=512): def __init__(self, input_channel, output_channel=512):
super(VGG_FeatureExtractor, self).__init__() super(VGG_FeatureExtractor, self).__init__()
@ -28,6 +29,7 @@ class VGG_FeatureExtractor(nn.Module):
class RCNN_FeatureExtractor(nn.Module): class RCNN_FeatureExtractor(nn.Module):
""" FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """
def __init__(self, input_channel, output_channel=512): def __init__(self, input_channel, output_channel=512):
super(RCNN_FeatureExtractor, self).__init__() super(RCNN_FeatureExtractor, self).__init__()
@ -50,6 +52,7 @@ class RCNN_FeatureExtractor(nn.Module):
class ResNet_FeatureExtractor(nn.Module): class ResNet_FeatureExtractor(nn.Module):
""" FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
def __init__(self, input_channel, output_channel=512): def __init__(self, input_channel, output_channel=512):
super(ResNet_FeatureExtractor, self).__init__() super(ResNet_FeatureExtractor, self).__init__()

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
class TPS_SpatialTransformerNetwork(nn.Module): class TPS_SpatialTransformerNetwork(nn.Module):
""" Rectification Network of RARE, namely TPS based STN """ """ Rectification Network of RARE, namely TPS based STN """
def __init__(self, F, I_size, I_r_size, batch_size, I_channel_num=1): def __init__(self, F, I_size, I_r_size, I_channel_num=1):
""" Based on RARE TPS """ Based on RARE TPS
input: input:
batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
@ -23,7 +23,7 @@ class TPS_SpatialTransformerNetwork(nn.Module):
self.I_r_size = I_r_size # = (I_r_height, I_r_width) self.I_r_size = I_r_size # = (I_r_height, I_r_width)
self.I_channel_num = I_channel_num self.I_channel_num = I_channel_num
self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
self.GridGenerator = GridGenerator(self.F, self.I_r_size, batch_size) self.GridGenerator = GridGenerator(self.F, self.I_r_size)
def forward(self, batch_I): def forward(self, batch_I):
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
@ -81,7 +81,7 @@ class LocalizationNetwork(nn.Module):
class GridGenerator(nn.Module): class GridGenerator(nn.Module):
""" Grid Generator of RARE, which produces P_prime by multipling T with P """ """ Grid Generator of RARE, which produces P_prime by multipling T with P """
def __init__(self, F, I_r_size, batch_size): def __init__(self, F, I_r_size):
""" Generate P_hat and inv_delta_C for later """ """ Generate P_hat and inv_delta_C for later """
super(GridGenerator, self).__init__() super(GridGenerator, self).__init__()
self.eps = 1e-6 self.eps = 1e-6

View File

@ -259,14 +259,17 @@ if __name__ == '__main__':
opt.num_gpu = torch.cuda.device_count() opt.num_gpu = torch.cuda.device_count()
# print('device count', opt.num_gpu) # print('device count', opt.num_gpu)
if opt.num_gpu > 1: if opt.num_gpu > 1:
opt.num_iter = int(opt.num_iter / opt.num_gpu)
opt.batch_size = opt.batch_size * opt.num_gpu
opt.workers = opt.workers * opt.num_gpu
print('------ Use multi-GPU setting ------') print('------ Use multi-GPU setting ------')
print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.')
# If you dont care about it, just commnet out these line.)
print(f'The batch_size is multiplied with num_gpu and multiplied batch_size is {opt.batch_size}')
print('if you stuck too long time with multi-GPU setting, try to set --workers 0') print('if you stuck too long time with multi-GPU setting, try to set --workers 0')
# check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1 # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1
opt.workers = opt.workers * opt.num_gpu
""" previous version
print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size)
opt.batch_size = opt.batch_size * opt.num_gpu
print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.')
If you dont care about it, just commnet out these line.)
opt.num_iter = int(opt.num_iter / opt.num_gpu)
"""
train(opt) train(opt)