LFW experiment: Show avg and stdev of accuracies.

This commit is contained in:
Brandon Amos 2015-09-26 04:41:48 -04:00
parent d85ee28f93
commit 5bffd697cc
1 changed files with 7 additions and 5 deletions

View File

@ -140,11 +140,12 @@ def classifyExp(workDir, pairs, embeddings):
folds = KFold(n=6000, n_folds=10, shuffle=False) folds = KFold(n=6000, n_folds=10, shuffle=False)
thresholds = arange(0,4,0.01) thresholds = arange(0,4,0.01)
if os.path.exists("{}/log.txt".format(workDir)): if os.path.exists("{}/accuracies.txt".format(workDir)):
print("{}/log.txt already exists. Skipping processing.".format(workDir)) print("{}/accuracies.txt already exists. Skipping processing.".format(workDir))
else: else:
accuracies = [] accuracies = []
with open("{}/log.txt".format(workDir), "w") as f: with open("{}/accuracies.txt".format(workDir), "w") as f:
f.write('fold, threshold, accuracy\n')
for idx, (train, test) in enumerate(folds): for idx, (train, test) in enumerate(folds):
fname = "{}/l2-roc.fold-{}.csv".format(workDir, idx) fname = "{}/l2-roc.fold-{}.csv".format(workDir, idx)
writeROC(fname, thresholds, embeddings, pairs[test]) writeROC(fname, thresholds, embeddings, pairs[test])
@ -152,8 +153,9 @@ def classifyExp(workDir, pairs, embeddings):
bestThresh = findBestThreshold(thresholds, embeddings, pairs[train]) bestThresh = findBestThreshold(thresholds, embeddings, pairs[train])
accuracy = evalThresholdAccuracy(embeddings, pairs[test], bestThresh) accuracy = evalThresholdAccuracy(embeddings, pairs[test], bestThresh)
accuracies.append(accuracy) accuracies.append(accuracy)
f.write('fold {} test accuracy: {}\n'.format(idx, accuracy)) f.write('{}, {:0.2f}, {:0.2f}\n'.format(idx, bestThresh, accuracy))
f.write('avg: {} +/- {}\n'.format(np.mean(accuracies), np.std(accuracies))) f.write('\navg, {:0.2f} +/- {:0.2f}\n'.format(np.mean(accuracies),
np.std(accuracies)))
def plotClassifyExp(workDir): def plotClassifyExp(workDir):