Updated the Python API, train_simple_object_detector() so you can call it

directly on already loaded data rather than needing to use an XML file as
input.
This commit is contained in:
Davis King 2014-08-12 19:47:41 -04:00
parent ae9546cbdc
commit cd71dab3f2
3 changed files with 138 additions and 14 deletions

View File

@ -99,3 +99,33 @@ for f in glob.glob(faces_folder+"/*.jpg"):
win.add_overlay(dets)
raw_input("Hit enter to continue")
# Finally, note that you don't have to use the XML based input to
# train_simple_object_detector(). If you have already loaded your training
# images and bounding boxes for the objects then you can call it as shown
# below.
# You just need to put your images into a list.
images = [io.imread(faces_folder + '/2008_002506.jpg'), io.imread(faces_folder + '/2009_004587.jpg') ]
# Then for each image you make a list of rectangles which give the pixel
# locations of the edges of the boxes.
boxes_img1 = ([dlib.rectangle(left=329, top=78, right=437, bottom=186),
dlib.rectangle(left=224, top=95, right=314, bottom=185),
dlib.rectangle(left=125, top=65, right=214, bottom=155) ] )
boxes_img2 = ([dlib.rectangle(left=154, top=46, right=228, bottom=121 ),
dlib.rectangle(left=266, top=280, right=328, bottom=342) ] )
# And then you aggregate those lists of boxes into one big list and then call
# train_simple_object_detector().
boxes = [boxes_img1, boxes_img2]
dlib.train_simple_object_detector(images, boxes, "detector2.svm", options)
# Now let's load the trained detector and look at its HOG filter!
detector2 = dlib.simple_object_detector("detector2.svm")
win_det.set_image(detector2)
raw_input("Hit enter to continue")

View File

@ -220,6 +220,39 @@ string print_simple_test_results(const simple_test_results& r)
return sout.str();
}
inline void train_simple_object_detector_on_images_py (
const object& pyimages,
const object& pyboxes,
const std::string& detector_output_filename,
const simple_object_detector_training_options& options
)
{
const unsigned long num_images = len(pyimages);
if (num_images != len(pyboxes))
throw dlib::error("The length of the boxes list must match the length of the images list.");
// We never have any ignore boxes for this version of the API.
std::vector<std::vector<rectangle> > ignore(num_images), boxes(num_images);
dlib::array<array2d<rgb_pixel> > images(num_images);
// Now copy the data into dlib based objects so we can call the trainer.
for (unsigned long i = 0; i < num_images; ++i)
{
const unsigned long num_boxes = len(pyboxes[i]);
for (unsigned long j = 0; j < num_boxes; ++j)
boxes[i].push_back(extract<rectangle>(pyboxes[i][j]));
object img = pyimages[i];
if (is_gray_python_image(img))
assign_image(images[i], numpy_gray_image(img));
else if (is_rgb_python_image(img))
assign_image(images[i], numpy_rgb_image(img));
else
throw dlib::error("Unsupported image type, must be 8bit gray or RGB image.");
}
train_simple_object_detector_on_images("", images, boxes, ignore, detector_output_filename, options);
}
// ----------------------------------------------------------------------------------------
void bind_object_detection()
@ -340,6 +373,39 @@ ensures \n\
!*/
);
def("train_simple_object_detector", train_simple_object_detector_on_images_py,
(arg("images"), arg("boxes"), arg("detector_output_filename"), arg("options")),
"requires \n\
- options.C > 0 \n\
- len(images) == len(boxes) \n\
- images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\
- boxes should be a list of lists of dlib.rectangle object. \n\
ensures \n\
- Uses the structural_object_detection_trainer to train a \n\
simple_object_detector based on the labeled images and bounding boxes. \n\
- This function will apply a reasonable set of default parameters and \n\
preprocessing techniques to the training procedure for simple_object_detector \n\
objects. So the point of this function is to provide you with a very easy \n\
way to train a basic object detector. \n\
- The trained object detector is serialized to the file detector_output_filename."
/*!
requires
- options.C > 0
- len(images) == len(boxes)
- images should be a list of numpy matrices that represent images, either RGB or grayscale.
- boxes should be a dlib.rectangles object (i.e. an array of rectangles).
- boxes should be a list of lists of dlib.rectangle object.
ensures
- Uses the structural_object_detection_trainer to train a
simple_object_detector based on the labeled images and bounding boxes.
- This function will apply a reasonable set of default parameters and
preprocessing techniques to the training procedure for simple_object_detector
objects. So the point of this function is to provide you with a very easy
way to train a basic object detector.
- The trained object detector is serialized to the file detector_output_filename.
!*/
);
def("test_simple_object_detector", test_simple_object_detector,
(arg("dataset_filename"), arg("detector_filename")),
"ensures \n\

View File

@ -95,21 +95,29 @@ namespace dlib
const simple_object_detector_training_options& options
)
{
image_dataset_metadata::dataset data;
load_image_dataset_metadata(data, dataset_filename);
std::ostringstream sout;
// Note that the 1/16 factor is here because we will try to upsample the image
// 2 times to accommodate small boxes. We also take the max because we want to
// lower bound the size of the smallest recommended box. This is because the
// 8x8 HOG cells can't really deal with really small object boxes.
sout << "Error! An impossible set of object boxes was given for training. ";
sout << "All the boxes need to have a similar aspect ratio and also not be ";
sout << "smaller than about " << options.detection_window_size/16 << " pixels in area. ";
sout << "The following images contain invalid boxes:\n";
sout << "smaller than about " << std::max<long>(20*20,options.detection_window_size/16) << " pixels in area. ";
std::ostringstream sout2;
for (unsigned long i = 0; i < removed.size(); ++i)
if (dataset_filename.size() != 0)
{
if (removed[i].size() != 0)
sout << "The following images contain invalid boxes:\n";
image_dataset_metadata::dataset data;
load_image_dataset_metadata(data, dataset_filename);
for (unsigned long i = 0; i < removed.size(); ++i)
{
const std::string imgname = data.images[i].filename;
sout2 << " " << imgname << "\n";
if (removed[i].size() != 0)
{
const std::string imgname = data.images[i].filename;
sout2 << " " << imgname << "\n";
}
}
}
throw error("\n"+wrap_string(sout.str()) + "\n" + sout2.str());
@ -118,8 +126,12 @@ namespace dlib
// ----------------------------------------------------------------------------------------
inline void train_simple_object_detector (
const std::string& dataset_filename,
template <typename image_array>
inline void train_simple_object_detector_on_images (
const std::string& dataset_filename, // can be "" if it's not applicable
image_array& images,
std::vector<std::vector<rectangle> >& boxes,
std::vector<std::vector<rectangle> >& ignore,
const std::string& detector_output_filename,
const simple_object_detector_training_options& options
)
@ -129,12 +141,13 @@ namespace dlib
if (options.epsilon <= 0)
throw error("Invalid epsilon value given to train_simple_object_detector(), epsilon must be > 0.");
dlib::array<array2d<rgb_pixel> > images;
std::vector<std::vector<rectangle> > boxes, ignore;
ignore = load_image_dataset(images, boxes, dataset_filename);
if (images.size() != boxes.size())
throw error("The list of images must have the same length as the list of boxes.");
if (images.size() != ignore.size())
throw error("The list of images must have the same length as the list of ignore boxes.");
if (impl::contains_any_boxes(boxes) == false)
throw error("Error, the dataset in " + dataset_filename + " does not have any labeled object boxes in it.");
throw error("Error, the training dataset does not have any labeled object boxes in it.");
typedef scan_fhog_pyramid<pyramid_down<6> > image_scanner_type;
image_scanner_type scanner;
@ -215,6 +228,21 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
inline void train_simple_object_detector (
const std::string& dataset_filename,
const std::string& detector_output_filename,
const simple_object_detector_training_options& options
)
{
dlib::array<array2d<rgb_pixel> > images;
std::vector<std::vector<rectangle> > boxes, ignore;
ignore = load_image_dataset(images, boxes, dataset_filename);
train_simple_object_detector_on_images(dataset_filename, images, boxes, ignore, detector_output_filename, options);
}
// ----------------------------------------------------------------------------------------
struct simple_test_results