From 3772d75f405349bdccabf2cb36acd955b7bbd3aa Mon Sep 17 00:00:00 2001 From: Baek JeongHun Date: Fri, 26 Apr 2019 14:27:34 +0900 Subject: [PATCH] Update prediction.py --- modules/prediction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/prediction.py b/modules/prediction.py index 20f9623..37afab4 100755 --- a/modules/prediction.py +++ b/modules/prediction.py @@ -71,9 +71,9 @@ class AttentionCell(nn.Module): # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] batch_H_proj = self.i2h(batch_H) prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) - emition = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 + e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 - alpha = F.softmax(emition, dim=1) + alpha = F.softmax(e, dim=1) context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) cur_hidden = self.rnn(concat_context, prev_hidden)