diff --git a/demos/classifier.py b/demos/classifier.py index 118931b..fb1f75b 100755 --- a/demos/classifier.py +++ b/demos/classifier.py @@ -35,6 +35,8 @@ import pandas as pd import openface +from sklearn.pipeline import Pipeline +from sklearn.lda import LDA from sklearn.preprocessing import LabelEncoder from sklearn.svm import SVC from sklearn.mixture import GMM @@ -100,6 +102,11 @@ def train(args): elif args.classifier == 'GMM': clf = GMM(n_components=nClasses) + if args.ldaDim > 0: + clf_final = clf + clf = Pipeline([('lda', LDA(n_components=args.ldaDim)), + ('clf', clf_final)]) + clf.fit(embeddings, labelsNum) fName = "{}/classifier.pkl".format(args.workDir) @@ -147,6 +154,7 @@ if __name__ == '__main__': subparsers = parser.add_subparsers(dest='mode', help="Mode") trainParser = subparsers.add_parser('train', help="Train a new classifier.") + trainParser.add_argument('--ldaDim', type=int, default=-1) trainParser.add_argument('--classifier', type=str, choices=['LinearSvm', 'GMM'], help='The type of classifier to use.',