Web demo: Make the 'unknown' class optional.
This commit is contained in:
parent
f12d181a75
commit
dbd4490536
|
@ -62,6 +62,7 @@ parser.add_argument('--networkModel', type=str, help="Path to Torch network mode
|
|||
default=os.path.join(facenetModelDir, 'nn4.v1.t7'))
|
||||
parser.add_argument('--imgDim', type=int, help="Default image dimension.", default=96)
|
||||
parser.add_argument('--cuda', type=bool, default=False)
|
||||
parser.add_argument('--unknown', type=bool, default=False, help='Try to predict unknown people')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -89,7 +90,8 @@ class FaceNetServerProtocol(WebSocketServerProtocol):
|
|||
self.training = True
|
||||
self.people = []
|
||||
self.svm = None
|
||||
self.unknownImgs = np.load("./examples/web/unknown.npy")
|
||||
if args.unknown:
|
||||
self.unknownImgs = np.load("./examples/web/unknown.npy")
|
||||
|
||||
def onConnect(self, request):
|
||||
print("Client connecting: {0}".format(request.peer))
|
||||
|
@ -161,18 +163,20 @@ class FaceNetServerProtocol(WebSocketServerProtocol):
|
|||
X.append(img.rep)
|
||||
y.append(img.identity)
|
||||
|
||||
numUnknown = y.count(-1)
|
||||
numIdentified = len(y) - numUnknown
|
||||
numIdentities = len(set(y+[-1])) - 1
|
||||
if numIdentities == 0:
|
||||
return None
|
||||
numUnknownAdd = (numIdentified/numIdentities) - numUnknown
|
||||
if numUnknownAdd > 0:
|
||||
print("+ Augmenting with {} unknown images.".format(numUnknownAdd))
|
||||
for rep in self.unknownImgs[:numUnknownAdd]:
|
||||
# print(rep)
|
||||
X.append(rep)
|
||||
y.append(-1)
|
||||
|
||||
if args.unknown:
|
||||
numUnknown = y.count(-1)
|
||||
numIdentified = len(y) - numUnknown
|
||||
numUnknownAdd = (numIdentified/numIdentities) - numUnknown
|
||||
if numUnknownAdd > 0:
|
||||
print("+ Augmenting with {} unknown images.".format(numUnknownAdd))
|
||||
for rep in self.unknownImgs[:numUnknownAdd]:
|
||||
# print(rep)
|
||||
X.append(rep)
|
||||
y.append(-1)
|
||||
|
||||
X = np.vstack(X)
|
||||
y = np.array(y)
|
||||
|
@ -264,8 +268,7 @@ class FaceNetServerProtocol(WebSocketServerProtocol):
|
|||
if phash in self.images:
|
||||
identity = self.images[phash].identity
|
||||
else:
|
||||
cv2.imwrite('/tmp/facenet-web-demo.png', alignedFace)
|
||||
rep = np.array(net.forward("/tmp/facenet-web-demo.png"))
|
||||
rep = net.forwardImage(alignedFace)
|
||||
# print(rep)
|
||||
if self.training:
|
||||
self.images[phash] = Face(rep, identity)
|
||||
|
|
Loading…
Reference in New Issue