lfw roc: Make sure every representation exists.

This commit is contained in:
Brandon Amos 2015-09-24 14:21:47 -04:00
parent 63104030dd
commit ce72875f54
1 changed files with 18 additions and 15 deletions

View File

@ -38,10 +38,12 @@ def main():
args = parser.parse_args()
print("Loading embeddings.")
paths = pd.read_csv("{}/labels.csv".format(args.workDir)).as_matrix()[:,1]
fname = "{}/labels.csv".format(args.workDir)
paths = pd.read_csv(fname, header=None).as_matrix()[:,1]
paths = map(os.path.basename, paths) # Get the filename.
paths = map(lambda path: os.path.splitext(path)[0], paths) # Remove the extension.
rawEmbeddings = pd.read_csv("{}/reps.csv".format(args.workDir)).as_matrix()
fname = "{}/reps.csv".format(args.workDir)
rawEmbeddings = pd.read_csv(fname, header=None).as_matrix()
embeddings = dict(zip(*[paths, rawEmbeddings]))
pairs = loadPairs()
@ -89,19 +91,20 @@ def analyze_accuracy(workDir, pairs, embeddings):
else:
raise Exception(
"Unexpected pair length: {}".format(len(pair)))
# print(name1,name2)
if name1 not in embeddings or name2 not in embeddings:
# Representation of one or both people is not available
# since sometimes the face cannot be aligned.
# Guess they are the same and note the error.
num_errors += 1
predict_same = True
else:
vec1 = embeddings[name1]
vec2 = embeddings[name2]
diff = vec1-vec2
dist = np.dot(diff.T, diff)
predict_same = dist < threshold
if name1 not in embeddings:
print('Error: Representation for {} not found.'.format(name1))
sys.exit(-1)
if name2 not in embeddings:
print('Error: Representation for {} not found.'.format(name2))
sys.exit(-1)
vec1 = embeddings[name1]
vec2 = embeddings[name2]
diff = vec1-vec2
dist = np.dot(diff.T, diff)
predict_same = dist < threshold
if predict_same and actual_same: tp += 1
elif predict_same and not actual_same: fp += 1
@ -134,7 +137,7 @@ def plot_accuracy(workDir):
openbrData['Y'] = 1-openbrData['Y']
openbrData.plot(x='X', y='Y', legend=True, ax=ax)
ax.legend(['FaceNet', 'OpenBR v1.0.0'], loc='lower right')
ax.legend(['FaceNet nn4.v1', 'OpenBR v1.0.0'], loc='lower right')
plt.plot([0,1], color='k', linestyle='dashed')