2019-04-05 18:45:29 +08:00
|
|
|
"""
|
|
|
|
Copyright (c) 2019-present NAVER Corp.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
from modules.transformation import TPS_SpatialTransformerNetwork
|
|
|
|
from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor
|
|
|
|
from modules.sequence_modeling import BidirectionalLSTM
|
|
|
|
from modules.prediction import Attention
|
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, imgH, imgW, input_channel, output_channel, hidden_size, num_class, batch_max_length,
|
|
|
|
Transformation="None", FeatureExtraction="VGG",
|
|
|
|
SequenceModeling="BiLSTM", Prediction="CTC", F=20):
|
|
|
|
super(Model, self).__init__()
|
|
|
|
self.stages = {'Trans': Transformation, 'Feat': FeatureExtraction, 'Seq': SequenceModeling, 'Pred': Prediction}
|
|
|
|
|
|
|
|
""" Transformation """
|
|
|
|
if Transformation == 'TPS':
|
|
|
|
self.Transformation = TPS_SpatialTransformerNetwork(
|
|
|
|
F=F, I_size=(imgH, imgW), I_r_size=(imgH, imgW), I_channel_num=input_channel)
|
|
|
|
else:
|
|
|
|
print('No Transformation module specified')
|
|
|
|
|
|
|
|
""" FeatureExtraction """
|
|
|
|
if FeatureExtraction == 'VGG':
|
|
|
|
self.FeatureExtraction = VGG_FeatureExtractor(input_channel, output_channel)
|
|
|
|
elif FeatureExtraction == 'RCNN':
|
|
|
|
self.FeatureExtraction = RCNN_FeatureExtractor(input_channel, output_channel)
|
|
|
|
elif FeatureExtraction == 'ResNet':
|
|
|
|
self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel)
|
|
|
|
else:
|
|
|
|
raise Exception('No FeatureExtraction module specified')
|
|
|
|
self.FeatureExtraction_output = output_channel # int(imgH/16-1) * 512
|
|
|
|
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
|
|
|
|
|
|
|
|
""" Sequence modeling"""
|
|
|
|
if SequenceModeling == 'BiLSTM':
|
|
|
|
self.SequenceModeling = nn.Sequential(
|
|
|
|
BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
|
|
|
|
BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
|
|
|
|
self.SequenceModeling_output = hidden_size
|
|
|
|
else:
|
|
|
|
print('No SequenceModeling module specified')
|
|
|
|
self.SequenceModeling_output = self.FeatureExtraction_output
|
|
|
|
|
|
|
|
""" Prediction """
|
|
|
|
if Prediction == 'CTC':
|
|
|
|
self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
|
|
|
|
elif Prediction == 'Attn':
|
|
|
|
self.Prediction = Attention(self.SequenceModeling_output, hidden_size, num_class)
|
|
|
|
else:
|
|
|
|
raise Exception('Prediction is neither CTC or Attn')
|
|
|
|
|
2019-04-07 20:41:11 +08:00
|
|
|
def forward(self, input, text, is_train=True):
|
2019-04-05 18:45:29 +08:00
|
|
|
""" Transformation stage """
|
|
|
|
if not self.stages['Trans'] == "None":
|
|
|
|
input = self.Transformation(input)
|
|
|
|
|
|
|
|
""" Feature extraction stage """
|
|
|
|
visual_feature = self.FeatureExtraction(input)
|
|
|
|
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
|
|
|
|
visual_feature = visual_feature.squeeze(3)
|
|
|
|
|
|
|
|
""" Sequence modeling stage """
|
|
|
|
if self.stages['Seq'] == 'BiLSTM':
|
|
|
|
contextual_feature = self.SequenceModeling(visual_feature)
|
|
|
|
else:
|
|
|
|
contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM
|
|
|
|
|
|
|
|
""" Prediction stage """
|
|
|
|
if self.stages['Pred'] == 'CTC':
|
|
|
|
prediction = self.Prediction(contextual_feature.contiguous())
|
|
|
|
else:
|
2019-04-07 20:41:11 +08:00
|
|
|
prediction = self.Prediction(contextual_feature.contiguous(), text, is_train)
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
return prediction
|