diff --git a/test.py b/test.py index 8c01134..7a35fd8 100755 --- a/test.py +++ b/test.py @@ -126,7 +126,10 @@ def validation(model, criterion, evaluation_loader, converter, opt): if pred == gt: n_correct += 1 - norm_ED += edit_distance(pred, gt) / len(gt) + if len(gt) == 0: + norm_ED += 1 + else: + norm_ED += edit_distance(pred, gt) / len(gt) accuracy = n_correct / float(length_of_data) * 100