2019-04-05 18:45:29 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
2019-08-03 16:03:46 +08:00
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, input_size, hidden_size, num_classes):
|
|
|
|
super(Attention, self).__init__()
|
|
|
|
self.attention_cell = AttentionCell(input_size, hidden_size, num_classes)
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.generator = nn.Linear(hidden_size, num_classes)
|
|
|
|
|
|
|
|
def _char_to_onehot(self, input_char, onehot_dim=38):
|
|
|
|
input_char = input_char.unsqueeze(1)
|
|
|
|
batch_size = input_char.size(0)
|
2019-08-03 16:03:46 +08:00
|
|
|
one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
|
2019-04-05 18:45:29 +08:00
|
|
|
one_hot = one_hot.scatter_(1, input_char, 1)
|
|
|
|
return one_hot
|
|
|
|
|
2019-04-07 20:41:11 +08:00
|
|
|
def forward(self, batch_H, text, is_train=True, batch_max_length=25):
|
2019-04-05 18:45:29 +08:00
|
|
|
"""
|
|
|
|
input:
|
2020-06-16 00:52:51 +08:00
|
|
|
batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels]
|
2019-04-05 18:45:29 +08:00
|
|
|
text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
|
|
|
|
output: probability distribution at each step [batch_size x num_steps x num_classes]
|
|
|
|
"""
|
|
|
|
batch_size = batch_H.size(0)
|
|
|
|
num_steps = batch_max_length + 1 # +1 for [s] at end of sentence.
|
|
|
|
|
2019-08-03 16:03:46 +08:00
|
|
|
output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device)
|
|
|
|
hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
|
|
|
|
torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device))
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
if is_train:
|
|
|
|
for i in range(num_steps):
|
|
|
|
# one-hot vectors for a i-th char. in a batch
|
|
|
|
char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes)
|
|
|
|
# hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1})
|
|
|
|
hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
|
|
|
|
output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell)
|
|
|
|
probs = self.generator(output_hiddens)
|
|
|
|
|
|
|
|
else:
|
2019-08-03 16:03:46 +08:00
|
|
|
targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token
|
|
|
|
probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device)
|
2019-04-05 18:45:29 +08:00
|
|
|
|
|
|
|
for i in range(num_steps):
|
|
|
|
char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes)
|
|
|
|
hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
|
|
|
|
probs_step = self.generator(hidden[0])
|
|
|
|
probs[:, i, :] = probs_step
|
|
|
|
_, next_input = probs_step.max(1)
|
|
|
|
targets = next_input
|
|
|
|
|
|
|
|
return probs # batch_size x num_steps x num_classes
|
|
|
|
|
|
|
|
|
|
|
|
class AttentionCell(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, input_size, hidden_size, num_embeddings):
|
|
|
|
super(AttentionCell, self).__init__()
|
|
|
|
self.i2h = nn.Linear(input_size, hidden_size, bias=False)
|
|
|
|
self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias
|
|
|
|
self.score = nn.Linear(hidden_size, 1, bias=False)
|
|
|
|
self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
|
|
|
|
def forward(self, prev_hidden, batch_H, char_onehots):
|
|
|
|
# [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
|
|
|
|
batch_H_proj = self.i2h(batch_H)
|
|
|
|
prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
|
2019-04-26 13:27:34 +08:00
|
|
|
e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
|
2019-04-05 18:45:29 +08:00
|
|
|
|
2019-04-26 13:27:34 +08:00
|
|
|
alpha = F.softmax(e, dim=1)
|
2019-04-05 18:45:29 +08:00
|
|
|
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
|
|
|
|
concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding)
|
|
|
|
cur_hidden = self.rnn(concat_context, prev_hidden)
|
|
|
|
return cur_hidden, alpha
|