mirror of https://github.com/davisking/dlib.git
173 lines
6.3 KiB
C++
173 lines
6.3 KiB
C++
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
|
/*
|
|
This example shows how to do semantic segmentation on an image using net pretrained
|
|
on the PASCAL VOC2012 dataset. For an introduction to what segmentation is, see the
|
|
accompanying header file dnn_semantic_segmentation_ex.h.
|
|
|
|
Instructions how to run the example:
|
|
1. Download the PASCAL VOC2012 data, and untar it somewhere.
|
|
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
|
|
2. Build the dnn_semantic_segmentation_train_ex example program.
|
|
3. Run:
|
|
./dnn_semantic_segmentation_train_ex /path/to/VOC2012
|
|
4. Wait while the network is being trained.
|
|
5. Build the dnn_semantic_segmentation_ex example program.
|
|
6. Run:
|
|
./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images
|
|
|
|
An alternative to steps 2-4 above is to download a pre-trained network
|
|
from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn
|
|
|
|
It would be a good idea to become familiar with dlib's DNN tooling before reading this
|
|
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
|
|
before reading this example program.
|
|
*/
|
|
|
|
#include "dnn_semantic_segmentation_ex.h"
|
|
|
|
#include <iostream>
|
|
#include <dlib/data_io.h>
|
|
#include <dlib/gui_widgets.h>
|
|
|
|
using namespace std;
|
|
using namespace dlib;
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
|
|
// is represented using an RGB color value. We associate each class also to an index in the
|
|
// range [0, 20], used internally by the network. To generate nice RGB representations of
|
|
// inference results, we need to be able to convert the index values to the corresponding
|
|
// RGB values.
|
|
|
|
// Given an index in the range [0, 20], find the corresponding PASCAL VOC2012 class
|
|
// (e.g., 'dog').
|
|
const Voc2012class& find_voc2012_class(const uint16_t& index_label)
|
|
{
|
|
return find_voc2012_class(
|
|
[&index_label](const Voc2012class& voc2012class)
|
|
{
|
|
return index_label == voc2012class.index;
|
|
}
|
|
);
|
|
}
|
|
|
|
// Convert an index in the range [0, 20] to a corresponding RGB class label.
|
|
inline rgb_pixel index_label_to_rgb_label(uint16_t index_label)
|
|
{
|
|
return find_voc2012_class(index_label).rgb_label;
|
|
}
|
|
|
|
// Convert an image containing indexes in the range [0, 20] to a corresponding
|
|
// image containing RGB class labels.
|
|
void index_label_image_to_rgb_label_image(
|
|
const matrix<uint16_t>& index_label_image,
|
|
matrix<rgb_pixel>& rgb_label_image
|
|
)
|
|
{
|
|
const long nr = index_label_image.nr();
|
|
const long nc = index_label_image.nc();
|
|
|
|
rgb_label_image.set_size(nr, nc);
|
|
|
|
for (long r = 0; r < nr; ++r)
|
|
{
|
|
for (long c = 0; c < nc; ++c)
|
|
{
|
|
rgb_label_image(r, c) = index_label_to_rgb_label(index_label_image(r, c));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Find the most prominent class label from amongst the per-pixel predictions.
|
|
std::string get_most_prominent_non_background_classlabel(const matrix<uint16_t>& index_label_image)
|
|
{
|
|
const long nr = index_label_image.nr();
|
|
const long nc = index_label_image.nc();
|
|
|
|
std::vector<unsigned int> counters(class_count);
|
|
|
|
for (long r = 0; r < nr; ++r)
|
|
{
|
|
for (long c = 0; c < nc; ++c)
|
|
{
|
|
const uint16_t label = index_label_image(r, c);
|
|
++counters[label];
|
|
}
|
|
}
|
|
|
|
const auto max_element = std::max_element(counters.begin() + 1, counters.end());
|
|
const uint16_t most_prominent_index_label = max_element - counters.begin();
|
|
|
|
return find_voc2012_class(most_prominent_index_label).classlabel;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
int main(int argc, char** argv) try
|
|
{
|
|
if (argc != 2)
|
|
{
|
|
cout << "You call this program like this: " << endl;
|
|
cout << "./dnn_semantic_segmentation_train_ex /path/to/images" << endl;
|
|
cout << endl;
|
|
cout << "You will also need a trained 'semantic_segmentation_voc2012net.dnn' file." << endl;
|
|
cout << "You can either train it yourself (see example program" << endl;
|
|
cout << "dnn_semantic_segmentation_train_ex), or download a" << endl;
|
|
cout << "copy from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn" << endl;
|
|
return 1;
|
|
}
|
|
|
|
// Read the file containing the trained network from the working directory.
|
|
anet_type net;
|
|
deserialize("semantic_segmentation_voc2012net.dnn") >> net;
|
|
|
|
// Show inference results in a window.
|
|
image_window win;
|
|
|
|
matrix<rgb_pixel> input_image;
|
|
matrix<uint16_t> index_label_image;
|
|
matrix<rgb_pixel> rgb_label_image;
|
|
|
|
// Find supported image files.
|
|
const std::vector<file> files = dlib::get_files_in_directory_tree(argv[1],
|
|
dlib::match_endings(".jpeg .jpg .png"));
|
|
|
|
cout << "Found " << files.size() << " images, processing..." << endl;
|
|
|
|
for (const file& file : files)
|
|
{
|
|
// Load the input image.
|
|
load_image(input_image, file.full_name());
|
|
|
|
// Create predictions for each pixel. At this point, the type of each prediction
|
|
// is an index (a value between 0 and 20). Note that the net may return an image
|
|
// that is not exactly the same size as the input.
|
|
const matrix<uint16_t> temp = net(input_image);
|
|
|
|
// Crop the returned image to be exactly the same size as the input.
|
|
const chip_details chip_details(
|
|
centered_rect(temp.nc() / 2, temp.nr() / 2, input_image.nc(), input_image.nr()),
|
|
chip_dims(input_image.nr(), input_image.nc())
|
|
);
|
|
extract_image_chip(temp, chip_details, index_label_image, interpolate_nearest_neighbor());
|
|
|
|
// Convert the indexes to RGB values.
|
|
index_label_image_to_rgb_label_image(index_label_image, rgb_label_image);
|
|
|
|
// Show the input image on the left, and the predicted RGB labels on the right.
|
|
win.set_image(join_rows(input_image, rgb_label_image));
|
|
|
|
// Find the most prominent class label from amongst the per-pixel predictions.
|
|
const std::string classlabel = get_most_prominent_non_background_classlabel(index_label_image);
|
|
|
|
cout << file.name() << " : " << classlabel << " - hit enter to process the next image";
|
|
cin.get();
|
|
}
|
|
}
|
|
catch(std::exception& e)
|
|
{
|
|
cout << e.what() << endl;
|
|
}
|
|
|