mirror of https://github.com/davisking/dlib.git
Added a simple python interface for training fhog object detectors.
This commit is contained in:
parent
15207aad71
commit
ddc44067b4
|
@ -8,6 +8,7 @@
|
|||
#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
|
||||
#include <dlib/image_processing/frontal_face_detector.h>
|
||||
#include <dlib/gui_widgets.h>
|
||||
#include "simple_object_detector.h"
|
||||
|
||||
|
||||
using namespace dlib;
|
||||
|
@ -120,6 +121,38 @@ std::vector<rectangle> run_detector (
|
|||
}
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
struct simple_object_detector_py
|
||||
{
|
||||
simple_object_detector detector;
|
||||
unsigned int upsampling_amount;
|
||||
|
||||
std::vector<rectangle> run_detector1 (object img, const unsigned int upsampling_amount_)
|
||||
{ return ::run_detector(detector, img, upsampling_amount_); }
|
||||
|
||||
std::vector<rectangle> run_detector2 (object img)
|
||||
{ return ::run_detector(detector, img, upsampling_amount); }
|
||||
};
|
||||
|
||||
void serialize (const simple_object_detector_py& item, std::ostream& out)
|
||||
{
|
||||
int version = 1;
|
||||
serialize(item.detector, out);
|
||||
serialize(version, out);
|
||||
serialize(item.upsampling_amount, out);
|
||||
}
|
||||
|
||||
void deserialize (simple_object_detector_py& item, std::istream& in)
|
||||
{
|
||||
int version = 0;
|
||||
deserialize(item.detector, in);
|
||||
deserialize(version, in);
|
||||
if (version != 1)
|
||||
throw dlib::serialization_error("Unexpected version found while deserializing a simple_object_detector.");
|
||||
deserialize(item.upsampling_amount, in);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void image_window_set_image (
|
||||
|
@ -163,16 +196,11 @@ boost::shared_ptr<image_window> make_image_window_from_image_and_title(object im
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
boost::shared_ptr<frontal_face_detector> load_fhog_object_detector_from_file (
|
||||
const std::string& filename
|
||||
)
|
||||
string print_simple_test_results(const simple_test_results& r)
|
||||
{
|
||||
ifstream fin(filename.c_str(), ios::binary);
|
||||
if (!fin)
|
||||
throw dlib::error("Unable to open " + filename);
|
||||
boost::shared_ptr<frontal_face_detector> detector(new frontal_face_detector());
|
||||
deserialize(*detector, fin);
|
||||
return detector;
|
||||
std::ostringstream sout;
|
||||
sout << "precision: "<<r.precision << ", recall: "<< r.recall << ", average precision: " << r.average_precision;
|
||||
return sout.str();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -181,6 +209,23 @@ void bind_object_detection()
|
|||
{
|
||||
using boost::python::arg;
|
||||
|
||||
class_<simple_object_detector_training_options>("simple_object_detector_training_options")
|
||||
.add_property("be_verbose", &simple_object_detector_training_options::be_verbose,
|
||||
&simple_object_detector_training_options::be_verbose)
|
||||
.add_property("add_left_right_image_flips", &simple_object_detector_training_options::add_left_right_image_flips,
|
||||
&simple_object_detector_training_options::add_left_right_image_flips)
|
||||
.add_property("detection_window_size", &simple_object_detector_training_options::detection_window_size,
|
||||
&simple_object_detector_training_options::detection_window_size)
|
||||
.add_property("num_threads", &simple_object_detector_training_options::num_threads,
|
||||
&simple_object_detector_training_options::num_threads);
|
||||
|
||||
class_<simple_test_results>("simple_test_results")
|
||||
.add_property("precision", &simple_test_results::precision)
|
||||
.add_property("recall", &simple_test_results::recall)
|
||||
.add_property("average_precision", &simple_test_results::average_precision)
|
||||
.def("__str__", &::print_simple_test_results);
|
||||
|
||||
|
||||
{
|
||||
typedef rectangle type;
|
||||
class_<type>("rectangle", "This object represents a rectangular area of an image.")
|
||||
|
@ -199,12 +244,77 @@ void bind_object_detection()
|
|||
def("get_frontal_face_detector", get_frontal_face_detector,
|
||||
"Returns the default face detector");
|
||||
|
||||
def("train_simple_object_detector", train_simple_object_detector,
|
||||
(arg("dataset_filename"), arg("detector_output_filename"), arg("C"), arg("options")=simple_object_detector_training_options()),
|
||||
"whatever");
|
||||
|
||||
def("test_simple_object_detector", test_simple_object_detector,
|
||||
(arg("dataset_filename"), arg("detector_filename")),
|
||||
"whatever");
|
||||
|
||||
{
|
||||
typedef simple_object_detector_py type;
|
||||
class_<type>("simple_object_detector",
|
||||
"This object represents a sliding window histogram-of-oriented-gradients based object detector.")
|
||||
.def("__init__", make_constructor(&load_object_from_file<type>),
|
||||
"Loads a simple_object_detector from a file that contains the output of the \n\
|
||||
train_simple_object_detector() routine."
|
||||
/*!
|
||||
Loads a simple_object_detector from a file that contains the output of the
|
||||
train_simple_object_detector() routine.
|
||||
!*/)
|
||||
.def("__call__", &type::run_detector1, (arg("image"), arg("upsample_num_times")),
|
||||
"requires \n\
|
||||
- image is a numpy ndarray containing either an 8bit grayscale or RGB \n\
|
||||
image. \n\
|
||||
- upsample_num_times >= 0 \n\
|
||||
ensures \n\
|
||||
- This function runs the object detector on the input image and returns \n\
|
||||
a list of detections. \n\
|
||||
- Upsamples the image upsample_num_times before running the basic \n\
|
||||
detector. If you don't know how many times you want to upsample then \n\
|
||||
don't provide a value for upsample_num_times and an appropriate \n\
|
||||
default will be used."
|
||||
/*!
|
||||
requires
|
||||
- image is a numpy ndarray containing either an 8bit grayscale or RGB
|
||||
image.
|
||||
- upsample_num_times >= 0
|
||||
ensures
|
||||
- This function runs the object detector on the input image and returns
|
||||
a list of detections.
|
||||
- Upsamples the image upsample_num_times before running the basic
|
||||
detector. If you don't know how many times you want to upsample then
|
||||
don't provide a value for upsample_num_times and an appropriate
|
||||
default will be used.
|
||||
!*/
|
||||
)
|
||||
.def("__call__", &type::run_detector2, (arg("image")),
|
||||
"requires \n\
|
||||
- image is a numpy ndarray containing either an 8bit grayscale or RGB \n\
|
||||
image. \n\
|
||||
ensures \n\
|
||||
- This function runs the object detector on the input image and returns \n\
|
||||
a list of detections. "
|
||||
/*!
|
||||
requires
|
||||
- image is a numpy ndarray containing either an 8bit grayscale or RGB
|
||||
image.
|
||||
ensures
|
||||
- This function runs the object detector on the input image and returns
|
||||
a list of detections.
|
||||
!*/
|
||||
)
|
||||
.def_pickle(serialize_pickle<type>());
|
||||
}
|
||||
|
||||
{
|
||||
typedef frontal_face_detector type;
|
||||
class_<type>("fhog_object_detector",
|
||||
"This object represents a sliding window histogram-of-oriented-gradients based object detector.")
|
||||
.def("__init__", make_constructor(&load_fhog_object_detector_from_file),
|
||||
"Loads a fhog_object_detector from a file.")
|
||||
.def("__init__", make_constructor(&load_object_from_file<type>),
|
||||
"Loads a fhog_object_detector from a file that contains a serialized \n\
|
||||
object_detector<scan_fhog_pyramid<pyramid_down<6>>> object. " )
|
||||
.def("__call__", &::run_detector, (arg("image"), arg("upsample_num_times")=0),
|
||||
"requires \n\
|
||||
- image is a numpy ndarray containing either an 8bit \n\
|
||||
|
|
|
@ -0,0 +1,259 @@
|
|||
// Copyright (C) 2014 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#ifndef DLIB_SIMPLE_ObJECT_DETECTOR_H__
|
||||
#define DLIB_SIMPLE_ObJECT_DETECTOR_H__
|
||||
|
||||
#include "simple_object_detector_abstract.h"
|
||||
#include "dlib/image_processing/object_detector.h"
|
||||
#include "dlib/string.h"
|
||||
#include "dlib/image_processing/scan_fhog_pyramid.h"
|
||||
#include "dlib/svm/structural_object_detection_trainer.h"
|
||||
#include "dlib/geometry.h"
|
||||
#include "dlib/data_io/load_image_dataset.h"
|
||||
#include "dlib/image_processing/remove_unobtainable_rectangles.h"
|
||||
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
typedef object_detector<scan_fhog_pyramid<pyramid_down<6> > > simple_object_detector;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
struct simple_object_detector_training_options
|
||||
{
|
||||
simple_object_detector_training_options()
|
||||
{
|
||||
be_verbose = false;
|
||||
add_left_right_image_flips = false;
|
||||
num_threads = 4;
|
||||
detection_window_size = 80*80;
|
||||
}
|
||||
|
||||
bool be_verbose;
|
||||
bool add_left_right_image_flips;
|
||||
unsigned long num_threads;
|
||||
unsigned long detection_window_size;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
namespace impl
|
||||
{
|
||||
inline void pick_best_window_size (
|
||||
const std::vector<std::vector<rectangle> >& boxes,
|
||||
unsigned long& width,
|
||||
unsigned long& height,
|
||||
const unsigned long target_size
|
||||
)
|
||||
{
|
||||
// find the average width and height
|
||||
running_stats<double> avg_width, avg_height;
|
||||
for (unsigned long i = 0; i < boxes.size(); ++i)
|
||||
{
|
||||
for (unsigned long j = 0; j < boxes[i].size(); ++j)
|
||||
{
|
||||
avg_width.add(boxes[i][j].width());
|
||||
avg_height.add(boxes[i][j].height());
|
||||
}
|
||||
}
|
||||
|
||||
// now adjust the box size so that it is about target_pixels pixels in size
|
||||
double size = avg_width.mean()*avg_height.mean();
|
||||
double scale = std::sqrt(target_size/size);
|
||||
|
||||
width = (unsigned long)(avg_width.mean()*scale+0.5);
|
||||
height = (unsigned long)(avg_height.mean()*scale+0.5);
|
||||
// make sure the width and height never round to zero.
|
||||
if (width == 0)
|
||||
width = 1;
|
||||
if (height == 0)
|
||||
height = 1;
|
||||
}
|
||||
|
||||
inline bool contains_any_boxes (
|
||||
const std::vector<std::vector<rectangle> >& boxes
|
||||
)
|
||||
{
|
||||
for (unsigned long i = 0; i < boxes.size(); ++i)
|
||||
{
|
||||
if (boxes[i].size() != 0)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline void throw_invalid_box_error_message (
|
||||
const std::string& dataset_filename,
|
||||
const std::vector<std::vector<rectangle> >& removed,
|
||||
const simple_object_detector_training_options& options,
|
||||
const unsigned long width,
|
||||
const unsigned long height
|
||||
)
|
||||
{
|
||||
image_dataset_metadata::dataset data;
|
||||
load_image_dataset_metadata(data, dataset_filename);
|
||||
|
||||
std::ostringstream sout;
|
||||
sout << "Error, an impossible set of object boxes was given for training. ";
|
||||
sout << "All the boxes need to have a similar aspect ratio and also not be ";
|
||||
sout << "smaller than about " << options.detection_window_size/16 << " pixels in area. ";
|
||||
sout << "The following images contain invalid boxes:\n";
|
||||
std::ostringstream sout2;
|
||||
for (unsigned long i = 0; i < removed.size(); ++i)
|
||||
{
|
||||
if (removed[i].size() != 0)
|
||||
{
|
||||
const std::string imgname = data.images[i].filename;
|
||||
sout2 << " " << imgname << "\n";
|
||||
}
|
||||
}
|
||||
throw error("\n"+wrap_string(sout.str()) + "\n" + sout2.str());
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
inline void train_simple_object_detector (
|
||||
const std::string& dataset_filename,
|
||||
const std::string& detector_output_filename,
|
||||
const double C,
|
||||
const simple_object_detector_training_options& options = simple_object_detector_training_options()
|
||||
)
|
||||
{
|
||||
if (C <= 0)
|
||||
throw error("Invalid C value given to train_simple_object_detector(), C must be > 0.");
|
||||
|
||||
dlib::array<array2d<unsigned char> > images;
|
||||
std::vector<std::vector<rectangle> > boxes, ignore;
|
||||
ignore = load_image_dataset(images, boxes, dataset_filename);
|
||||
|
||||
if (impl::contains_any_boxes(boxes) == false)
|
||||
throw error("Error, the dataset in " + dataset_filename + " does not have any labeled object boxes in it.");
|
||||
|
||||
typedef scan_fhog_pyramid<pyramid_down<6> > image_scanner_type;
|
||||
image_scanner_type scanner;
|
||||
unsigned long width, height;
|
||||
impl::pick_best_window_size(boxes, width, height, options.detection_window_size);
|
||||
scanner.set_detection_window_size(width, height);
|
||||
structural_object_detection_trainer<image_scanner_type> trainer(scanner);
|
||||
trainer.set_num_threads(options.num_threads);
|
||||
trainer.set_c(C);
|
||||
trainer.set_epsilon(0.01);
|
||||
if (options.be_verbose)
|
||||
{
|
||||
std::cout << "Training with C: " << C << std::endl;
|
||||
std::cout << "Training using " << options.num_threads << " threads."<< std::endl;
|
||||
std::cout << "Training with sliding window " << width << " pixels wide by " << height << " pixels tall." << std::endl;
|
||||
if (options.add_left_right_image_flips)
|
||||
std::cout << "Training on both left and right flipped versions of images." << std::endl;
|
||||
trainer.be_verbose();
|
||||
}
|
||||
|
||||
|
||||
unsigned long upsample_amount = 0;
|
||||
|
||||
// now make sure all the boxes are obtainable by the scanner. We will try and
|
||||
// upsample the images at most two times to help make the boxes obtainable.
|
||||
std::vector<std::vector<rectangle> > temp(boxes), removed;
|
||||
removed = remove_unobtainable_rectangles(trainer, images, temp);
|
||||
if (impl::contains_any_boxes(removed))
|
||||
{
|
||||
++upsample_amount;
|
||||
if (options.be_verbose)
|
||||
std::cout << "upsample images..." << std::endl;
|
||||
upsample_image_dataset<pyramid_down<2> >(images, boxes, ignore);
|
||||
temp = boxes;
|
||||
removed = remove_unobtainable_rectangles(trainer, images, temp);
|
||||
if (impl::contains_any_boxes(removed))
|
||||
{
|
||||
++upsample_amount;
|
||||
if (options.be_verbose)
|
||||
std::cout << "upsample images..." << std::endl;
|
||||
upsample_image_dataset<pyramid_down<2> >(images, boxes, ignore);
|
||||
temp = boxes;
|
||||
removed = remove_unobtainable_rectangles(trainer, images, temp);
|
||||
}
|
||||
}
|
||||
// if we weren't able to get all the boxes to match then throw an error
|
||||
if (impl::contains_any_boxes(removed))
|
||||
impl::throw_invalid_box_error_message(dataset_filename, removed, options, width, height);
|
||||
|
||||
if (options.add_left_right_image_flips)
|
||||
add_image_left_right_flips(images, boxes, ignore);
|
||||
|
||||
simple_object_detector detector = trainer.train(images, boxes, ignore);
|
||||
|
||||
std::ofstream fout(detector_output_filename.c_str(), std::ios::binary);
|
||||
int version = 1;
|
||||
serialize(detector, fout);
|
||||
serialize(version, fout);
|
||||
serialize(upsample_amount, fout);
|
||||
|
||||
if (options.be_verbose)
|
||||
{
|
||||
std::cout << "Training complete, saved detector to file " << detector_output_filename << std::endl;
|
||||
std::cout << "Trained with C: " << C << std::endl;
|
||||
std::cout << "Trained using " << options.num_threads << " threads."<< std::endl;
|
||||
std::cout << "Trained with sliding window " << width << " pixels wide by " << height << " pixels tall." << std::endl;
|
||||
if (upsample_amount != 0)
|
||||
{
|
||||
if (upsample_amount == 1)
|
||||
std::cout << "Upsampled images " << upsample_amount << " time to allow detection of small boxes." << std::endl;
|
||||
else
|
||||
std::cout << "Upsampled images " << upsample_amount << " times to allow detection of small boxes." << std::endl;
|
||||
}
|
||||
if (options.add_left_right_image_flips)
|
||||
std::cout << "Trained on both left and right flipped versions of images." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
struct simple_test_results
|
||||
{
|
||||
double precision;
|
||||
double recall;
|
||||
double average_precision;
|
||||
};
|
||||
|
||||
inline const simple_test_results test_simple_object_detector (
|
||||
const std::string& dataset_filename,
|
||||
const std::string& detector_filename
|
||||
)
|
||||
{
|
||||
dlib::array<array2d<unsigned char> > images;
|
||||
std::vector<std::vector<rectangle> > boxes, ignore;
|
||||
ignore = load_image_dataset(images, boxes, dataset_filename);
|
||||
|
||||
simple_object_detector detector;
|
||||
int version = 0;
|
||||
unsigned int upsample_amount = 0;
|
||||
std::ifstream fin(detector_filename.c_str(), std::ios::binary);
|
||||
if (!fin)
|
||||
throw error("Unable to open file " + detector_filename);
|
||||
deserialize(detector, fin);
|
||||
deserialize(version, fin);
|
||||
if (version != 1)
|
||||
throw error("Unknown simple_object_detector format.");
|
||||
deserialize(upsample_amount, fin);
|
||||
|
||||
for (unsigned int i = 0; i < upsample_amount; ++i)
|
||||
upsample_image_dataset<pyramid_down<2> >(images, boxes);
|
||||
|
||||
matrix<double,1,3> res = test_object_detection_function(detector, images, boxes, ignore);
|
||||
simple_test_results ret;
|
||||
ret.precision = res(0);
|
||||
ret.recall = res(1);
|
||||
ret.average_precision = res(2);
|
||||
return ret;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_SIMPLE_ObJECT_DETECTOR_H__
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
// Copyright (C) 2014 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#undef DLIB_SIMPLE_ObJECT_DETECTOR_ABSTRACT_H__
|
||||
#ifdef DLIB_SIMPLE_ObJECT_DETECTOR_ABSTRACT_H__
|
||||
|
||||
#include <dlib/image_processing/object_detector_abstract.h>
|
||||
#include <dlib/image_processing/scan_fhog_pyramid_abstract.h>
|
||||
#include <dlib/svm/structural_object_detection_trainer_abstract.h>
|
||||
#include <dlib/data_io/image_dataset_metadata.h>
|
||||
#include <dlib/matrix.h>
|
||||
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
struct fhog_training_options
|
||||
{
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This object is a container for the more advanced options to the
|
||||
train_simple_object_detector() routine. The parameters have the following
|
||||
interpretations:
|
||||
- be_verbose: If true, train_simple_object_detector() will print out a
|
||||
lot of information to the screen while training.
|
||||
- add_left_right_image_flips: if true, train_simple_object_detector()
|
||||
will assume the objects are left/right symmetric and add in left
|
||||
right flips of the training images. This doubles the size of the
|
||||
training dataset.
|
||||
- num_threads: train_simple_object_detector() will use this many
|
||||
threads of execution. Set this to the number of CPU cores on your
|
||||
machine to obtain the fastest training speed.
|
||||
- detection_window_size: The sliding window used will have about this
|
||||
many pixels inside it.
|
||||
!*/
|
||||
|
||||
fhog_training_options()
|
||||
{
|
||||
be_verbose = false;
|
||||
add_left_right_image_flips = false;
|
||||
num_threads = 4;
|
||||
detection_window_size = 80*80;
|
||||
}
|
||||
|
||||
bool be_verbose;
|
||||
bool add_left_right_image_flips;
|
||||
unsigned long num_threads;
|
||||
unsigned long detection_window_size;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
typedef object_detector<scan_fhog_pyramid<pyramid_down<6> > > simple_object_detector;
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void train_simple_object_detector (
|
||||
const std::string& dataset_filename,
|
||||
const std::string& detector_output_filename,
|
||||
const double C,
|
||||
const fhog_training_options& options = fhog_training_options()
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- C > 0
|
||||
ensures
|
||||
- Uses the structural_object_detection_trainer to train a
|
||||
simple_object_detector based on the labeled images in the XML file
|
||||
dataset_filename. This function assumes the file dataset_filename is in the
|
||||
XML format produced by the save_image_dataset_metadata() routine.
|
||||
- This function will apply a reasonable set of default parameters and
|
||||
preprocessing techniques to the training procedure for simple_object_detector
|
||||
objects. So the point of this function is to provide you with a very easy
|
||||
way to train a basic object detector.
|
||||
- C is the usual SVM C regularization parameter. So it is passed to
|
||||
structural_object_detection_trainer::set_c(). Larger values of C will
|
||||
encourage the trainer to fit the data better but might lead to overfitting.
|
||||
Therefore, you must determine the proper setting of this parameter
|
||||
experimentally.
|
||||
- The trained object detector is serialized to the file detector_output_filename.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
struct simple_test_results
|
||||
{
|
||||
double precision;
|
||||
double recall;
|
||||
double average_precision;
|
||||
};
|
||||
|
||||
inline const simple_test_results test_simple_object_detector (
|
||||
const std::string& dataset_filename,
|
||||
const std::string& detector_filename
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- Loads an image dataset from dataset_filename. We assume dataset_filename is
|
||||
a file using the XML format written by save_image_dataset_metadata().
|
||||
- Loads a simple_object_detector from the file detector_filename. This means
|
||||
detector_filename should be a file produced by the train_simple_object_detector()
|
||||
routine defined above.
|
||||
- This function tests the detector against the dataset and returns three
|
||||
numbers that tell you how well the detector does at detecting the objects in
|
||||
the dataset. The return value of this function is identical to that of
|
||||
test_object_detection_function(). Therefore, see the documentation for
|
||||
test_object_detection_function() for an extended definition of these metrics.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_SIMPLE_ObJECT_DETECTOR_ABSTRACT_H__
|
||||
|
||||
|
Loading…
Reference in New Issue