Update prediction.py
This commit is contained in:
parent
f543f85b78
commit
3772d75f40
|
@ -71,9 +71,9 @@ class AttentionCell(nn.Module):
|
|||
# [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
|
||||
batch_H_proj = self.i2h(batch_H)
|
||||
prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
|
||||
emition = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
|
||||
e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
|
||||
|
||||
alpha = F.softmax(emition, dim=1)
|
||||
alpha = F.softmax(e, dim=1)
|
||||
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
|
||||
concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding)
|
||||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||
|
|
Loading…
Reference in New Issue