This commit is contained in:
Brandon Amos 2016-03-18 16:19:08 -04:00
parent 437603e45d
commit 2035eeb8a5
1 changed files with 8 additions and 8 deletions

View File

@ -96,35 +96,35 @@ def train(args):
print("Training for {} classes.".format(nClasses)) print("Training for {} classes.".format(nClasses))
if args.classifier == 'LinearSvm': if args.classifier == 'LinearSvm':
cls = SVC(C=1, kernel='linear', probability=True) clf = SVC(C=1, kernel='linear', probability=True)
elif args.classifier == 'GMM': elif args.classifier == 'GMM':
cls = GMM(n_components=nClasses) clf = GMM(n_components=nClasses)
cls.fit(embeddings, labelsNum) clf.fit(embeddings, labelsNum)
fName = "{}/classifier.pkl".format(args.workDir) fName = "{}/classifier.pkl".format(args.workDir)
print("Saving classifier to '{}'".format(fName)) print("Saving classifier to '{}'".format(fName))
with open(fName, 'w') as f: with open(fName, 'w') as f:
pickle.dump((le, cls), f) pickle.dump((le, clf), f)
def infer(args): def infer(args):
with open(args.classifierModel, 'r') as f: with open(args.classifierModel, 'r') as f:
(le, cls) = pickle.load(f) (le, clf) = pickle.load(f)
for img in args.imgs: for img in args.imgs:
print("\n=== {} ===".format(img)) print("\n=== {} ===".format(img))
rep = getRep(img).reshape(1, -1) rep = getRep(img).reshape(1, -1)
start = time.time() start = time.time()
predictions = cls.predict_proba(rep).ravel() predictions = clf.predict_proba(rep).ravel()
maxI = np.argmax(predictions) maxI = np.argmax(predictions)
person = le.inverse_transform(maxI) person = le.inverse_transform(maxI)
confidence = predictions[maxI] confidence = predictions[maxI]
if args.verbose: if args.verbose:
print("Prediction took {} seconds.".format(time.time() - start)) print("Prediction took {} seconds.".format(time.time() - start))
print("Predict {} with {:.2f} confidence.".format(person, confidence)) print("Predict {} with {:.2f} confidence.".format(person, confidence))
if isinstance(cls, GMM): if isinstance(clf, GMM):
dist = np.linalg.norm(rep - cls.means_[maxI]) dist = np.linalg.norm(rep - clf.means_[maxI])
print(" + Distance from the mean: {}".format(dist)) print(" + Distance from the mean: {}".format(dist))