update model arg

This commit is contained in:
Baek JeongHun 2019-04-09 09:06:32 +00:00
parent d322f4ff00
commit 7da666fc3a
5 changed files with 29 additions and 34 deletions

View File

@ -24,46 +24,45 @@ from modules.prediction import Attention
class Model(nn.Module):
def __init__(self, imgH, imgW, input_channel, output_channel, hidden_size, num_class, batch_max_length,
Transformation="None", FeatureExtraction="VGG",
SequenceModeling="BiLSTM", Prediction="CTC", F=20):
def __init__(self, opt):
super(Model, self).__init__()
self.stages = {'Trans': Transformation, 'Feat': FeatureExtraction, 'Seq': SequenceModeling, 'Pred': Prediction}
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
""" Transformation """
if Transformation == 'TPS':
if opt.Transformation == 'TPS':
self.Transformation = TPS_SpatialTransformerNetwork(
F=F, I_size=(imgH, imgW), I_r_size=(imgH, imgW), I_channel_num=input_channel)
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), batch_size=opt.batch_size, I_channel_num=opt.input_channel)
else:
print('No Transformation module specified')
""" FeatureExtraction """
if FeatureExtraction == 'VGG':
self.FeatureExtraction = VGG_FeatureExtractor(input_channel, output_channel)
elif FeatureExtraction == 'RCNN':
self.FeatureExtraction = RCNN_FeatureExtractor(input_channel, output_channel)
elif FeatureExtraction == 'ResNet':
self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel)
if opt.FeatureExtraction == 'VGG':
self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'RCNN':
self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'ResNet':
self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
else:
raise Exception('No FeatureExtraction module specified')
self.FeatureExtraction_output = output_channel # int(imgH/16-1) * 512
self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
""" Sequence modeling"""
if SequenceModeling == 'BiLSTM':
if opt.SequenceModeling == 'BiLSTM':
self.SequenceModeling = nn.Sequential(
BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
self.SequenceModeling_output = hidden_size
BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
self.SequenceModeling_output = opt.hidden_size
else:
print('No SequenceModeling module specified')
self.SequenceModeling_output = self.FeatureExtraction_output
""" Prediction """
if Prediction == 'CTC':
self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
elif Prediction == 'Attn':
self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class)
if opt.Prediction == 'CTC':
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
elif opt.Prediction == 'Attn':
self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class)
else:
raise Exception('Prediction is neither CTC or Attn')

6
modules/transformation.py Normal file → Executable file
View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
class TPS_SpatialTransformerNetwork(nn.Module):
""" Rectification Network of RARE, namely TPS based STN """
def __init__(self, F, I_size, I_r_size, I_channel_num=1):
def __init__(self, F, I_size, I_r_size, batch_size, I_channel_num=1):
""" Based on RARE TPS
input:
batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
@ -23,7 +23,7 @@ class TPS_SpatialTransformerNetwork(nn.Module):
self.I_r_size = I_r_size # = (I_r_height, I_r_width)
self.I_channel_num = I_channel_num
self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
self.GridGenerator = GridGenerator(self.F, self.I_r_size)
self.GridGenerator = GridGenerator(self.F, self.I_r_size, batch_size)
def forward(self, batch_I):
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
@ -81,7 +81,7 @@ class LocalizationNetwork(nn.Module):
class GridGenerator(nn.Module):
""" Grid Generator of RARE, which produces P_prime by multipling T with P """
def __init__(self, F, I_r_size, batch_size=192):
def __init__(self, F, I_r_size, batch_size):
""" Generate P_hat and inv_delta_C for later """
super(GridGenerator, self).__init__()
self.eps = 1e-6

View File

@ -136,11 +136,8 @@ def test(opt):
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
model = Model(opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size,
opt.num_class, opt.batch_max_length,
Transformation=opt.Transformation, FeatureExtraction=opt.FeatureExtraction,
SequenceModeling=opt.SequenceModeling, Prediction=opt.Prediction)
print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel,
model = Model(opt)
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
opt.SequenceModeling, opt.Prediction)
model = torch.nn.DataParallel(model).cuda()
@ -201,6 +198,7 @@ if __name__ == '__main__':
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
parser.add_argument('--output_channel', type=int, default=512,
help='the number of output channel of Feature extractor')

View File

@ -41,11 +41,8 @@ def train(opt):
opt.num_class = len(converter.character)
if opt.rgb:
opt.input_channel = 3
model = Model(opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size,
opt.num_class, opt.batch_max_length,
Transformation=opt.Transformation, FeatureExtraction=opt.FeatureExtraction,
SequenceModeling=opt.SequenceModeling, Prediction=opt.Prediction)
print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel,
model = Model(opt)
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
opt.SequenceModeling, opt.Prediction)
@ -230,6 +227,7 @@ if __name__ == '__main__':
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
parser.add_argument('--output_channel', type=int, default=512,
help='the number of output channel of Feature extractor')

0
utils.py Normal file → Executable file
View File