diff --git a/test.py b/test.py index d8fbd65..6603f78 100755 --- a/test.py +++ b/test.py @@ -135,6 +135,10 @@ def test(opt): converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) + + if opt.rgb: + opt.input_channel = 3 + opt.num_class = len(converter.character) model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,