diff --git a/modules/prediction.py b/modules/prediction.py index 20f9623..37afab4 100755 --- a/modules/prediction.py +++ b/modules/prediction.py @@ -71,9 +71,9 @@ class AttentionCell(nn.Module): # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] batch_H_proj = self.i2h(batch_H) prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) - emition = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 + e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 - alpha = F.softmax(emition, dim=1) + alpha = F.softmax(e, dim=1) context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) cur_hidden = self.rnn(concat_context, prev_hidden)