comment and multi-gpu setting update
This commit is contained in:
parent
eb5570fd8c
commit
c2e28f5c0d
3
model.py
3
model.py
|
@ -26,13 +26,14 @@ class Model(nn.Module):
|
||||||
|
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super(Model, self).__init__()
|
super(Model, self).__init__()
|
||||||
|
self.opt = opt
|
||||||
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
|
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
|
||||||
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
|
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
|
||||||
|
|
||||||
""" Transformation """
|
""" Transformation """
|
||||||
if opt.Transformation == 'TPS':
|
if opt.Transformation == 'TPS':
|
||||||
self.Transformation = TPS_SpatialTransformerNetwork(
|
self.Transformation = TPS_SpatialTransformerNetwork(
|
||||||
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), batch_size=int(opt.batch_size/opt.num_gpu), I_channel_num=opt.input_channel)
|
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
|
||||||
else:
|
else:
|
||||||
print('No Transformation module specified')
|
print('No Transformation module specified')
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class VGG_FeatureExtractor(nn.Module):
|
class VGG_FeatureExtractor(nn.Module):
|
||||||
|
""" FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """
|
||||||
|
|
||||||
def __init__(self, input_channel, output_channel=512):
|
def __init__(self, input_channel, output_channel=512):
|
||||||
super(VGG_FeatureExtractor, self).__init__()
|
super(VGG_FeatureExtractor, self).__init__()
|
||||||
|
@ -28,6 +29,7 @@ class VGG_FeatureExtractor(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class RCNN_FeatureExtractor(nn.Module):
|
class RCNN_FeatureExtractor(nn.Module):
|
||||||
|
""" FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """
|
||||||
|
|
||||||
def __init__(self, input_channel, output_channel=512):
|
def __init__(self, input_channel, output_channel=512):
|
||||||
super(RCNN_FeatureExtractor, self).__init__()
|
super(RCNN_FeatureExtractor, self).__init__()
|
||||||
|
@ -50,6 +52,7 @@ class RCNN_FeatureExtractor(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ResNet_FeatureExtractor(nn.Module):
|
class ResNet_FeatureExtractor(nn.Module):
|
||||||
|
""" FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
|
||||||
|
|
||||||
def __init__(self, input_channel, output_channel=512):
|
def __init__(self, input_channel, output_channel=512):
|
||||||
super(ResNet_FeatureExtractor, self).__init__()
|
super(ResNet_FeatureExtractor, self).__init__()
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||||
class TPS_SpatialTransformerNetwork(nn.Module):
|
class TPS_SpatialTransformerNetwork(nn.Module):
|
||||||
""" Rectification Network of RARE, namely TPS based STN """
|
""" Rectification Network of RARE, namely TPS based STN """
|
||||||
|
|
||||||
def __init__(self, F, I_size, I_r_size, batch_size, I_channel_num=1):
|
def __init__(self, F, I_size, I_r_size, I_channel_num=1):
|
||||||
""" Based on RARE TPS
|
""" Based on RARE TPS
|
||||||
input:
|
input:
|
||||||
batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
|
batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
|
||||||
|
@ -23,7 +23,7 @@ class TPS_SpatialTransformerNetwork(nn.Module):
|
||||||
self.I_r_size = I_r_size # = (I_r_height, I_r_width)
|
self.I_r_size = I_r_size # = (I_r_height, I_r_width)
|
||||||
self.I_channel_num = I_channel_num
|
self.I_channel_num = I_channel_num
|
||||||
self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
|
self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
|
||||||
self.GridGenerator = GridGenerator(self.F, self.I_r_size, batch_size)
|
self.GridGenerator = GridGenerator(self.F, self.I_r_size)
|
||||||
|
|
||||||
def forward(self, batch_I):
|
def forward(self, batch_I):
|
||||||
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
|
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
|
||||||
|
@ -81,7 +81,7 @@ class LocalizationNetwork(nn.Module):
|
||||||
class GridGenerator(nn.Module):
|
class GridGenerator(nn.Module):
|
||||||
""" Grid Generator of RARE, which produces P_prime by multipling T with P """
|
""" Grid Generator of RARE, which produces P_prime by multipling T with P """
|
||||||
|
|
||||||
def __init__(self, F, I_r_size, batch_size):
|
def __init__(self, F, I_r_size):
|
||||||
""" Generate P_hat and inv_delta_C for later """
|
""" Generate P_hat and inv_delta_C for later """
|
||||||
super(GridGenerator, self).__init__()
|
super(GridGenerator, self).__init__()
|
||||||
self.eps = 1e-6
|
self.eps = 1e-6
|
||||||
|
|
15
train.py
15
train.py
|
@ -259,14 +259,17 @@ if __name__ == '__main__':
|
||||||
opt.num_gpu = torch.cuda.device_count()
|
opt.num_gpu = torch.cuda.device_count()
|
||||||
# print('device count', opt.num_gpu)
|
# print('device count', opt.num_gpu)
|
||||||
if opt.num_gpu > 1:
|
if opt.num_gpu > 1:
|
||||||
opt.num_iter = int(opt.num_iter / opt.num_gpu)
|
|
||||||
opt.batch_size = opt.batch_size * opt.num_gpu
|
|
||||||
opt.workers = opt.workers * opt.num_gpu
|
|
||||||
print('------ Use multi-GPU setting ------')
|
print('------ Use multi-GPU setting ------')
|
||||||
print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.')
|
|
||||||
# If you dont care about it, just commnet out these line.)
|
|
||||||
print(f'The batch_size is multiplied with num_gpu and multiplied batch_size is {opt.batch_size}')
|
|
||||||
print('if you stuck too long time with multi-GPU setting, try to set --workers 0')
|
print('if you stuck too long time with multi-GPU setting, try to set --workers 0')
|
||||||
# check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1
|
# check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1
|
||||||
|
opt.workers = opt.workers * opt.num_gpu
|
||||||
|
|
||||||
|
""" previous version
|
||||||
|
print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size)
|
||||||
|
opt.batch_size = opt.batch_size * opt.num_gpu
|
||||||
|
print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.')
|
||||||
|
If you dont care about it, just commnet out these line.)
|
||||||
|
opt.num_iter = int(opt.num_iter / opt.num_gpu)
|
||||||
|
"""
|
||||||
|
|
||||||
train(opt)
|
train(opt)
|
||||||
|
|
Loading…
Reference in New Issue