2016-12-18 09:41:36 +08:00
|
|
|
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
|
|
|
/*
|
|
|
|
This is an example illustrating the use of the deep learning tools from the
|
|
|
|
dlib C++ Library. In it, we will show how to use the loss_metric layer to do
|
|
|
|
metric learning on images.
|
|
|
|
|
|
|
|
The main reason you might want to use this kind of algorithm is because you
|
|
|
|
would like to use a k-nearest neighbor classifier or similar algorithm, but
|
|
|
|
you don't know a good way to calculate the distance between two things. A
|
|
|
|
popular example would be face recognition. There are a whole lot of papers
|
|
|
|
that train some kind of deep metric learning algorithm that embeds face
|
|
|
|
images in some vector space where images of the same person are close to each
|
|
|
|
other and images of different people are far apart. Then in that vector
|
|
|
|
space it's very easy to do face recognition with some kind of k-nearest
|
|
|
|
neighbor classifier.
|
|
|
|
|
2017-02-10 01:38:39 +08:00
|
|
|
In this example we will use a version of the ResNet network from the
|
|
|
|
dnn_imagenet_ex.cpp example to learn to map images into some vector space where
|
|
|
|
pictures of the same person are close and pictures of different people are far
|
|
|
|
apart.
|
2016-12-18 09:41:36 +08:00
|
|
|
|
|
|
|
You might want to read the simpler introduction to the deep metric learning
|
|
|
|
API, dnn_metric_learning_ex.cpp, before reading this example. You should
|
|
|
|
also have read the examples that introduce the dlib DNN API before
|
|
|
|
continuing. These are dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp.
|
|
|
|
|
|
|
|
*/
|
2016-12-18 03:29:29 +08:00
|
|
|
|
|
|
|
#include <dlib/dnn.h>
|
|
|
|
#include <dlib/image_io.h>
|
|
|
|
#include <dlib/misc_api.h>
|
|
|
|
|
|
|
|
using namespace dlib;
|
|
|
|
using namespace std;
|
|
|
|
|
2016-12-18 09:41:36 +08:00
|
|
|
// ----------------------------------------------------------------------------------------
|
2016-12-18 03:29:29 +08:00
|
|
|
|
2016-12-18 09:41:36 +08:00
|
|
|
// We will need to create some functions for loading data. This program will
|
|
|
|
// expect to be given a directory structured as follows:
|
|
|
|
// top_level_directory/
|
|
|
|
// person1/
|
|
|
|
// image1.jpg
|
|
|
|
// image2.jpg
|
|
|
|
// image3.jpg
|
|
|
|
// person2/
|
|
|
|
// image4.jpg
|
|
|
|
// image5.jpg
|
|
|
|
// image6.jpg
|
|
|
|
// person3/
|
|
|
|
// image7.jpg
|
|
|
|
// image8.jpg
|
|
|
|
// image9.jpg
|
|
|
|
//
|
|
|
|
// The specific folder and image names don't matter, nor does the number of folders or
|
|
|
|
// images. What does matter is that there is a top level folder, which contains
|
|
|
|
// subfolders, and each subfolder contains images of a single person.
|
|
|
|
|
|
|
|
// This function spiders the top level directory and obtains a list of all the
|
|
|
|
// image files.
|
2016-12-18 03:29:29 +08:00
|
|
|
std::vector<std::vector<string>> load_objects_list (
|
|
|
|
const string& dir
|
|
|
|
)
|
|
|
|
{
|
|
|
|
std::vector<std::vector<string>> objects;
|
|
|
|
for (auto subdir : directory(dir).get_dirs())
|
|
|
|
{
|
|
|
|
std::vector<string> imgs;
|
|
|
|
for (auto img : subdir.get_files())
|
|
|
|
imgs.push_back(img);
|
|
|
|
|
2016-12-19 09:52:45 +08:00
|
|
|
if (imgs.size() != 0)
|
|
|
|
objects.push_back(imgs);
|
2016-12-18 03:29:29 +08:00
|
|
|
}
|
|
|
|
return objects;
|
|
|
|
}
|
|
|
|
|
2016-12-18 09:41:36 +08:00
|
|
|
// This function takes the output of load_objects_list() as input and randomly
|
|
|
|
// selects images for training. It should also be pointed out that it's really
|
|
|
|
// important that each mini-batch contain multiple images of each person. This
|
|
|
|
// is because the metric learning algorithm needs to consider pairs of images
|
|
|
|
// that should be close (i.e. images of the same person) as well as pairs of
|
|
|
|
// images that should be far apart (i.e. images of different people) during each
|
|
|
|
// training step.
|
2016-12-18 03:29:29 +08:00
|
|
|
void load_mini_batch (
|
2016-12-18 09:41:36 +08:00
|
|
|
const size_t num_people, // how many different people to include
|
|
|
|
const size_t samples_per_id, // how many images per person to select.
|
2016-12-18 03:29:29 +08:00
|
|
|
dlib::rand& rnd,
|
|
|
|
const std::vector<std::vector<string>>& objs,
|
|
|
|
std::vector<matrix<rgb_pixel>>& images,
|
|
|
|
std::vector<unsigned long>& labels
|
|
|
|
)
|
|
|
|
{
|
|
|
|
images.clear();
|
|
|
|
labels.clear();
|
2016-12-18 09:41:36 +08:00
|
|
|
DLIB_CASSERT(num_people <= objs.size(), "The dataset doesn't have that many people in it.");
|
2016-12-18 03:29:29 +08:00
|
|
|
|
2016-12-18 09:41:36 +08:00
|
|
|
std::vector<bool> already_selected(objs.size(), false);
|
2016-12-18 03:29:29 +08:00
|
|
|
matrix<rgb_pixel> image;
|
2016-12-18 09:41:36 +08:00
|
|
|
for (size_t i = 0; i < num_people; ++i)
|
2016-12-18 03:29:29 +08:00
|
|
|
{
|
2016-12-18 09:41:36 +08:00
|
|
|
size_t id = rnd.get_random_32bit_number()%objs.size();
|
|
|
|
// don't pick a person we already added to the mini-batch
|
|
|
|
while(already_selected[id])
|
|
|
|
id = rnd.get_random_32bit_number()%objs.size();
|
|
|
|
already_selected[id] = true;
|
|
|
|
|
2016-12-18 03:29:29 +08:00
|
|
|
for (size_t j = 0; j < samples_per_id; ++j)
|
|
|
|
{
|
|
|
|
const auto& obj = objs[id][rnd.get_random_32bit_number()%objs[id].size()];
|
|
|
|
load_image(image, obj);
|
|
|
|
images.push_back(std::move(image));
|
|
|
|
labels.push_back(id);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-02-10 01:38:39 +08:00
|
|
|
// You might want to do some data augmentation at this point. Here we do some simple
|
2016-12-18 03:29:29 +08:00
|
|
|
// color augmentation.
|
|
|
|
for (auto&& crop : images)
|
|
|
|
disturb_colors(crop,rnd);
|
|
|
|
|
|
|
|
|
|
|
|
// All the images going into a mini-batch have to be the same size. And really, all
|
|
|
|
// the images in your entire training dataset should be the same size for what we are
|
|
|
|
// doing to make the most sense.
|
|
|
|
DLIB_CASSERT(images.size() > 0);
|
|
|
|
for (auto&& img : images)
|
|
|
|
{
|
|
|
|
DLIB_CASSERT(img.nr() == images[0].nr() && img.nc() == images[0].nc(),
|
|
|
|
"All the images in a single mini-batch must be the same size.");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
|
2017-02-09 11:53:55 +08:00
|
|
|
// The next page of code defines a ResNet network. It's basically copied
|
2016-12-18 09:41:36 +08:00
|
|
|
// and pasted from the dnn_imagenet_ex.cpp example, except we replaced the loss
|
2017-02-09 11:53:55 +08:00
|
|
|
// layer with loss_metric and make the network somewhat smaller.
|
2016-12-18 09:41:36 +08:00
|
|
|
|
2016-12-18 03:29:29 +08:00
|
|
|
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
|
|
|
|
using residual = add_prev1<block<N,BN,1,tag1<SUBNET>>>;
|
|
|
|
|
|
|
|
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
|
|
|
|
using residual_down = add_prev2<avg_pool<2,2,2,2,skip1<tag2<block<N,BN,2,tag1<SUBNET>>>>>>;
|
|
|
|
|
|
|
|
template <int N, template <typename> class BN, int stride, typename SUBNET>
|
|
|
|
using block = BN<con<N,3,3,1,1,relu<BN<con<N,3,3,stride,stride,SUBNET>>>>>;
|
|
|
|
|
|
|
|
|
|
|
|
template <int N, typename SUBNET> using res = relu<residual<block,N,bn_con,SUBNET>>;
|
|
|
|
template <int N, typename SUBNET> using ares = relu<residual<block,N,affine,SUBNET>>;
|
|
|
|
template <int N, typename SUBNET> using res_down = relu<residual_down<block,N,bn_con,SUBNET>>;
|
|
|
|
template <int N, typename SUBNET> using ares_down = relu<residual_down<block,N,affine,SUBNET>>;
|
|
|
|
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
|
2017-02-09 11:53:55 +08:00
|
|
|
template <typename SUBNET> using level0 = res_down<256,SUBNET>;
|
|
|
|
template <typename SUBNET> using level1 = res<256,res<256,res_down<256,SUBNET>>>;
|
|
|
|
template <typename SUBNET> using level2 = res<128,res<128,res_down<128,SUBNET>>>;
|
|
|
|
template <typename SUBNET> using level3 = res<64,res<64,res<64,res_down<64,SUBNET>>>>;
|
|
|
|
template <typename SUBNET> using level4 = res<32,res<32,res<32,SUBNET>>>;
|
2016-12-18 03:29:29 +08:00
|
|
|
|
2017-02-09 11:53:55 +08:00
|
|
|
template <typename SUBNET> using alevel0 = ares_down<256,SUBNET>;
|
|
|
|
template <typename SUBNET> using alevel1 = ares<256,ares<256,ares_down<256,SUBNET>>>;
|
|
|
|
template <typename SUBNET> using alevel2 = ares<128,ares<128,ares_down<128,SUBNET>>>;
|
|
|
|
template <typename SUBNET> using alevel3 = ares<64,ares<64,ares<64,ares_down<64,SUBNET>>>>;
|
|
|
|
template <typename SUBNET> using alevel4 = ares<32,ares<32,ares<32,SUBNET>>>;
|
2016-12-18 03:29:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
// training network type
|
2016-12-19 02:20:37 +08:00
|
|
|
using net_type = loss_metric<fc_no_bias<128,avg_pool_everything<
|
2017-02-09 11:53:55 +08:00
|
|
|
level0<
|
2016-12-18 03:29:29 +08:00
|
|
|
level1<
|
|
|
|
level2<
|
|
|
|
level3<
|
|
|
|
level4<
|
2017-02-09 11:53:55 +08:00
|
|
|
max_pool<3,3,2,2,relu<bn_con<con<32,7,7,2,2,
|
2017-02-10 01:38:39 +08:00
|
|
|
input_rgb_image
|
2017-02-09 11:53:55 +08:00
|
|
|
>>>>>>>>>>>>;
|
2016-12-18 03:29:29 +08:00
|
|
|
|
|
|
|
// testing network type (replaced batch normalization with fixed affine transforms)
|
2016-12-19 02:20:37 +08:00
|
|
|
using anet_type = loss_metric<fc_no_bias<128,avg_pool_everything<
|
2017-02-09 11:53:55 +08:00
|
|
|
alevel0<
|
2016-12-18 03:29:29 +08:00
|
|
|
alevel1<
|
|
|
|
alevel2<
|
|
|
|
alevel3<
|
|
|
|
alevel4<
|
2017-02-09 11:53:55 +08:00
|
|
|
max_pool<3,3,2,2,relu<affine<con<32,7,7,2,2,
|
2017-02-10 01:38:39 +08:00
|
|
|
input_rgb_image
|
2017-02-09 11:53:55 +08:00
|
|
|
>>>>>>>>>>>>;
|
2016-12-18 03:29:29 +08:00
|
|
|
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
int main(int argc, char** argv)
|
|
|
|
{
|
|
|
|
if (argc != 2)
|
|
|
|
{
|
2016-12-18 09:41:36 +08:00
|
|
|
cout << "Give a folder as input. It should contain sub-folders of images and we will " << endl;
|
|
|
|
cout << "learn to distinguish between these sub-folders with metric learning. " << endl;
|
|
|
|
cout << "For example, you can run this program on the very small examples/johns dataset" << endl;
|
|
|
|
cout << "that comes with dlib by running this command:" << endl;
|
|
|
|
cout << " ./dnn_metric_learning_on_images_ex johns" << endl;
|
2016-12-18 03:29:29 +08:00
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto objs = load_objects_list(argv[1]);
|
|
|
|
|
|
|
|
cout << "objs.size(): "<< objs.size() << endl;
|
|
|
|
|
|
|
|
std::vector<matrix<rgb_pixel>> images;
|
|
|
|
std::vector<unsigned long> labels;
|
|
|
|
|
|
|
|
|
|
|
|
net_type net;
|
|
|
|
|
|
|
|
dnn_trainer<net_type> trainer(net, sgd(0.0005, 0.9));
|
|
|
|
trainer.set_learning_rate(0.1);
|
|
|
|
trainer.be_verbose();
|
|
|
|
trainer.set_synchronization_file("face_metric_sync", std::chrono::minutes(5));
|
2016-12-18 09:41:36 +08:00
|
|
|
// I've set this to something really small to make the example terminate
|
|
|
|
// sooner. But when you really want to train a good model you should set
|
|
|
|
// this to something like 8000 so training doesn't terminate too early.
|
2016-12-18 03:29:29 +08:00
|
|
|
trainer.set_iterations_without_progress_threshold(300);
|
|
|
|
|
2016-12-18 09:41:36 +08:00
|
|
|
// If you have a lot of data then it might not be reasonable to load it all
|
|
|
|
// into RAM. So you will need to be sure you are decompressing your images
|
|
|
|
// and loading them fast enough to keep the GPU occupied. I like to do this
|
|
|
|
// using the following coding pattern: create a bunch of threads that dump
|
|
|
|
// mini-batches into dlib::pipes.
|
2016-12-18 03:29:29 +08:00
|
|
|
dlib::pipe<std::vector<matrix<rgb_pixel>>> qimages(4);
|
|
|
|
dlib::pipe<std::vector<unsigned long>> qlabels(4);
|
|
|
|
auto data_loader = [&qimages, &qlabels, &objs](time_t seed)
|
|
|
|
{
|
|
|
|
dlib::rand rnd(time(0)+seed);
|
|
|
|
std::vector<matrix<rgb_pixel>> images;
|
|
|
|
std::vector<unsigned long> labels;
|
|
|
|
while(qimages.is_enabled())
|
|
|
|
{
|
|
|
|
try
|
|
|
|
{
|
2016-12-18 09:41:36 +08:00
|
|
|
load_mini_batch(5, 5, rnd, objs, images, labels);
|
2016-12-18 03:29:29 +08:00
|
|
|
qimages.enqueue(images);
|
|
|
|
qlabels.enqueue(labels);
|
|
|
|
}
|
|
|
|
catch(std::exception& e)
|
|
|
|
{
|
|
|
|
cout << "EXCEPTION IN LOADING DATA" << endl;
|
|
|
|
cout << e.what() << endl;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
2016-12-18 09:41:36 +08:00
|
|
|
// Run the data_loader from 5 threads. You should set the number of threads
|
|
|
|
// relative to the number of CPU cores you have.
|
2016-12-18 03:29:29 +08:00
|
|
|
std::thread data_loader1([data_loader](){ data_loader(1); });
|
|
|
|
std::thread data_loader2([data_loader](){ data_loader(2); });
|
|
|
|
std::thread data_loader3([data_loader](){ data_loader(3); });
|
|
|
|
std::thread data_loader4([data_loader](){ data_loader(4); });
|
|
|
|
std::thread data_loader5([data_loader](){ data_loader(5); });
|
|
|
|
|
|
|
|
|
|
|
|
// Here we do the training. We keep passing mini-batches to the trainer until the
|
|
|
|
// learning rate has dropped low enough.
|
|
|
|
while(trainer.get_learning_rate() >= 1e-4)
|
|
|
|
{
|
|
|
|
qimages.dequeue(images);
|
|
|
|
qlabels.dequeue(labels);
|
|
|
|
trainer.train_one_step(images, labels);
|
|
|
|
}
|
|
|
|
|
2016-12-18 09:41:36 +08:00
|
|
|
// Wait for training threads to stop
|
2016-12-18 03:29:29 +08:00
|
|
|
trainer.get_net();
|
|
|
|
cout << "done training" << endl;
|
|
|
|
|
|
|
|
// Save the network to disk
|
|
|
|
net.clean();
|
|
|
|
serialize("metric_network_renset.dat") << net;
|
|
|
|
|
|
|
|
// stop all the data loading threads and wait for them to terminate.
|
|
|
|
qimages.disable();
|
|
|
|
qlabels.disable();
|
|
|
|
data_loader1.join();
|
|
|
|
data_loader2.join();
|
|
|
|
data_loader3.join();
|
|
|
|
data_loader4.join();
|
|
|
|
data_loader5.join();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2016-12-18 09:41:36 +08:00
|
|
|
// Now, just to show an example of how you would use the network, let's check how well
|
2016-12-18 03:29:29 +08:00
|
|
|
// it performs on the training data.
|
|
|
|
dlib::rand rnd(time(0));
|
2016-12-18 09:41:36 +08:00
|
|
|
load_mini_batch(5, 5, rnd, objs, images, labels);
|
2016-12-18 03:29:29 +08:00
|
|
|
|
2016-12-19 02:10:13 +08:00
|
|
|
// Normally you would use the non-batch-normalized version of the network to do
|
|
|
|
// testing, which is what we do here.
|
|
|
|
anet_type testing_net = net;
|
|
|
|
|
2016-12-18 03:29:29 +08:00
|
|
|
// Run all the images through the network to get their vector embeddings.
|
2016-12-19 02:10:13 +08:00
|
|
|
std::vector<matrix<float,0,1>> embedded = testing_net(images);
|
2016-12-18 03:29:29 +08:00
|
|
|
|
2016-12-18 09:41:36 +08:00
|
|
|
// Now, check if the embedding puts images with the same labels near each other and
|
|
|
|
// images with different labels far apart.
|
2016-12-18 03:29:29 +08:00
|
|
|
int num_right = 0;
|
|
|
|
int num_wrong = 0;
|
|
|
|
for (size_t i = 0; i < embedded.size(); ++i)
|
|
|
|
{
|
|
|
|
for (size_t j = i+1; j < embedded.size(); ++j)
|
|
|
|
{
|
|
|
|
if (labels[i] == labels[j])
|
|
|
|
{
|
2016-12-18 09:41:36 +08:00
|
|
|
// The loss_metric layer will cause images with the same label to be less
|
2016-12-18 03:29:29 +08:00
|
|
|
// than net.loss_details().get_distance_threshold() distance from each
|
|
|
|
// other. So we can use that distance value as our testing threshold.
|
2016-12-19 02:10:13 +08:00
|
|
|
if (length(embedded[i]-embedded[j]) < testing_net.loss_details().get_distance_threshold())
|
2016-12-18 03:29:29 +08:00
|
|
|
++num_right;
|
|
|
|
else
|
|
|
|
++num_wrong;
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
2016-12-19 02:10:13 +08:00
|
|
|
if (length(embedded[i]-embedded[j]) >= testing_net.loss_details().get_distance_threshold())
|
2016-12-18 03:29:29 +08:00
|
|
|
++num_right;
|
|
|
|
else
|
|
|
|
++num_wrong;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
cout << "num_right: "<< num_right << endl;
|
|
|
|
cout << "num_wrong: "<< num_wrong << endl;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|