fix minor

This commit is contained in:
Baek JeongHun 2019-04-07 12:41:11 +00:00
parent e4249bed3d
commit 34e6316856
4 changed files with 7 additions and 8 deletions

4
model.py Normal file → Executable file
View File

@ -67,7 +67,7 @@ class Model(nn.Module):
else: else:
raise Exception('Prediction is neither CTC or Attn') raise Exception('Prediction is neither CTC or Attn')
def forward(self, input, length, text, is_train=True): def forward(self, input, text, is_train=True):
""" Transformation stage """ """ Transformation stage """
if not self.stages['Trans'] == "None": if not self.stages['Trans'] == "None":
input = self.Transformation(input) input = self.Transformation(input)
@ -87,6 +87,6 @@ class Model(nn.Module):
if self.stages['Pred'] == 'CTC': if self.stages['Pred'] == 'CTC':
prediction = self.Prediction(contextual_feature.contiguous()) prediction = self.Prediction(contextual_feature.contiguous())
else: else:
prediction = self.Prediction(contextual_feature.contiguous(), length, text, is_train) prediction = self.Prediction(contextual_feature.contiguous(), text, is_train)
return prediction return prediction

3
modules/prediction.py Normal file → Executable file
View File

@ -19,11 +19,10 @@ class Attention(nn.Module):
one_hot = one_hot.scatter_(1, input_char, 1) one_hot = one_hot.scatter_(1, input_char, 1)
return one_hot return one_hot
def forward(self, batch_H, length, text, is_train=True, batch_max_length=25): def forward(self, batch_H, text, is_train=True, batch_max_length=25):
""" """
input: input:
batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes] batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes]
length : the length of each label. train: [3, 7, ....], test: [25, 25, 25, ...] [batch_size]
text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
output: probability distribution at each step [batch_size x num_steps x num_classes] output: probability distribution at each step [batch_size x num_steps x num_classes]
""" """

4
test.py Normal file → Executable file
View File

@ -85,7 +85,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
start_time = time.time() start_time = time.time()
if 'CTC' in opt.Prediction: if 'CTC' in opt.Prediction:
preds = model(image, length_for_pred, text_for_pred) preds = model(image, text_for_pred)
forward_time = time.time() - start_time forward_time = time.time() - start_time
# Calculate evaluation loss for CTC deocder. # Calculate evaluation loss for CTC deocder.
@ -99,7 +99,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
sim_preds = converter.decode(preds.data, preds_size.data) sim_preds = converter.decode(preds.data, preds_size.data)
else: else:
preds = model(image, length_for_pred, text_for_pred, is_train=False) preds = model(image, text_for_pred, is_train=False)
forward_time = time.time() - start_time forward_time = time.time() - start_time
preds = preds[:, :text_for_loss.shape[1] - 1, :] preds = preds[:, :text_for_loss.shape[1] - 1, :]

4
train.py Normal file → Executable file
View File

@ -131,13 +131,13 @@ def train(opt):
batch_size = image.size(0) batch_size = image.size(0)
if 'CTC' in opt.Prediction: if 'CTC' in opt.Prediction:
preds = model(image, length, text) preds = model(image, text)
preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2) # to use CTCLoss format preds = preds.permute(1, 0, 2) # to use CTCLoss format
cost = criterion(preds, text, preds_size, length) / batch_size cost = criterion(preds, text, preds_size, length) / batch_size
else: else:
preds = model(image, length, text) preds = model(image, text)
target = text[:, 1:] # without [GO] Symbol target = text[:, 1:] # without [GO] Symbol
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))