2015-10-12 04:58:58 +08:00
#!/usr/bin/env python2
#
# Example to classify faces.
# Brandon Amos
# 2015/10/11
#
# Copyright 2015 Carnegie Mellon University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2015-12-15 00:53:11 +08:00
import time
start = time . time ( )
2015-10-12 04:58:58 +08:00
import argparse
import cv2
import itertools
import os
import pickle
from operator import itemgetter
import numpy as np
np . set_printoptions ( precision = 2 )
import pandas as pd
import sys
fileDir = os . path . dirname ( os . path . realpath ( __file__ ) )
sys . path . append ( os . path . join ( fileDir , " .. " ) )
import openface
import openface . helper
from openface . data import iterImgs
from sklearn . preprocessing import LabelEncoder
from sklearn . decomposition import PCA
from sklearn . grid_search import GridSearchCV
from sklearn . manifold import TSNE
from sklearn . svm import SVC
modelDir = os . path . join ( fileDir , ' .. ' , ' models ' )
dlibModelDir = os . path . join ( modelDir , ' dlib ' )
openfaceModelDir = os . path . join ( modelDir , ' openface ' )
2015-10-12 23:30:29 +08:00
2015-10-12 04:58:58 +08:00
def getRep ( imgPath ) :
2015-12-15 00:53:11 +08:00
start = time . time ( )
2015-12-09 06:04:57 +08:00
bgrImg = cv2 . imread ( imgPath )
if bgrImg is None :
2015-10-12 04:58:58 +08:00
raise Exception ( " Unable to load image: {} " . format ( imgPath ) )
2015-12-15 00:53:11 +08:00
2015-12-09 06:04:57 +08:00
rgbImg = cv2 . cvtColor ( bgrImg , cv2 . COLOR_BGR2RGB )
2015-10-12 04:58:58 +08:00
if args . verbose :
2015-12-09 06:04:57 +08:00
print ( " + Original size: {} " . format ( rgbImg . shape ) )
2015-12-15 00:53:11 +08:00
if args . verbose :
print ( " Loading the image took {} seconds. " . format ( time . time ( ) - start ) )
start = time . time ( )
2015-10-12 04:58:58 +08:00
2015-12-09 06:04:57 +08:00
bb = align . getLargestFaceBoundingBox ( rgbImg )
2015-10-12 04:58:58 +08:00
if bb is None :
raise Exception ( " Unable to find a face: {} " . format ( imgPath ) )
2015-12-15 00:53:11 +08:00
if args . verbose :
print ( " Face detection took {} seconds. " . format ( time . time ( ) - start ) )
2015-10-12 04:58:58 +08:00
2015-12-15 00:53:11 +08:00
start = time . time ( )
2015-12-09 06:04:57 +08:00
alignedFace = align . alignImg ( " affine " , args . imgDim , bgrImg , bb )
2015-10-12 04:58:58 +08:00
if alignedFace is None :
raise Exception ( " Unable to align image: {} " . format ( imgPath ) )
2015-12-15 00:53:11 +08:00
if args . verbose :
print ( " Alignment took {} seconds. " . format ( time . time ( ) - start ) )
2015-10-12 04:58:58 +08:00
2015-12-15 00:53:11 +08:00
start = time . time ( )
2015-10-12 04:58:58 +08:00
rep = net . forwardImage ( alignedFace )
2015-12-15 00:53:11 +08:00
if args . verbose :
print ( " Neural network forward pass took {} seconds. " . format ( time . time ( ) - start ) )
2015-10-12 04:58:58 +08:00
return rep
2015-10-12 23:30:29 +08:00
2015-10-12 04:58:58 +08:00
def train ( args ) :
print ( " Loading embeddings. " )
fname = " {} /labels.csv " . format ( args . workDir )
2015-10-12 23:30:29 +08:00
labels = pd . read_csv ( fname , header = None ) . as_matrix ( ) [ : , 1 ]
2015-10-12 04:58:58 +08:00
labels = map ( itemgetter ( 1 ) ,
map ( os . path . split ,
2015-10-12 23:30:29 +08:00
map ( os . path . dirname , labels ) ) ) # Get the directory.
2015-10-12 04:58:58 +08:00
fname = " {} /reps.csv " . format ( args . workDir )
embeddings = pd . read_csv ( fname , header = None ) . as_matrix ( )
le = LabelEncoder ( ) . fit ( labels )
labelsNum = le . transform ( labels )
param_grid = [
{ ' C ' : [ 1 , 10 , 100 , 1000 ] ,
' kernel ' : [ ' linear ' ] } ,
{ ' C ' : [ 1 , 10 , 100 , 1000 ] ,
' gamma ' : [ 0.001 , 0.0001 ] ,
' kernel ' : [ ' rbf ' ] }
]
svm = GridSearchCV (
SVC ( probability = True ) ,
param_grid , verbose = 4 , cv = 5 , n_jobs = 16
) . fit ( embeddings , labelsNum )
print ( " Best estimator: {} " . format ( svm . best_estimator_ ) )
print ( " Best score on left out data: {:.2f} " . format ( svm . best_score_ ) )
with open ( " {} /classifier.pkl " . format ( args . workDir ) , ' w ' ) as f :
pickle . dump ( ( le , svm ) , f )
2015-10-12 23:30:29 +08:00
2015-10-12 04:58:58 +08:00
def infer ( args ) :
2015-10-12 05:35:50 +08:00
with open ( args . classifierModel , ' r ' ) as f :
2015-10-12 04:58:58 +08:00
( le , svm ) = pickle . load ( f )
rep = getRep ( args . img )
2015-12-15 00:53:11 +08:00
start = time . time ( )
2015-10-12 04:58:58 +08:00
predictions = svm . predict_proba ( rep ) [ 0 ]
maxI = np . argmax ( predictions )
person = le . inverse_transform ( maxI )
confidence = predictions [ maxI ]
2015-12-15 00:53:11 +08:00
if args . verbose :
print ( " SVM prediction took {} seconds. " . format ( time . time ( ) - start ) )
2015-10-12 04:58:58 +08:00
print ( " Predict {} with {:.2f} confidence. " . format ( person , confidence ) )
2015-12-15 00:53:11 +08:00
2015-10-12 04:58:58 +08:00
if __name__ == ' __main__ ' :
2015-12-15 00:53:11 +08:00
2015-10-12 04:58:58 +08:00
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' --dlibFacePredictor ' , type = str ,
help = " Path to dlib ' s face predictor. " ,
default = os . path . join ( dlibModelDir ,
" shape_predictor_68_face_landmarks.dat " ) )
parser . add_argument ( ' --dlibRoot ' , type = str ,
2015-10-12 23:30:29 +08:00
default = os . path . expanduser (
" ~/src/dlib-18.16/python_examples " ) ,
2015-10-12 04:58:58 +08:00
help = " dlib directory with the dlib.so Python library. " )
parser . add_argument ( ' --networkModel ' , type = str ,
help = " Path to Torch network model. " ,
default = os . path . join ( openfaceModelDir , ' nn4.v1.t7 ' ) )
parser . add_argument ( ' --imgDim ' , type = int ,
help = " Default image dimension. " , default = 96 )
parser . add_argument ( ' --cuda ' , action = ' store_true ' )
parser . add_argument ( ' --verbose ' , action = ' store_true ' )
subparsers = parser . add_subparsers ( dest = ' mode ' , help = " Mode " )
trainParser = subparsers . add_parser ( ' train ' ,
help = " Train a new classifier. " )
2015-10-12 05:35:50 +08:00
trainParser . add_argument ( ' workDir ' , type = str ,
help = " The input work directory containing ' reps.csv ' and ' labels.csv ' . Obtained from aligning a directory with ' align-dlib ' and getting the representations with ' batch-represent ' . " )
2015-10-12 04:58:58 +08:00
inferParser = subparsers . add_parser ( ' infer ' ,
help = ' Predict who an image contains from a trained classifier. ' )
2015-11-24 22:57:33 +08:00
inferParser . add_argument ( ' classifierModel ' , type = str ,
help = ' The Python pickle representing the classifier. This is NOT the Torch network model, which can be set with --networkModel. ' )
2015-10-12 04:58:58 +08:00
inferParser . add_argument ( ' img ' , type = str ,
2015-10-12 23:30:29 +08:00
help = " Input image. " )
2015-10-12 04:58:58 +08:00
args = parser . parse_args ( )
2015-12-15 00:53:11 +08:00
if args . verbose :
print ( " Argument parsing and import libraries took {} seconds. " . format ( time . time ( ) - start ) )
2015-10-12 04:58:58 +08:00
2015-12-04 02:47:37 +08:00
if args . mode == ' infer ' and args . classifierModel . endswith ( " .t7 " ) :
2015-11-24 22:57:33 +08:00
raise Exception ( """
2015-12-04 02:47:37 +08:00
Torch network model passed as the classification model ,
which should be a Python pickle ( . pkl )
2015-11-24 22:57:33 +08:00
See the documentation for the distinction between the Torch
network and classification models :
http : / / cmusatyalab . github . io / openface / demo - 3 - classifier /
http : / / cmusatyalab . github . io / openface / training - new - models /
Use ` - - networkModel ` to set a non - standard Torch network model . """ )
2015-12-15 00:53:11 +08:00
start = time . time ( )
2015-11-24 22:57:33 +08:00
2015-12-09 06:04:57 +08:00
sys . path = [ args . dlibRoot ] + sys . path
2015-10-12 04:58:58 +08:00
import dlib
2015-10-12 23:30:29 +08:00
from openface . alignment import NaiveDlib # Depends on dlib.
2015-10-12 04:58:58 +08:00
2015-12-09 06:04:57 +08:00
align = NaiveDlib ( args . dlibFacePredictor )
2015-10-12 23:30:29 +08:00
net = openface . TorchWrap (
args . networkModel , imgDim = args . imgDim , cuda = args . cuda )
2015-10-12 04:58:58 +08:00
2015-12-15 00:53:11 +08:00
if args . verbose :
print ( " Loading the dlib and OpenFace models took {} seconds. " . format ( time . time ( ) - start ) )
start = time . time ( )
2015-10-12 04:58:58 +08:00
if args . mode == ' train ' :
train ( args )
elif args . mode == ' infer ' :
infer ( args )