fix encode with batch_max_length
This commit is contained in:
parent
a239f46df3
commit
bda81065a5
2
test.py
2
test.py
|
@ -85,7 +85,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
|
|||
length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
|
||||
text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)
|
||||
|
||||
text_for_loss, length_for_loss = converter.encode(labels)
|
||||
text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
|
||||
|
||||
start_time = time.time()
|
||||
if 'CTC' in opt.Prediction:
|
||||
|
|
2
utils.py
2
utils.py
|
@ -15,7 +15,7 @@ class CTCLabelConverter(object):
|
|||
|
||||
self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0)
|
||||
|
||||
def encode(self, text):
|
||||
def encode(self, text, batch_max_length=25):
|
||||
"""convert text-label into text-index.
|
||||
input:
|
||||
text: text labels of each image. [batch_size]
|
||||
|
|
Loading…
Reference in New Issue