Doc String Correction: Attention class forward function
This commit is contained in:
parent
d38c3cbfdb
commit
4be75b3877
|
@ -23,7 +23,7 @@ class Attention(nn.Module):
|
||||||
def forward(self, batch_H, text, is_train=True, batch_max_length=25):
|
def forward(self, batch_H, text, is_train=True, batch_max_length=25):
|
||||||
"""
|
"""
|
||||||
input:
|
input:
|
||||||
batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes]
|
batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels]
|
||||||
text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
|
text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
|
||||||
output: probability distribution at each step [batch_size x num_steps x num_classes]
|
output: probability distribution at each step [batch_size x num_steps x num_classes]
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue