cls->clf
This commit is contained in:
parent
437603e45d
commit
2035eeb8a5
|
@ -96,35 +96,35 @@ def train(args):
|
||||||
print("Training for {} classes.".format(nClasses))
|
print("Training for {} classes.".format(nClasses))
|
||||||
|
|
||||||
if args.classifier == 'LinearSvm':
|
if args.classifier == 'LinearSvm':
|
||||||
cls = SVC(C=1, kernel='linear', probability=True)
|
clf = SVC(C=1, kernel='linear', probability=True)
|
||||||
elif args.classifier == 'GMM':
|
elif args.classifier == 'GMM':
|
||||||
cls = GMM(n_components=nClasses)
|
clf = GMM(n_components=nClasses)
|
||||||
|
|
||||||
cls.fit(embeddings, labelsNum)
|
clf.fit(embeddings, labelsNum)
|
||||||
|
|
||||||
fName = "{}/classifier.pkl".format(args.workDir)
|
fName = "{}/classifier.pkl".format(args.workDir)
|
||||||
print("Saving classifier to '{}'".format(fName))
|
print("Saving classifier to '{}'".format(fName))
|
||||||
with open(fName, 'w') as f:
|
with open(fName, 'w') as f:
|
||||||
pickle.dump((le, cls), f)
|
pickle.dump((le, clf), f)
|
||||||
|
|
||||||
|
|
||||||
def infer(args):
|
def infer(args):
|
||||||
with open(args.classifierModel, 'r') as f:
|
with open(args.classifierModel, 'r') as f:
|
||||||
(le, cls) = pickle.load(f)
|
(le, clf) = pickle.load(f)
|
||||||
|
|
||||||
for img in args.imgs:
|
for img in args.imgs:
|
||||||
print("\n=== {} ===".format(img))
|
print("\n=== {} ===".format(img))
|
||||||
rep = getRep(img).reshape(1, -1)
|
rep = getRep(img).reshape(1, -1)
|
||||||
start = time.time()
|
start = time.time()
|
||||||
predictions = cls.predict_proba(rep).ravel()
|
predictions = clf.predict_proba(rep).ravel()
|
||||||
maxI = np.argmax(predictions)
|
maxI = np.argmax(predictions)
|
||||||
person = le.inverse_transform(maxI)
|
person = le.inverse_transform(maxI)
|
||||||
confidence = predictions[maxI]
|
confidence = predictions[maxI]
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
print("Prediction took {} seconds.".format(time.time() - start))
|
print("Prediction took {} seconds.".format(time.time() - start))
|
||||||
print("Predict {} with {:.2f} confidence.".format(person, confidence))
|
print("Predict {} with {:.2f} confidence.".format(person, confidence))
|
||||||
if isinstance(cls, GMM):
|
if isinstance(clf, GMM):
|
||||||
dist = np.linalg.norm(rep - cls.means_[maxI])
|
dist = np.linalg.norm(rep - clf.means_[maxI])
|
||||||
print(" + Distance from the mean: {}".format(dist))
|
print(" + Distance from the mean: {}".format(dist))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue