Web demo: Make the 'unknown' class optional.

This commit is contained in:
Brandon Amos 2015-10-03 22:52:58 -04:00
parent f12d181a75
commit dbd4490536
1 changed files with 15 additions and 12 deletions

View File

@ -62,6 +62,7 @@ parser.add_argument('--networkModel', type=str, help="Path to Torch network mode
default=os.path.join(facenetModelDir, 'nn4.v1.t7'))
parser.add_argument('--imgDim', type=int, help="Default image dimension.", default=96)
parser.add_argument('--cuda', type=bool, default=False)
parser.add_argument('--unknown', type=bool, default=False, help='Try to predict unknown people')
args = parser.parse_args()
@ -89,7 +90,8 @@ class FaceNetServerProtocol(WebSocketServerProtocol):
self.training = True
self.people = []
self.svm = None
self.unknownImgs = np.load("./examples/web/unknown.npy")
if args.unknown:
self.unknownImgs = np.load("./examples/web/unknown.npy")
def onConnect(self, request):
print("Client connecting: {0}".format(request.peer))
@ -161,18 +163,20 @@ class FaceNetServerProtocol(WebSocketServerProtocol):
X.append(img.rep)
y.append(img.identity)
numUnknown = y.count(-1)
numIdentified = len(y) - numUnknown
numIdentities = len(set(y+[-1])) - 1
if numIdentities == 0:
return None
numUnknownAdd = (numIdentified/numIdentities) - numUnknown
if numUnknownAdd > 0:
print("+ Augmenting with {} unknown images.".format(numUnknownAdd))
for rep in self.unknownImgs[:numUnknownAdd]:
# print(rep)
X.append(rep)
y.append(-1)
if args.unknown:
numUnknown = y.count(-1)
numIdentified = len(y) - numUnknown
numUnknownAdd = (numIdentified/numIdentities) - numUnknown
if numUnknownAdd > 0:
print("+ Augmenting with {} unknown images.".format(numUnknownAdd))
for rep in self.unknownImgs[:numUnknownAdd]:
# print(rep)
X.append(rep)
y.append(-1)
X = np.vstack(X)
y = np.array(y)
@ -264,8 +268,7 @@ class FaceNetServerProtocol(WebSocketServerProtocol):
if phash in self.images:
identity = self.images[phash].identity
else:
cv2.imwrite('/tmp/facenet-web-demo.png', alignedFace)
rep = np.array(net.forward("/tmp/facenet-web-demo.png"))
rep = net.forwardImage(alignedFace)
# print(rep)
if self.training:
self.images[phash] = Face(rep, identity)