LFW experiment: Show avg and stdev of accuracies.
This commit is contained in:
parent
d85ee28f93
commit
5bffd697cc
|
@ -140,11 +140,12 @@ def classifyExp(workDir, pairs, embeddings):
|
||||||
folds = KFold(n=6000, n_folds=10, shuffle=False)
|
folds = KFold(n=6000, n_folds=10, shuffle=False)
|
||||||
thresholds = arange(0,4,0.01)
|
thresholds = arange(0,4,0.01)
|
||||||
|
|
||||||
if os.path.exists("{}/log.txt".format(workDir)):
|
if os.path.exists("{}/accuracies.txt".format(workDir)):
|
||||||
print("{}/log.txt already exists. Skipping processing.".format(workDir))
|
print("{}/accuracies.txt already exists. Skipping processing.".format(workDir))
|
||||||
else:
|
else:
|
||||||
accuracies = []
|
accuracies = []
|
||||||
with open("{}/log.txt".format(workDir), "w") as f:
|
with open("{}/accuracies.txt".format(workDir), "w") as f:
|
||||||
|
f.write('fold, threshold, accuracy\n')
|
||||||
for idx, (train, test) in enumerate(folds):
|
for idx, (train, test) in enumerate(folds):
|
||||||
fname = "{}/l2-roc.fold-{}.csv".format(workDir, idx)
|
fname = "{}/l2-roc.fold-{}.csv".format(workDir, idx)
|
||||||
writeROC(fname, thresholds, embeddings, pairs[test])
|
writeROC(fname, thresholds, embeddings, pairs[test])
|
||||||
|
@ -152,8 +153,9 @@ def classifyExp(workDir, pairs, embeddings):
|
||||||
bestThresh = findBestThreshold(thresholds, embeddings, pairs[train])
|
bestThresh = findBestThreshold(thresholds, embeddings, pairs[train])
|
||||||
accuracy = evalThresholdAccuracy(embeddings, pairs[test], bestThresh)
|
accuracy = evalThresholdAccuracy(embeddings, pairs[test], bestThresh)
|
||||||
accuracies.append(accuracy)
|
accuracies.append(accuracy)
|
||||||
f.write('fold {} test accuracy: {}\n'.format(idx, accuracy))
|
f.write('{}, {:0.2f}, {:0.2f}\n'.format(idx, bestThresh, accuracy))
|
||||||
f.write('avg: {} +/- {}\n'.format(np.mean(accuracies), np.std(accuracies)))
|
f.write('\navg, {:0.2f} +/- {:0.2f}\n'.format(np.mean(accuracies),
|
||||||
|
np.std(accuracies)))
|
||||||
|
|
||||||
|
|
||||||
def plotClassifyExp(workDir):
|
def plotClassifyExp(workDir):
|
||||||
|
|
Loading…
Reference in New Issue