diff --git a/model.py b/model.py index fd11717..b37aa67 100755 --- a/model.py +++ b/model.py @@ -24,46 +24,45 @@ from modules.prediction import Attention class Model(nn.Module): - def __init__(self, imgH, imgW, input_channel, output_channel, hidden_size, num_class, batch_max_length, - Transformation="None", FeatureExtraction="VGG", - SequenceModeling="BiLSTM", Prediction="CTC", F=20): + def __init__(self, opt): super(Model, self).__init__() - self.stages = {'Trans': Transformation, 'Feat': FeatureExtraction, 'Seq': SequenceModeling, 'Pred': Prediction} + self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, + 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} """ Transformation """ - if Transformation == 'TPS': + if opt.Transformation == 'TPS': self.Transformation = TPS_SpatialTransformerNetwork( - F=F, I_size=(imgH, imgW), I_r_size=(imgH, imgW), I_channel_num=input_channel) + F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), batch_size=opt.batch_size, I_channel_num=opt.input_channel) else: print('No Transformation module specified') """ FeatureExtraction """ - if FeatureExtraction == 'VGG': - self.FeatureExtraction = VGG_FeatureExtractor(input_channel, output_channel) - elif FeatureExtraction == 'RCNN': - self.FeatureExtraction = RCNN_FeatureExtractor(input_channel, output_channel) - elif FeatureExtraction == 'ResNet': - self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel) + if opt.FeatureExtraction == 'VGG': + self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) + elif opt.FeatureExtraction == 'RCNN': + self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel) + elif opt.FeatureExtraction == 'ResNet': + self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) else: raise Exception('No FeatureExtraction module specified') - self.FeatureExtraction_output = output_channel # int(imgH/16-1) * 512 + self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 """ Sequence modeling""" - if SequenceModeling == 'BiLSTM': + if opt.SequenceModeling == 'BiLSTM': self.SequenceModeling = nn.Sequential( - BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size), - BidirectionalLSTM(hidden_size, hidden_size, hidden_size)) - self.SequenceModeling_output = hidden_size + BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), + BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) + self.SequenceModeling_output = opt.hidden_size else: print('No SequenceModeling module specified') self.SequenceModeling_output = self.FeatureExtraction_output """ Prediction """ - if Prediction == 'CTC': - self.Prediction = nn.Linear(self.SequenceModeling_output, num_class) - elif Prediction == 'Attn': - self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class) + if opt.Prediction == 'CTC': + self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) + elif opt.Prediction == 'Attn': + self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) else: raise Exception('Prediction is neither CTC or Attn') diff --git a/modules/transformation.py b/modules/transformation.py old mode 100644 new mode 100755 index 76c1ba4..1ed2dd4 --- a/modules/transformation.py +++ b/modules/transformation.py @@ -7,7 +7,7 @@ import torch.nn.functional as F class TPS_SpatialTransformerNetwork(nn.Module): """ Rectification Network of RARE, namely TPS based STN """ - def __init__(self, F, I_size, I_r_size, I_channel_num=1): + def __init__(self, F, I_size, I_r_size, batch_size, I_channel_num=1): """ Based on RARE TPS input: batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] @@ -23,7 +23,7 @@ class TPS_SpatialTransformerNetwork(nn.Module): self.I_r_size = I_r_size # = (I_r_height, I_r_width) self.I_channel_num = I_channel_num self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) - self.GridGenerator = GridGenerator(self.F, self.I_r_size) + self.GridGenerator = GridGenerator(self.F, self.I_r_size, batch_size) def forward(self, batch_I): batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 @@ -81,7 +81,7 @@ class LocalizationNetwork(nn.Module): class GridGenerator(nn.Module): """ Grid Generator of RARE, which produces P_prime by multipling T with P """ - def __init__(self, F, I_r_size, batch_size=192): + def __init__(self, F, I_r_size, batch_size): """ Generate P_hat and inv_delta_C for later """ super(GridGenerator, self).__init__() self.eps = 1e-6 diff --git a/test.py b/test.py index 99e6092..a31d2dd 100755 --- a/test.py +++ b/test.py @@ -136,11 +136,8 @@ def test(opt): else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) - model = Model(opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, - opt.num_class, opt.batch_max_length, - Transformation=opt.Transformation, FeatureExtraction=opt.FeatureExtraction, - SequenceModeling=opt.SequenceModeling, Prediction=opt.Prediction) - print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, + model = Model(opt) + print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).cuda() @@ -201,6 +198,7 @@ if __name__ == '__main__': parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') + parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') parser.add_argument('--output_channel', type=int, default=512, help='the number of output channel of Feature extractor') diff --git a/train.py b/train.py index 645fd10..1d2f233 100755 --- a/train.py +++ b/train.py @@ -41,11 +41,8 @@ def train(opt): opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 - model = Model(opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, - opt.num_class, opt.batch_max_length, - Transformation=opt.Transformation, FeatureExtraction=opt.FeatureExtraction, - SequenceModeling=opt.SequenceModeling, Prediction=opt.Prediction) - print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, + model = Model(opt) + print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) @@ -230,6 +227,7 @@ if __name__ == '__main__': parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') + parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') parser.add_argument('--output_channel', type=int, default=512, help='the number of output channel of Feature extractor') diff --git a/utils.py b/utils.py old mode 100644 new mode 100755