lfw: classify->verify

This commit is contained in:
Brandon Amos 2016-02-05 20:07:29 -05:00
parent 7d8c6ba012
commit a5528889c4
1 changed files with 6 additions and 4 deletions

View File

@ -13,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This implements the standard LFW verification experiment.
import math
import numpy as np
@ -64,8 +66,8 @@ def main():
embeddings = dict(zip(*[paths, rawEmbeddings]))
pairs = loadPairs(args.lfwPairs)
classifyExp(args.workDir, pairs, embeddings)
plotClassifyExp(args.workDir, args.tag)
verifyExp(args.workDir, pairs, embeddings)
plotVerifyExp(args.workDir, args.tag)
def loadPairs(pairsFname):
@ -165,7 +167,7 @@ def findBestThreshold(thresholds, embeddings, pairsTrain):
return bestThresh
def classifyExp(workDir, pairs, embeddings):
def verifyExp(workDir, pairs, embeddings):
print(" + Computing accuracy.")
folds = KFold(n=6000, n_folds=10, shuffle=False)
thresholds = arange(0, 4, 0.01)
@ -234,7 +236,7 @@ def plotOpenFaceROC(workDir, plotFolds=True, color=None):
return foldPlot, meanPlot, AUC
def plotClassifyExp(workDir, tag):
def plotVerifyExp(workDir, tag):
print("Plotting.")
fig, ax = plt.subplots(1, 1)