diff --git a/demos/classifier.py b/demos/classifier.py index 7ab80a1..119206a 100755 --- a/demos/classifier.py +++ b/demos/classifier.py @@ -145,9 +145,10 @@ if __name__ == '__main__': args = parser.parse_args() - if args.classifierModel.endswith(".t7"): + if args.mode == 'infer' and args.classifierModel.endswith(".t7"): raise Exception(""" -Torch network model passed as the classification model. +Torch network model passed as the classification model, +which should be a Python pickle (.pkl) See the documentation for the distinction between the Torch network and classification models: