lfw: classify->verify
This commit is contained in:
parent
7d8c6ba012
commit
a5528889c4
|
@ -13,6 +13,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# This implements the standard LFW verification experiment.
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
@ -64,8 +66,8 @@ def main():
|
|||
embeddings = dict(zip(*[paths, rawEmbeddings]))
|
||||
|
||||
pairs = loadPairs(args.lfwPairs)
|
||||
classifyExp(args.workDir, pairs, embeddings)
|
||||
plotClassifyExp(args.workDir, args.tag)
|
||||
verifyExp(args.workDir, pairs, embeddings)
|
||||
plotVerifyExp(args.workDir, args.tag)
|
||||
|
||||
|
||||
def loadPairs(pairsFname):
|
||||
|
@ -165,7 +167,7 @@ def findBestThreshold(thresholds, embeddings, pairsTrain):
|
|||
return bestThresh
|
||||
|
||||
|
||||
def classifyExp(workDir, pairs, embeddings):
|
||||
def verifyExp(workDir, pairs, embeddings):
|
||||
print(" + Computing accuracy.")
|
||||
folds = KFold(n=6000, n_folds=10, shuffle=False)
|
||||
thresholds = arange(0, 4, 0.01)
|
||||
|
@ -234,7 +236,7 @@ def plotOpenFaceROC(workDir, plotFolds=True, color=None):
|
|||
return foldPlot, meanPlot, AUC
|
||||
|
||||
|
||||
def plotClassifyExp(workDir, tag):
|
||||
def plotVerifyExp(workDir, tag):
|
||||
print("Plotting.")
|
||||
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
|
|
Loading…
Reference in New Issue