Improved example

This commit is contained in:
Davis King 2016-12-18 13:10:13 -05:00
parent fd13230486
commit b87ecad51e
1 changed files with 7 additions and 3 deletions

View File

@ -288,8 +288,12 @@ int main(int argc, char** argv)
dlib::rand rnd(time(0));
load_mini_batch(5, 5, rnd, objs, images, labels);
// Normally you would use the non-batch-normalized version of the network to do
// testing, which is what we do here.
anet_type testing_net = net;
// Run all the images through the network to get their vector embeddings.
std::vector<matrix<float,0,1>> embedded = net(images);
std::vector<matrix<float,0,1>> embedded = testing_net(images);
// Now, check if the embedding puts images with the same labels near each other and
// images with different labels far apart.
@ -304,14 +308,14 @@ int main(int argc, char** argv)
// The loss_metric layer will cause images with the same label to be less
// than net.loss_details().get_distance_threshold() distance from each
// other. So we can use that distance value as our testing threshold.
if (length(embedded[i]-embedded[j]) < net.loss_details().get_distance_threshold())
if (length(embedded[i]-embedded[j]) < testing_net.loss_details().get_distance_threshold())
++num_right;
else
++num_wrong;
}
else
{
if (length(embedded[i]-embedded[j]) >= net.loss_details().get_distance_threshold())
if (length(embedded[i]-embedded[j]) >= testing_net.loss_details().get_distance_threshold())
++num_right;
else
++num_wrong;