Update prediction.py
This commit is contained in:
parent
f543f85b78
commit
3772d75f40
|
@ -71,9 +71,9 @@ class AttentionCell(nn.Module):
|
||||||
# [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
|
# [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
|
||||||
batch_H_proj = self.i2h(batch_H)
|
batch_H_proj = self.i2h(batch_H)
|
||||||
prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
|
prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
|
||||||
emition = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
|
e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
|
||||||
|
|
||||||
alpha = F.softmax(emition, dim=1)
|
alpha = F.softmax(e, dim=1)
|
||||||
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
|
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
|
||||||
concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding)
|
concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding)
|
||||||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||||
|
|
Loading…
Reference in New Issue