2015-09-26 16:34:58 +08:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
#
|
|
|
|
# Copyright 2015 Carnegie Mellon University
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2015-10-04 05:16:44 +08:00
|
|
|
import math
|
2015-09-26 16:34:58 +08:00
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
from scipy.interpolate import interp1d
|
|
|
|
|
|
|
|
from sklearn.cross_validation import KFold
|
|
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
|
|
|
|
import matplotlib as mpl
|
|
|
|
mpl.use('Agg')
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
plt.style.use('bmh')
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
from scipy import arange
|
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
|
2015-09-26 16:34:58 +08:00
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('--workDir', type=str, default='reps')
|
2015-10-11 09:26:10 +08:00
|
|
|
parser.add_argument('--lfwPairs', type=str,
|
|
|
|
default=os.path.expanduser("~/openface/data/lfw/pairs.txt"))
|
2015-09-26 16:34:58 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
print("Loading embeddings.")
|
|
|
|
fname = "{}/labels.csv".format(args.workDir)
|
2015-10-13 19:40:47 +08:00
|
|
|
paths = pd.read_csv(fname, header=None).as_matrix()[:, 1]
|
|
|
|
paths = map(os.path.basename, paths) # Get the filename.
|
|
|
|
# Remove the extension.
|
|
|
|
paths = map(lambda path: os.path.splitext(path)[0], paths)
|
2015-09-26 16:34:58 +08:00
|
|
|
fname = "{}/reps.csv".format(args.workDir)
|
|
|
|
rawEmbeddings = pd.read_csv(fname, header=None).as_matrix()
|
|
|
|
embeddings = dict(zip(*[paths, rawEmbeddings]))
|
|
|
|
|
2015-10-11 09:26:10 +08:00
|
|
|
pairs = loadPairs(args.lfwPairs)
|
2015-09-26 16:34:58 +08:00
|
|
|
classifyExp(args.workDir, pairs, embeddings)
|
|
|
|
plotClassifyExp(args.workDir)
|
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
|
2015-10-11 09:26:10 +08:00
|
|
|
def loadPairs(pairsFname):
|
2015-09-26 16:34:58 +08:00
|
|
|
print(" + Reading pairs.")
|
|
|
|
pairs = []
|
2015-10-11 09:26:10 +08:00
|
|
|
with open(pairsFname, 'r') as f:
|
2015-09-26 16:34:58 +08:00
|
|
|
for line in f.readlines()[1:]:
|
|
|
|
pair = line.strip().split()
|
|
|
|
pairs.append(pair)
|
|
|
|
assert(len(pairs) == 6000)
|
|
|
|
return np.array(pairs)
|
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
|
2015-09-26 16:34:58 +08:00
|
|
|
def getEmbeddings(pair, embeddings):
|
|
|
|
if len(pair) == 3:
|
2015-10-13 19:40:47 +08:00
|
|
|
name1 = "{}_{}".format(pair[0], pair[1].zfill(4))
|
|
|
|
name2 = "{}_{}".format(pair[0], pair[2].zfill(4))
|
2015-09-26 16:34:58 +08:00
|
|
|
actual_same = True
|
|
|
|
elif len(pair) == 4:
|
2015-10-13 19:40:47 +08:00
|
|
|
name1 = "{}_{}".format(pair[0], pair[1].zfill(4))
|
|
|
|
name2 = "{}_{}".format(pair[2], pair[3].zfill(4))
|
2015-09-26 16:34:58 +08:00
|
|
|
actual_same = False
|
|
|
|
else:
|
|
|
|
raise Exception(
|
|
|
|
"Unexpected pair length: {}".format(len(pair)))
|
|
|
|
|
|
|
|
(x1, x2) = (embeddings[name1], embeddings[name2])
|
|
|
|
return (x1, x2, actual_same)
|
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
|
2015-09-26 16:34:58 +08:00
|
|
|
def writeROC(fname, thresholds, embeddings, pairsTest):
|
2015-10-13 19:40:47 +08:00
|
|
|
with open(fname, "w") as f:
|
2015-09-26 16:34:58 +08:00
|
|
|
f.write("threshold,tp,tn,fp,fn,tpr,fpr\n")
|
2015-10-13 19:40:47 +08:00
|
|
|
tp = tn = fp = fn = 0
|
2015-09-26 16:34:58 +08:00
|
|
|
for threshold in thresholds:
|
2015-10-13 19:40:47 +08:00
|
|
|
tp = tn = fp = fn = 0
|
2015-09-26 16:34:58 +08:00
|
|
|
for pair in pairsTest:
|
|
|
|
(x1, x2, actual_same) = getEmbeddings(pair, embeddings)
|
2015-10-13 19:40:47 +08:00
|
|
|
diff = x1 - x2
|
2015-09-26 16:34:58 +08:00
|
|
|
dist = np.dot(diff.T, diff)
|
|
|
|
predict_same = dist < threshold
|
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
if predict_same and actual_same:
|
|
|
|
tp += 1
|
|
|
|
elif predict_same and not actual_same:
|
|
|
|
fp += 1
|
|
|
|
elif not predict_same and not actual_same:
|
|
|
|
tn += 1
|
|
|
|
elif not predict_same and actual_same:
|
|
|
|
fn += 1
|
|
|
|
|
|
|
|
if tp + fn == 0:
|
|
|
|
tpr = 0
|
|
|
|
else:
|
|
|
|
tpr = float(tp) / float(tp + fn)
|
|
|
|
if fp + tn == 0:
|
|
|
|
fpr = 0
|
|
|
|
else:
|
|
|
|
fpr = float(fp) / float(fp + tn)
|
|
|
|
f.write(",".join([str(x)
|
|
|
|
for x in [threshold, tp, tn, fp, fn, tpr, fpr]]))
|
2015-09-26 16:34:58 +08:00
|
|
|
f.write("\n")
|
|
|
|
if tpr == 1.0 and fpr == 1.0:
|
|
|
|
# No further improvements.
|
2015-10-13 19:40:47 +08:00
|
|
|
f.write(",".join([str(x)
|
|
|
|
for x in [4.0, tp, tn, fp, fn, tpr, fpr]]))
|
2015-09-26 16:34:58 +08:00
|
|
|
return
|
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
|
2015-09-26 16:34:58 +08:00
|
|
|
def evalThresholdAccuracy(embeddings, pairs, threshold):
|
2015-10-13 19:40:47 +08:00
|
|
|
y_true = []
|
|
|
|
y_predict = []
|
2015-09-26 16:34:58 +08:00
|
|
|
for pair in pairs:
|
|
|
|
(x1, x2, actual_same) = getEmbeddings(pair, embeddings)
|
2015-10-13 19:40:47 +08:00
|
|
|
diff = x1 - x2
|
2015-09-26 16:34:58 +08:00
|
|
|
dist = np.dot(diff.T, diff)
|
|
|
|
predict_same = dist < threshold
|
|
|
|
y_predict.append(predict_same)
|
|
|
|
y_true.append(actual_same)
|
|
|
|
|
|
|
|
y_true = np.array(y_true)
|
|
|
|
y_predict = np.array(y_predict)
|
|
|
|
accuracy = accuracy_score(y_true, y_predict)
|
|
|
|
return accuracy
|
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
|
2015-09-26 16:34:58 +08:00
|
|
|
def findBestThreshold(thresholds, embeddings, pairsTrain):
|
|
|
|
bestThresh = bestThreshAcc = 0
|
|
|
|
for threshold in thresholds:
|
|
|
|
accuracy = evalThresholdAccuracy(embeddings, pairsTrain, threshold)
|
2015-10-04 05:16:44 +08:00
|
|
|
if accuracy >= bestThreshAcc:
|
2015-09-26 16:34:58 +08:00
|
|
|
bestThreshAcc = accuracy
|
|
|
|
bestThresh = threshold
|
|
|
|
else:
|
|
|
|
# No further improvements.
|
|
|
|
return bestThresh
|
|
|
|
return bestThresh
|
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
|
2015-09-26 16:34:58 +08:00
|
|
|
def classifyExp(workDir, pairs, embeddings):
|
|
|
|
print(" + Computing accuracy.")
|
|
|
|
folds = KFold(n=6000, n_folds=10, shuffle=False)
|
2015-10-13 19:40:47 +08:00
|
|
|
thresholds = arange(0, 4, 0.01)
|
2015-09-26 16:34:58 +08:00
|
|
|
|
2015-09-26 16:41:48 +08:00
|
|
|
if os.path.exists("{}/accuracies.txt".format(workDir)):
|
|
|
|
print("{}/accuracies.txt already exists. Skipping processing.".format(workDir))
|
2015-09-26 16:34:58 +08:00
|
|
|
else:
|
|
|
|
accuracies = []
|
2015-09-26 16:41:48 +08:00
|
|
|
with open("{}/accuracies.txt".format(workDir), "w") as f:
|
|
|
|
f.write('fold, threshold, accuracy\n')
|
2015-09-26 16:34:58 +08:00
|
|
|
for idx, (train, test) in enumerate(folds):
|
|
|
|
fname = "{}/l2-roc.fold-{}.csv".format(workDir, idx)
|
|
|
|
writeROC(fname, thresholds, embeddings, pairs[test])
|
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
bestThresh = findBestThreshold(
|
|
|
|
thresholds, embeddings, pairs[train])
|
|
|
|
accuracy = evalThresholdAccuracy(
|
|
|
|
embeddings, pairs[test], bestThresh)
|
2015-09-26 16:34:58 +08:00
|
|
|
accuracies.append(accuracy)
|
2015-10-13 19:40:47 +08:00
|
|
|
f.write('{}, {:0.2f}, {:0.2f}\n'.format(
|
|
|
|
idx, bestThresh, accuracy))
|
2016-01-11 21:11:37 +08:00
|
|
|
avg = np.mean(accuracies)
|
|
|
|
std = np.std(accuracies)
|
|
|
|
f.write('\navg, {:0.4f} +/- {:0.4f}\n'.format(avg, std))
|
|
|
|
print(' + {:0.4f}'.format(avg))
|
2015-09-26 16:34:58 +08:00
|
|
|
|
|
|
|
|
2015-10-04 05:16:44 +08:00
|
|
|
def getAUC(fprs, tprs):
|
2015-10-13 19:40:47 +08:00
|
|
|
sortedFprs, sortedTprs = zip(*sorted(zip(*(fprs, tprs))))
|
2015-10-04 05:16:44 +08:00
|
|
|
sortedFprs = list(sortedFprs)
|
|
|
|
sortedTprs = list(sortedTprs)
|
|
|
|
if sortedFprs[-1] != 1.0:
|
|
|
|
sortedFprs.append(1.0)
|
|
|
|
sortedTprs.append(sortedTprs[-1])
|
|
|
|
return np.trapz(sortedTprs, sortedFprs)
|
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
|
2016-01-08 07:28:05 +08:00
|
|
|
def plotOpenFaceROC(workDir, plotFolds=True, color=None):
|
2015-09-26 16:34:58 +08:00
|
|
|
fs = []
|
|
|
|
for i in range(10):
|
|
|
|
rocData = pd.read_csv("{}/l2-roc.fold-{}.csv".format(workDir, i))
|
|
|
|
fs.append(interp1d(rocData['fpr'], rocData['tpr']))
|
2015-10-13 19:40:47 +08:00
|
|
|
x = np.linspace(0, 1, 1000)
|
2016-01-08 07:28:05 +08:00
|
|
|
if plotFolds:
|
|
|
|
foldPlot, = plt.plot(x, fs[-1](x), color='grey', alpha=0.5)
|
|
|
|
else:
|
|
|
|
foldPlot = None
|
2015-09-26 16:34:58 +08:00
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
fprs = []
|
|
|
|
tprs = []
|
|
|
|
for fpr in np.linspace(0, 1, 1000):
|
2015-09-26 16:34:58 +08:00
|
|
|
tpr = 0.0
|
|
|
|
for f in fs:
|
2015-10-04 05:16:44 +08:00
|
|
|
v = f(fpr)
|
|
|
|
if math.isnan(v):
|
|
|
|
v = 0.0
|
|
|
|
tpr += v
|
2015-09-26 16:34:58 +08:00
|
|
|
tpr /= 10.0
|
|
|
|
fprs.append(fpr)
|
|
|
|
tprs.append(tpr)
|
2016-01-08 07:28:05 +08:00
|
|
|
if color:
|
|
|
|
meanPlot, = plt.plot(fprs, tprs, color=color)
|
|
|
|
else:
|
|
|
|
meanPlot, = plt.plot(fprs, tprs)
|
|
|
|
AUC = getAUC(fprs, tprs)
|
|
|
|
return foldPlot, meanPlot, AUC
|
|
|
|
|
2016-01-08 07:43:25 +08:00
|
|
|
|
2016-01-08 07:28:05 +08:00
|
|
|
def plotClassifyExp(workDir):
|
|
|
|
print("Plotting.")
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(1, 1)
|
|
|
|
|
|
|
|
openbrData = pd.read_csv("comparisons/openbr.v1.1.0.DET.csv")
|
|
|
|
openbrData['Y'] = 1 - openbrData['Y']
|
|
|
|
# brPlot = openbrData.plot(x='X', y='Y', legend=True, ax=ax)
|
|
|
|
brPlot, = plt.plot(openbrData['X'], openbrData['Y'])
|
|
|
|
brAUC = getAUC(openbrData['X'], openbrData['Y'])
|
|
|
|
|
|
|
|
foldPlot_v1, meanPlot_v1, AUC_v1 = plotOpenFaceROC("lfw.nn4.v1", False)
|
|
|
|
foldPlot_v2, meanPlot_v2, AUC_v2 = plotOpenFaceROC(workDir, color='k')
|
2015-09-26 16:34:58 +08:00
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
humanData = pd.read_table(
|
|
|
|
"comparisons/kumar_human_crop.txt", header=None, sep=' ')
|
2015-09-26 16:34:58 +08:00
|
|
|
humanPlot, = plt.plot(humanData[1], humanData[0])
|
2015-10-04 05:16:44 +08:00
|
|
|
humanAUC = getAUC(humanData[1], humanData[0])
|
2015-09-26 16:34:58 +08:00
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
deepfaceData = pd.read_table(
|
|
|
|
"comparisons/deepface_ensemble.txt", header=None, sep=' ')
|
2015-10-13 04:52:06 +08:00
|
|
|
dfPlot, = plt.plot(deepfaceData[1], deepfaceData[0], '--',
|
|
|
|
alpha=0.75)
|
2015-10-04 05:16:44 +08:00
|
|
|
deepfaceAUC = getAUC(deepfaceData[1], deepfaceData[0])
|
2015-09-26 16:34:58 +08:00
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
baiduData = pd.read_table(
|
|
|
|
"comparisons/BaiduIDLFinal.TPFP", header=None, sep=' ')
|
2015-09-26 16:34:58 +08:00
|
|
|
bPlot, = plt.plot(baiduData[1], baiduData[0])
|
2015-10-04 05:16:44 +08:00
|
|
|
baiduAUC = getAUC(baiduData[1], baiduData[0])
|
2015-09-26 16:34:58 +08:00
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
eigData = pd.read_table(
|
|
|
|
"comparisons/eigenfaces-original-roc.txt", header=None, sep=' ')
|
2015-09-26 16:34:58 +08:00
|
|
|
eigPlot, = plt.plot(eigData[1], eigData[0])
|
2015-10-04 05:16:44 +08:00
|
|
|
eigAUC = getAUC(eigData[1], eigData[0])
|
2015-09-26 16:34:58 +08:00
|
|
|
|
2016-01-08 07:28:05 +08:00
|
|
|
ax.legend([humanPlot, bPlot, dfPlot, brPlot, eigPlot,
|
|
|
|
meanPlot_v1, meanPlot_v2, foldPlot_v2],
|
2015-10-13 04:55:18 +08:00
|
|
|
['Human, Cropped [AUC={:.3f}]'.format(humanAUC),
|
|
|
|
'Baidu [{:.3f}]'.format(baiduAUC),
|
|
|
|
'DeepFace Ensemble [{:.3f}]'.format(deepfaceAUC),
|
|
|
|
'OpenBR v1.1.0 [{:.3f}]'.format(brAUC),
|
|
|
|
'Eigenfaces (img-restrict) [{:.3f}]'.format(eigAUC),
|
2016-01-08 07:28:05 +08:00
|
|
|
'OpenFace nn4.v1 [{:.3f}]'.format(AUC_v1),
|
|
|
|
'OpenFace nn4.v2 [{:.3f}]'.format(AUC_v2),
|
|
|
|
'OpenFace nn4.v2 folds'],
|
2015-09-26 16:34:58 +08:00
|
|
|
loc='lower right')
|
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
plt.plot([0, 1], color='k', linestyle=':')
|
2015-09-26 16:34:58 +08:00
|
|
|
|
|
|
|
plt.xlabel("False Positive Rate")
|
|
|
|
plt.ylabel("True Positive Rate")
|
|
|
|
# plt.ylim(ymin=0,ymax=1)
|
2015-10-13 19:40:47 +08:00
|
|
|
plt.xlim(xmin=0, xmax=1)
|
2015-09-26 16:34:58 +08:00
|
|
|
|
|
|
|
plt.grid(b=True, which='major', color='k', linestyle='-')
|
|
|
|
plt.grid(b=True, which='minor', color='k', linestyle='-', alpha=0.2)
|
|
|
|
plt.minorticks_on()
|
|
|
|
fig.savefig(os.path.join(workDir, "roc.pdf"))
|
|
|
|
|
2015-10-13 19:40:47 +08:00
|
|
|
if __name__ == '__main__':
|
2015-09-26 16:34:58 +08:00
|
|
|
main()
|