From b638145fb6b23ca0b24600eb5e1d92669e12b7b3 Mon Sep 17 00:00:00 2001 From: Davis King Date: Sat, 23 Mar 2013 19:49:42 -0400 Subject: [PATCH] Moved the responsibility for automatically filling out a test_box_overlap object from the structural_object_detection_trainer to the structural_svm_object_detection_problem. This allows us to use image scanners which require an image be loaded before get_best_matching_rect() can be called. I also made it so that the scanner loading (and therefore feature extraction) is threaded. Previously, it only used a single core. --- .../svm/structural_object_detection_trainer.h | 25 +----- .../structural_svm_object_detection_problem.h | 85 ++++++++++++++++--- ...al_svm_object_detection_problem_abstract.h | 18 +++- 3 files changed, 90 insertions(+), 38 deletions(-) diff --git a/dlib/svm/structural_object_detection_trainer.h b/dlib/svm/structural_object_detection_trainer.h index 25d3815e8..2456fd9fa 100644 --- a/dlib/svm/structural_object_detection_trainer.h +++ b/dlib/svm/structural_object_detection_trainer.h @@ -275,29 +275,8 @@ namespace dlib } #endif - test_box_overlap local_overlap_tester; - - if (auto_overlap_tester) - { - std::vector > mapped_rects(truth_object_detections.size()); - for (unsigned long i = 0; i < truth_object_detections.size(); ++i) - { - mapped_rects[i].resize(truth_object_detections[i].size()); - for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j) - { - mapped_rects[i][j] = scanner.get_best_matching_rect(truth_object_detections[i][j].get_rect()); - } - } - - local_overlap_tester = find_tight_overlap_tester(mapped_rects); - } - else - { - local_overlap_tester = overlap_tester; - } - structural_svm_object_detection_problem - svm_prob(scanner, local_overlap_tester, images, truth_object_detections, num_threads); + svm_prob(scanner, overlap_tester, auto_overlap_tester, images, truth_object_detections, num_threads); if (verbose) svm_prob.be_verbose(); @@ -314,7 +293,7 @@ namespace dlib solver(svm_prob,w); // report the results of the training. - return object_detector(scanner, local_overlap_tester, w); + return object_detector(scanner, svm_prob.get_overlap_tester(), w); } template < diff --git a/dlib/svm/structural_svm_object_detection_problem.h b/dlib/svm/structural_svm_object_detection_problem.h index f0ee788c6..adcf596f7 100644 --- a/dlib/svm/structural_svm_object_detection_problem.h +++ b/dlib/svm/structural_svm_object_detection_problem.h @@ -37,6 +37,7 @@ namespace dlib structural_svm_object_detection_problem( const image_scanner_type& scanner, const test_box_overlap& overlap_tester, + const bool auto_overlap_tester, const image_array_type& images_, const std::vector >& truth_object_detections_, unsigned long num_threads = 2 @@ -75,19 +76,34 @@ namespace dlib } } #endif - - scanners.set_max_size(images.size()); - scanners.set_size(images.size()); - + // The purpose of the max_num_dets member variable is to give us a reasonable + // upper limit on the number of detections we can expect from a single image. + // This is used in the separation_oracle to put a hard limit on the number of + // detections we will consider. We do this purely for computational reasons + // since otherwise we can end up wasting large amounts of time on certain + // pathological cases during optimization which ultimately do not influence the + // result. Therefore, we for the separation oracle to only consider the + // max_num_dets strongest detections. max_num_dets = 0; for (unsigned long i = 0; i < truth_object_detections.size(); ++i) { if (truth_object_detections[i].size() > max_num_dets) max_num_dets = truth_object_detections[i].size(); - - scanners[i].copy_configuration(scanner); } max_num_dets = max_num_dets*3 + 10; + + initialize_scanners(scanner, num_threads); + + if (auto_overlap_tester) + { + auto_configure_overlap_tester(); + } + } + + test_box_overlap get_overlap_tester ( + ) const + { + return boxes_overlap; } void set_match_eps ( @@ -154,6 +170,24 @@ namespace dlib } private: + + void auto_configure_overlap_tester( + ) + { + std::vector > mapped_rects(truth_object_detections.size()); + for (unsigned long i = 0; i < truth_object_detections.size(); ++i) + { + mapped_rects[i].resize(truth_object_detections[i].size()); + for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j) + { + mapped_rects[i][j] = scanners[i].get_best_matching_rect(truth_object_detections[i][j].get_rect()); + } + } + + boxes_overlap = find_tight_overlap_tester(mapped_rects); + } + + virtual long get_num_dimensions ( ) const { @@ -172,7 +206,7 @@ namespace dlib feature_vector_type& psi ) const { - const image_scanner_type& scanner = get_scanner(idx); + const image_scanner_type& scanner = scanners[idx]; psi.set_size(get_num_dimensions()); std::vector mapped_rects; @@ -268,7 +302,7 @@ namespace dlib feature_vector_type& psi ) const { - const image_scanner_type& scanner = get_scanner(idx); + const image_scanner_type& scanner = scanners[idx]; std::vector > dets; const double thresh = current_solution(scanner.get_num_dimensions()); @@ -437,13 +471,38 @@ namespace dlib return std::make_pair(match,best_idx); } - - const image_scanner_type& get_scanner (long idx) const + struct init_scanners_helper { - if (scanners[idx].is_loaded_with_image() == false) - scanners[idx].load(images[idx]); + init_scanners_helper ( + array& scanners_, + const image_array_type& images_ + ) : + scanners(scanners_), + images(images_) + {} - return scanners[idx]; + array& scanners; + const image_array_type& images; + + void operator() (long i ) const + { + scanners[i].load(images[i]); + } + }; + + void initialize_scanners ( + const image_scanner_type& scanner, + unsigned long num_threads + ) + { + scanners.set_max_size(images.size()); + scanners.set_size(images.size()); + + for (unsigned long i = 0; i < scanners.size(); ++i) + scanners[i].copy_configuration(scanner); + + // now load the images into all the scanners + parallel_for(num_threads, 0, scanners.size(), init_scanners_helper(scanners, images)); } diff --git a/dlib/svm/structural_svm_object_detection_problem_abstract.h b/dlib/svm/structural_svm_object_detection_problem_abstract.h index a7685186f..3c8960c04 100644 --- a/dlib/svm/structural_svm_object_detection_problem_abstract.h +++ b/dlib/svm/structural_svm_object_detection_problem_abstract.h @@ -68,7 +68,7 @@ namespace dlib A detection is considered a false alarm if it doesn't match with any of the ground truth rectangles or if it is a duplicate detection of a truth rectangle. Finally, for the purposes of calculating loss, a match - is determined using the following formula, rectangles A and B match + is determined using the following formula where rectangles A and B match if and only if: A.intersect(B).area()/(A+B).area() > get_match_eps() !*/ @@ -77,6 +77,7 @@ namespace dlib structural_svm_object_detection_problem( const image_scanner_type& scanner, const test_box_overlap& overlap_tester, + const bool auto_overlap_tester, const image_array_type& images, const std::vector >& truth_object_detections, unsigned long num_threads = 2 @@ -95,12 +96,18 @@ namespace dlib attempts to learn to predict truth_object_detections[i] based on images[i]. Or in other words, this object can be used to learn a parameter vector, w, such that an object_detector declared as: - object_detector detector(scanner,overlap_tester,w) + object_detector detector(scanner,get_overlap_tester(),w) results in a detector object which attempts to compute the locations of all the objects in truth_object_detections. So if you called detector(images[i]) you would hopefully get a list of rectangles back that had truth_object_detections[i].size() elements and contained exactly the rectangles indicated by truth_object_detections[i]. + - if (auto_overlap_tester == true) then + - #get_overlap_tester() == a test_box_overlap object that is configured + using the find_tight_overlap_tester() routine and the contents of + truth_object_detections. + - else + - #get_overlap_tester() == overlap_tester - #get_match_eps() == 0.5 - This object will use num_threads threads during the optimization procedure. You should set this parameter equal to the number of @@ -109,6 +116,13 @@ namespace dlib - #get_loss_per_false_alarm() == 1 !*/ + test_box_overlap get_overlap_tester ( + ) const; + /*! + ensures + - returns the overlap tester used by this object. + !*/ + void set_match_eps ( double eps );