Merge pull request #181 from tjdevWorks/master

Doc String Correction: Attention class forward function
This commit is contained in:
Baek JeongHun 2020-06-25 02:19:28 +09:00 committed by GitHub
commit 425a9b3e1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -23,7 +23,7 @@ class Attention(nn.Module):
def forward(self, batch_H, text, is_train=True, batch_max_length=25): def forward(self, batch_H, text, is_train=True, batch_max_length=25):
""" """
input: input:
batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes] batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels]
text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
output: probability distribution at each step [batch_size x num_steps x num_classes] output: probability distribution at each step [batch_size x num_steps x num_classes]
""" """