Update prediction.py

This commit is contained in:
Baek JeongHun 2019-04-26 14:27:34 +09:00 committed by GitHub
parent f543f85b78
commit 3772d75f40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -71,9 +71,9 @@ class AttentionCell(nn.Module):
# [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
batch_H_proj = self.i2h(batch_H) batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
emition = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
alpha = F.softmax(emition, dim=1) alpha = F.softmax(e, dim=1)
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding)
cur_hidden = self.rnn(concat_context, prev_hidden) cur_hidden = self.rnn(concat_context, prev_hidden)