fix minor
This commit is contained in:
parent
e4249bed3d
commit
34e6316856
|
@ -67,7 +67,7 @@ class Model(nn.Module):
|
||||||
else:
|
else:
|
||||||
raise Exception('Prediction is neither CTC or Attn')
|
raise Exception('Prediction is neither CTC or Attn')
|
||||||
|
|
||||||
def forward(self, input, length, text, is_train=True):
|
def forward(self, input, text, is_train=True):
|
||||||
""" Transformation stage """
|
""" Transformation stage """
|
||||||
if not self.stages['Trans'] == "None":
|
if not self.stages['Trans'] == "None":
|
||||||
input = self.Transformation(input)
|
input = self.Transformation(input)
|
||||||
|
@ -87,6 +87,6 @@ class Model(nn.Module):
|
||||||
if self.stages['Pred'] == 'CTC':
|
if self.stages['Pred'] == 'CTC':
|
||||||
prediction = self.Prediction(contextual_feature.contiguous())
|
prediction = self.Prediction(contextual_feature.contiguous())
|
||||||
else:
|
else:
|
||||||
prediction = self.Prediction(contextual_feature.contiguous(), length, text, is_train)
|
prediction = self.Prediction(contextual_feature.contiguous(), text, is_train)
|
||||||
|
|
||||||
return prediction
|
return prediction
|
||||||
|
|
|
@ -19,11 +19,10 @@ class Attention(nn.Module):
|
||||||
one_hot = one_hot.scatter_(1, input_char, 1)
|
one_hot = one_hot.scatter_(1, input_char, 1)
|
||||||
return one_hot
|
return one_hot
|
||||||
|
|
||||||
def forward(self, batch_H, length, text, is_train=True, batch_max_length=25):
|
def forward(self, batch_H, text, is_train=True, batch_max_length=25):
|
||||||
"""
|
"""
|
||||||
input:
|
input:
|
||||||
batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes]
|
batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes]
|
||||||
length : the length of each label. train: [3, 7, ....], test: [25, 25, 25, ...] [batch_size]
|
|
||||||
text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
|
text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
|
||||||
output: probability distribution at each step [batch_size x num_steps x num_classes]
|
output: probability distribution at each step [batch_size x num_steps x num_classes]
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -85,7 +85,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if 'CTC' in opt.Prediction:
|
if 'CTC' in opt.Prediction:
|
||||||
preds = model(image, length_for_pred, text_for_pred)
|
preds = model(image, text_for_pred)
|
||||||
forward_time = time.time() - start_time
|
forward_time = time.time() - start_time
|
||||||
|
|
||||||
# Calculate evaluation loss for CTC deocder.
|
# Calculate evaluation loss for CTC deocder.
|
||||||
|
@ -99,7 +99,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
||||||
sim_preds = converter.decode(preds.data, preds_size.data)
|
sim_preds = converter.decode(preds.data, preds_size.data)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
preds = model(image, length_for_pred, text_for_pred, is_train=False)
|
preds = model(image, text_for_pred, is_train=False)
|
||||||
forward_time = time.time() - start_time
|
forward_time = time.time() - start_time
|
||||||
|
|
||||||
preds = preds[:, :text_for_loss.shape[1] - 1, :]
|
preds = preds[:, :text_for_loss.shape[1] - 1, :]
|
||||||
|
|
|
@ -131,13 +131,13 @@ def train(opt):
|
||||||
batch_size = image.size(0)
|
batch_size = image.size(0)
|
||||||
|
|
||||||
if 'CTC' in opt.Prediction:
|
if 'CTC' in opt.Prediction:
|
||||||
preds = model(image, length, text)
|
preds = model(image, text)
|
||||||
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
||||||
preds = preds.permute(1, 0, 2) # to use CTCLoss format
|
preds = preds.permute(1, 0, 2) # to use CTCLoss format
|
||||||
cost = criterion(preds, text, preds_size, length) / batch_size
|
cost = criterion(preds, text, preds_size, length) / batch_size
|
||||||
|
|
||||||
else:
|
else:
|
||||||
preds = model(image, length, text)
|
preds = model(image, text)
|
||||||
target = text[:, 1:] # without [GO] Symbol
|
target = text[:, 1:] # without [GO] Symbol
|
||||||
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
|
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue