fix encode with batch_max_length

This commit is contained in:
Baek JeongHun 2019-07-31 08:03:53 +00:00
parent a239f46df3
commit bda81065a5
2 changed files with 2 additions and 2 deletions

View File

@ -85,7 +85,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)
text_for_loss, length_for_loss = converter.encode(labels)
text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
start_time = time.time()
if 'CTC' in opt.Prediction:

View File

@ -15,7 +15,7 @@ class CTCLabelConverter(object):
self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0)
def encode(self, text):
def encode(self, text, batch_max_length=25):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]