mirror of https://github.com/davisking/dlib.git
Improved example
This commit is contained in:
parent
fd13230486
commit
b87ecad51e
|
@ -288,8 +288,12 @@ int main(int argc, char** argv)
|
|||
dlib::rand rnd(time(0));
|
||||
load_mini_batch(5, 5, rnd, objs, images, labels);
|
||||
|
||||
// Normally you would use the non-batch-normalized version of the network to do
|
||||
// testing, which is what we do here.
|
||||
anet_type testing_net = net;
|
||||
|
||||
// Run all the images through the network to get their vector embeddings.
|
||||
std::vector<matrix<float,0,1>> embedded = net(images);
|
||||
std::vector<matrix<float,0,1>> embedded = testing_net(images);
|
||||
|
||||
// Now, check if the embedding puts images with the same labels near each other and
|
||||
// images with different labels far apart.
|
||||
|
@ -304,14 +308,14 @@ int main(int argc, char** argv)
|
|||
// The loss_metric layer will cause images with the same label to be less
|
||||
// than net.loss_details().get_distance_threshold() distance from each
|
||||
// other. So we can use that distance value as our testing threshold.
|
||||
if (length(embedded[i]-embedded[j]) < net.loss_details().get_distance_threshold())
|
||||
if (length(embedded[i]-embedded[j]) < testing_net.loss_details().get_distance_threshold())
|
||||
++num_right;
|
||||
else
|
||||
++num_wrong;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (length(embedded[i]-embedded[j]) >= net.loss_details().get_distance_threshold())
|
||||
if (length(embedded[i]-embedded[j]) >= testing_net.loss_details().get_distance_threshold())
|
||||
++num_right;
|
||||
else
|
||||
++num_wrong;
|
||||
|
|
Loading…
Reference in New Issue