Add timing information to classifier.py

Thanks @lucafeudi
This commit is contained in:
Brandon Amos 2015-12-14 11:53:11 -05:00
parent 139cfb7756
commit 1b9ad376ff
1 changed files with 30 additions and 0 deletions

View File

@ -18,6 +18,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
start = time.time()
import argparse
import cv2
import itertools
@ -50,23 +54,37 @@ openfaceModelDir = os.path.join(modelDir, 'openface')
def getRep(imgPath):
start = time.time()
bgrImg = cv2.imread(imgPath)
if bgrImg is None:
raise Exception("Unable to load image: {}".format(imgPath))
rgbImg = cv2.cvtColor(bgrImg, cv2.COLOR_BGR2RGB)
if args.verbose:
print(" + Original size: {}".format(rgbImg.shape))
if args.verbose:
print("Loading the image took {} seconds.".format(time.time() - start))
start = time.time()
bb = align.getLargestFaceBoundingBox(rgbImg)
if bb is None:
raise Exception("Unable to find a face: {}".format(imgPath))
if args.verbose:
print("Face detection took {} seconds.".format(time.time() - start))
start = time.time()
alignedFace = align.alignImg("affine", args.imgDim, bgrImg, bb)
if alignedFace is None:
raise Exception("Unable to align image: {}".format(imgPath))
if args.verbose:
print("Alignment took {} seconds.".format(time.time() - start))
start = time.time()
rep = net.forwardImage(alignedFace)
if args.verbose:
print("Neural network forward pass took {} seconds.".format(time.time() - start))
return rep
@ -104,13 +122,18 @@ def infer(args):
with open(args.classifierModel, 'r') as f:
(le, svm) = pickle.load(f)
rep = getRep(args.img)
start = time.time()
predictions = svm.predict_proba(rep)[0]
maxI = np.argmax(predictions)
person = le.inverse_transform(maxI)
confidence = predictions[maxI]
if args.verbose:
print("SVM prediction took {} seconds.".format(time.time() - start))
print("Predict {} with {:.2f} confidence.".format(person, confidence))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dlibFacePredictor', type=str,
@ -143,6 +166,8 @@ if __name__ == '__main__':
help="Input image.")
args = parser.parse_args()
if args.verbose:
print("Argument parsing and import libraries took {} seconds.".format(time.time() - start))
if args.mode == 'infer' and args.classifierModel.endswith(".t7"):
raise Exception("""
@ -156,6 +181,7 @@ network and classification models:
http://cmusatyalab.github.io/openface/training-new-models/
Use `--networkModel` to set a non-standard Torch network model.""")
start = time.time()
sys.path = [args.dlibRoot] + sys.path
import dlib
@ -165,6 +191,10 @@ Use `--networkModel` to set a non-standard Torch network model.""")
net = openface.TorchWrap(
args.networkModel, imgDim=args.imgDim, cuda=args.cuda)
if args.verbose:
print("Loading the dlib and OpenFace models took {} seconds.".format(time.time() - start))
start = time.time()
if args.mode == 'train':
train(args)
elif args.mode == 'infer':