LFW experiment: Show avg and stdev of accuracies.
This commit is contained in:
parent
d85ee28f93
commit
5bffd697cc
|
@ -140,11 +140,12 @@ def classifyExp(workDir, pairs, embeddings):
|
|||
folds = KFold(n=6000, n_folds=10, shuffle=False)
|
||||
thresholds = arange(0,4,0.01)
|
||||
|
||||
if os.path.exists("{}/log.txt".format(workDir)):
|
||||
print("{}/log.txt already exists. Skipping processing.".format(workDir))
|
||||
if os.path.exists("{}/accuracies.txt".format(workDir)):
|
||||
print("{}/accuracies.txt already exists. Skipping processing.".format(workDir))
|
||||
else:
|
||||
accuracies = []
|
||||
with open("{}/log.txt".format(workDir), "w") as f:
|
||||
with open("{}/accuracies.txt".format(workDir), "w") as f:
|
||||
f.write('fold, threshold, accuracy\n')
|
||||
for idx, (train, test) in enumerate(folds):
|
||||
fname = "{}/l2-roc.fold-{}.csv".format(workDir, idx)
|
||||
writeROC(fname, thresholds, embeddings, pairs[test])
|
||||
|
@ -152,8 +153,9 @@ def classifyExp(workDir, pairs, embeddings):
|
|||
bestThresh = findBestThreshold(thresholds, embeddings, pairs[train])
|
||||
accuracy = evalThresholdAccuracy(embeddings, pairs[test], bestThresh)
|
||||
accuracies.append(accuracy)
|
||||
f.write('fold {} test accuracy: {}\n'.format(idx, accuracy))
|
||||
f.write('avg: {} +/- {}\n'.format(np.mean(accuracies), np.std(accuracies)))
|
||||
f.write('{}, {:0.2f}, {:0.2f}\n'.format(idx, bestThresh, accuracy))
|
||||
f.write('\navg, {:0.2f} +/- {:0.2f}\n'.format(np.mean(accuracies),
|
||||
np.std(accuracies)))
|
||||
|
||||
|
||||
def plotClassifyExp(workDir):
|
||||
|
|
Loading…
Reference in New Issue