From 34e6316856f090a7acf591b7c647bc2aebabf4c3 Mon Sep 17 00:00:00 2001 From: Baek JeongHun Date: Sun, 7 Apr 2019 12:41:11 +0000 Subject: [PATCH] fix minor --- model.py | 4 ++-- modules/prediction.py | 3 +-- test.py | 4 ++-- train.py | 4 ++-- 4 files changed, 7 insertions(+), 8 deletions(-) mode change 100644 => 100755 model.py mode change 100644 => 100755 modules/prediction.py mode change 100644 => 100755 test.py mode change 100644 => 100755 train.py diff --git a/model.py b/model.py old mode 100644 new mode 100755 index 148fee2..fd11717 --- a/model.py +++ b/model.py @@ -67,7 +67,7 @@ class Model(nn.Module): else: raise Exception('Prediction is neither CTC or Attn') - def forward(self, input, length, text, is_train=True): + def forward(self, input, text, is_train=True): """ Transformation stage """ if not self.stages['Trans'] == "None": input = self.Transformation(input) @@ -87,6 +87,6 @@ class Model(nn.Module): if self.stages['Pred'] == 'CTC': prediction = self.Prediction(contextual_feature.contiguous()) else: - prediction = self.Prediction(contextual_feature.contiguous(), length, text, is_train) + prediction = self.Prediction(contextual_feature.contiguous(), text, is_train) return prediction diff --git a/modules/prediction.py b/modules/prediction.py old mode 100644 new mode 100755 index ee34b6d..20f9623 --- a/modules/prediction.py +++ b/modules/prediction.py @@ -19,11 +19,10 @@ class Attention(nn.Module): one_hot = one_hot.scatter_(1, input_char, 1) return one_hot - def forward(self, batch_H, length, text, is_train=True, batch_max_length=25): + def forward(self, batch_H, text, is_train=True, batch_max_length=25): """ input: batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes] - length : the length of each label. train: [3, 7, ....], test: [25, 25, 25, ...] [batch_size] text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. output: probability distribution at each step [batch_size x num_steps x num_classes] """ diff --git a/test.py b/test.py old mode 100644 new mode 100755 index 814e7a1..87a92b2 --- a/test.py +++ b/test.py @@ -85,7 +85,7 @@ def validation(model, criterion, evaluation_loader, converter, opt): start_time = time.time() if 'CTC' in opt.Prediction: - preds = model(image, length_for_pred, text_for_pred) + preds = model(image, text_for_pred) forward_time = time.time() - start_time # Calculate evaluation loss for CTC deocder. @@ -99,7 +99,7 @@ def validation(model, criterion, evaluation_loader, converter, opt): sim_preds = converter.decode(preds.data, preds_size.data) else: - preds = model(image, length_for_pred, text_for_pred, is_train=False) + preds = model(image, text_for_pred, is_train=False) forward_time = time.time() - start_time preds = preds[:, :text_for_loss.shape[1] - 1, :] diff --git a/train.py b/train.py old mode 100644 new mode 100755 index a51065f..6629e49 --- a/train.py +++ b/train.py @@ -131,13 +131,13 @@ def train(opt): batch_size = image.size(0) if 'CTC' in opt.Prediction: - preds = model(image, length, text) + preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: - preds = model(image, length, text) + preds = model(image, text) target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))