lfw roc: Make sure every representation exists.
This commit is contained in:
parent
63104030dd
commit
ce72875f54
|
@ -38,10 +38,12 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
print("Loading embeddings.")
|
||||
paths = pd.read_csv("{}/labels.csv".format(args.workDir)).as_matrix()[:,1]
|
||||
fname = "{}/labels.csv".format(args.workDir)
|
||||
paths = pd.read_csv(fname, header=None).as_matrix()[:,1]
|
||||
paths = map(os.path.basename, paths) # Get the filename.
|
||||
paths = map(lambda path: os.path.splitext(path)[0], paths) # Remove the extension.
|
||||
rawEmbeddings = pd.read_csv("{}/reps.csv".format(args.workDir)).as_matrix()
|
||||
fname = "{}/reps.csv".format(args.workDir)
|
||||
rawEmbeddings = pd.read_csv(fname, header=None).as_matrix()
|
||||
embeddings = dict(zip(*[paths, rawEmbeddings]))
|
||||
|
||||
pairs = loadPairs()
|
||||
|
@ -89,19 +91,20 @@ def analyze_accuracy(workDir, pairs, embeddings):
|
|||
else:
|
||||
raise Exception(
|
||||
"Unexpected pair length: {}".format(len(pair)))
|
||||
|
||||
# print(name1,name2)
|
||||
if name1 not in embeddings or name2 not in embeddings:
|
||||
# Representation of one or both people is not available
|
||||
# since sometimes the face cannot be aligned.
|
||||
# Guess they are the same and note the error.
|
||||
num_errors += 1
|
||||
predict_same = True
|
||||
else:
|
||||
vec1 = embeddings[name1]
|
||||
vec2 = embeddings[name2]
|
||||
diff = vec1-vec2
|
||||
dist = np.dot(diff.T, diff)
|
||||
predict_same = dist < threshold
|
||||
if name1 not in embeddings:
|
||||
print('Error: Representation for {} not found.'.format(name1))
|
||||
sys.exit(-1)
|
||||
if name2 not in embeddings:
|
||||
print('Error: Representation for {} not found.'.format(name2))
|
||||
sys.exit(-1)
|
||||
|
||||
vec1 = embeddings[name1]
|
||||
vec2 = embeddings[name2]
|
||||
diff = vec1-vec2
|
||||
dist = np.dot(diff.T, diff)
|
||||
predict_same = dist < threshold
|
||||
|
||||
if predict_same and actual_same: tp += 1
|
||||
elif predict_same and not actual_same: fp += 1
|
||||
|
@ -134,7 +137,7 @@ def plot_accuracy(workDir):
|
|||
openbrData['Y'] = 1-openbrData['Y']
|
||||
openbrData.plot(x='X', y='Y', legend=True, ax=ax)
|
||||
|
||||
ax.legend(['FaceNet', 'OpenBR v1.0.0'], loc='lower right')
|
||||
ax.legend(['FaceNet nn4.v1', 'OpenBR v1.0.0'], loc='lower right')
|
||||
|
||||
plt.plot([0,1], color='k', linestyle='dashed')
|
||||
|
||||
|
|
Loading…
Reference in New Issue