update model arg
This commit is contained in:
parent
d322f4ff00
commit
7da666fc3a
41
model.py
41
model.py
|
@ -24,46 +24,45 @@ from modules.prediction import Attention
|
|||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self, imgH, imgW, input_channel, output_channel, hidden_size, num_class, batch_max_length,
|
||||
Transformation="None", FeatureExtraction="VGG",
|
||||
SequenceModeling="BiLSTM", Prediction="CTC", F=20):
|
||||
def __init__(self, opt):
|
||||
super(Model, self).__init__()
|
||||
self.stages = {'Trans': Transformation, 'Feat': FeatureExtraction, 'Seq': SequenceModeling, 'Pred': Prediction}
|
||||
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
|
||||
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
|
||||
|
||||
""" Transformation """
|
||||
if Transformation == 'TPS':
|
||||
if opt.Transformation == 'TPS':
|
||||
self.Transformation = TPS_SpatialTransformerNetwork(
|
||||
F=F, I_size=(imgH, imgW), I_r_size=(imgH, imgW), I_channel_num=input_channel)
|
||||
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), batch_size=opt.batch_size, I_channel_num=opt.input_channel)
|
||||
else:
|
||||
print('No Transformation module specified')
|
||||
|
||||
""" FeatureExtraction """
|
||||
if FeatureExtraction == 'VGG':
|
||||
self.FeatureExtraction = VGG_FeatureExtractor(input_channel, output_channel)
|
||||
elif FeatureExtraction == 'RCNN':
|
||||
self.FeatureExtraction = RCNN_FeatureExtractor(input_channel, output_channel)
|
||||
elif FeatureExtraction == 'ResNet':
|
||||
self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel)
|
||||
if opt.FeatureExtraction == 'VGG':
|
||||
self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
|
||||
elif opt.FeatureExtraction == 'RCNN':
|
||||
self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel)
|
||||
elif opt.FeatureExtraction == 'ResNet':
|
||||
self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
|
||||
else:
|
||||
raise Exception('No FeatureExtraction module specified')
|
||||
self.FeatureExtraction_output = output_channel # int(imgH/16-1) * 512
|
||||
self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512
|
||||
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
|
||||
|
||||
""" Sequence modeling"""
|
||||
if SequenceModeling == 'BiLSTM':
|
||||
if opt.SequenceModeling == 'BiLSTM':
|
||||
self.SequenceModeling = nn.Sequential(
|
||||
BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
|
||||
BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
|
||||
self.SequenceModeling_output = hidden_size
|
||||
BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
|
||||
BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
|
||||
self.SequenceModeling_output = opt.hidden_size
|
||||
else:
|
||||
print('No SequenceModeling module specified')
|
||||
self.SequenceModeling_output = self.FeatureExtraction_output
|
||||
|
||||
""" Prediction """
|
||||
if Prediction == 'CTC':
|
||||
self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
|
||||
elif Prediction == 'Attn':
|
||||
self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class)
|
||||
if opt.Prediction == 'CTC':
|
||||
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
|
||||
elif opt.Prediction == 'Attn':
|
||||
self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class)
|
||||
else:
|
||||
raise Exception('Prediction is neither CTC or Attn')
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||
class TPS_SpatialTransformerNetwork(nn.Module):
|
||||
""" Rectification Network of RARE, namely TPS based STN """
|
||||
|
||||
def __init__(self, F, I_size, I_r_size, I_channel_num=1):
|
||||
def __init__(self, F, I_size, I_r_size, batch_size, I_channel_num=1):
|
||||
""" Based on RARE TPS
|
||||
input:
|
||||
batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
|
||||
|
@ -23,7 +23,7 @@ class TPS_SpatialTransformerNetwork(nn.Module):
|
|||
self.I_r_size = I_r_size # = (I_r_height, I_r_width)
|
||||
self.I_channel_num = I_channel_num
|
||||
self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
|
||||
self.GridGenerator = GridGenerator(self.F, self.I_r_size)
|
||||
self.GridGenerator = GridGenerator(self.F, self.I_r_size, batch_size)
|
||||
|
||||
def forward(self, batch_I):
|
||||
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
|
||||
|
@ -81,7 +81,7 @@ class LocalizationNetwork(nn.Module):
|
|||
class GridGenerator(nn.Module):
|
||||
""" Grid Generator of RARE, which produces P_prime by multipling T with P """
|
||||
|
||||
def __init__(self, F, I_r_size, batch_size=192):
|
||||
def __init__(self, F, I_r_size, batch_size):
|
||||
""" Generate P_hat and inv_delta_C for later """
|
||||
super(GridGenerator, self).__init__()
|
||||
self.eps = 1e-6
|
||||
|
|
8
test.py
8
test.py
|
@ -136,11 +136,8 @@ def test(opt):
|
|||
else:
|
||||
converter = AttnLabelConverter(opt.character)
|
||||
opt.num_class = len(converter.character)
|
||||
model = Model(opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size,
|
||||
opt.num_class, opt.batch_max_length,
|
||||
Transformation=opt.Transformation, FeatureExtraction=opt.FeatureExtraction,
|
||||
SequenceModeling=opt.SequenceModeling, Prediction=opt.Prediction)
|
||||
print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel,
|
||||
model = Model(opt)
|
||||
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
|
||||
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
|
||||
opt.SequenceModeling, opt.Prediction)
|
||||
model = torch.nn.DataParallel(model).cuda()
|
||||
|
@ -201,6 +198,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
|
||||
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
|
||||
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
|
||||
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
|
||||
parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
|
||||
parser.add_argument('--output_channel', type=int, default=512,
|
||||
help='the number of output channel of Feature extractor')
|
||||
|
|
8
train.py
8
train.py
|
@ -41,11 +41,8 @@ def train(opt):
|
|||
opt.num_class = len(converter.character)
|
||||
if opt.rgb:
|
||||
opt.input_channel = 3
|
||||
model = Model(opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size,
|
||||
opt.num_class, opt.batch_max_length,
|
||||
Transformation=opt.Transformation, FeatureExtraction=opt.FeatureExtraction,
|
||||
SequenceModeling=opt.SequenceModeling, Prediction=opt.Prediction)
|
||||
print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel,
|
||||
model = Model(opt)
|
||||
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
|
||||
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
|
||||
opt.SequenceModeling, opt.Prediction)
|
||||
|
||||
|
@ -230,6 +227,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
|
||||
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
|
||||
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
|
||||
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
|
||||
parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
|
||||
parser.add_argument('--output_channel', type=int, default=512,
|
||||
help='the number of output channel of Feature extractor')
|
||||
|
|
Loading…
Reference in New Issue