fix minor
This commit is contained in:
parent
e4249bed3d
commit
34e6316856
|
@ -67,7 +67,7 @@ class Model(nn.Module):
|
|||
else:
|
||||
raise Exception('Prediction is neither CTC or Attn')
|
||||
|
||||
def forward(self, input, length, text, is_train=True):
|
||||
def forward(self, input, text, is_train=True):
|
||||
""" Transformation stage """
|
||||
if not self.stages['Trans'] == "None":
|
||||
input = self.Transformation(input)
|
||||
|
@ -87,6 +87,6 @@ class Model(nn.Module):
|
|||
if self.stages['Pred'] == 'CTC':
|
||||
prediction = self.Prediction(contextual_feature.contiguous())
|
||||
else:
|
||||
prediction = self.Prediction(contextual_feature.contiguous(), length, text, is_train)
|
||||
prediction = self.Prediction(contextual_feature.contiguous(), text, is_train)
|
||||
|
||||
return prediction
|
||||
|
|
|
@ -19,11 +19,10 @@ class Attention(nn.Module):
|
|||
one_hot = one_hot.scatter_(1, input_char, 1)
|
||||
return one_hot
|
||||
|
||||
def forward(self, batch_H, length, text, is_train=True, batch_max_length=25):
|
||||
def forward(self, batch_H, text, is_train=True, batch_max_length=25):
|
||||
"""
|
||||
input:
|
||||
batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes]
|
||||
length : the length of each label. train: [3, 7, ....], test: [25, 25, 25, ...] [batch_size]
|
||||
text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
|
||||
output: probability distribution at each step [batch_size x num_steps x num_classes]
|
||||
"""
|
||||
|
|
|
@ -85,7 +85,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
|||
|
||||
start_time = time.time()
|
||||
if 'CTC' in opt.Prediction:
|
||||
preds = model(image, length_for_pred, text_for_pred)
|
||||
preds = model(image, text_for_pred)
|
||||
forward_time = time.time() - start_time
|
||||
|
||||
# Calculate evaluation loss for CTC deocder.
|
||||
|
@ -99,7 +99,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
|||
sim_preds = converter.decode(preds.data, preds_size.data)
|
||||
|
||||
else:
|
||||
preds = model(image, length_for_pred, text_for_pred, is_train=False)
|
||||
preds = model(image, text_for_pred, is_train=False)
|
||||
forward_time = time.time() - start_time
|
||||
|
||||
preds = preds[:, :text_for_loss.shape[1] - 1, :]
|
||||
|
|
|
@ -131,13 +131,13 @@ def train(opt):
|
|||
batch_size = image.size(0)
|
||||
|
||||
if 'CTC' in opt.Prediction:
|
||||
preds = model(image, length, text)
|
||||
preds = model(image, text)
|
||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||
preds = preds.permute(1, 0, 2) # to use CTCLoss format
|
||||
cost = criterion(preds, text, preds_size, length) / batch_size
|
||||
|
||||
else:
|
||||
preds = model(image, length, text)
|
||||
preds = model(image, text)
|
||||
target = text[:, 1:] # without [GO] Symbol
|
||||
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
|
||||
|
||||
|
|
Loading…
Reference in New Issue