From a9c940b1676c3c71e29bfe37d1133af527161bc3 Mon Sep 17 00:00:00 2001 From: Davis King Date: Fri, 8 Jun 2018 21:51:17 -0400 Subject: [PATCH] Added an option to do translational jittering of the bounding boxes in the shape_predictor_trainer. --- .../shape_predictor_trainer.h | 31 +++++++++++++++++++ .../shape_predictor_trainer_abstract.h | 29 +++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/dlib/image_processing/shape_predictor_trainer.h b/dlib/image_processing/shape_predictor_trainer.h index 3090998f9..3846e0af7 100644 --- a/dlib/image_processing/shape_predictor_trainer.h +++ b/dlib/image_processing/shape_predictor_trainer.h @@ -37,6 +37,7 @@ namespace dlib _num_trees_per_cascade_level = 500; _nu = 0.1; _oversampling_amount = 20; + _oversampling_translation_jitter = 0; _feature_pool_size = 400; _lambda = 0.1; _num_test_splits = 20; @@ -116,6 +117,7 @@ namespace dlib unsigned long get_oversampling_amount ( ) const { return _oversampling_amount; } + void set_oversampling_amount ( unsigned long amount ) @@ -129,6 +131,22 @@ namespace dlib _oversampling_amount = amount; } + unsigned long get_oversampling_translation_jitter ( + ) const { return _oversampling_translation_jitter; } + + void set_oversampling_translation_jitter ( + double amount + ) + { + DLIB_CASSERT(amount >= 0, + "\t void shape_predictor_trainer::set_oversampling_translation_jitter()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t amount: " << amount + ); + + _oversampling_translation_jitter = amount; + } + unsigned long get_feature_pool_size ( ) const { return _feature_pool_size; } void set_feature_pool_size ( @@ -706,6 +724,18 @@ namespace dlib hits += alpha*samples[rand_idx].present; } samples[i].current_shape = pointwise_multiply(samples[i].current_shape, reciprocal(hits)); + + if (_oversampling_translation_jitter != 0) + { + dpoint off; + off.x() = rnd.get_double_in_range(-_oversampling_translation_jitter,_oversampling_translation_jitter); + off.y() = rnd.get_double_in_range(-_oversampling_translation_jitter,_oversampling_translation_jitter); + for (long j = 0; j < samples[i].current_shape.size()/2; ++j) + { + samples[i].current_shape(2*j) += off.x(); + samples[i].current_shape(2*j+1) += off.y(); + } + } } } @@ -795,6 +825,7 @@ namespace dlib bool _verbose; unsigned long _num_threads; padding_mode_t _padding_mode; + double _oversampling_translation_jitter; }; // ---------------------------------------------------------------------------------------- diff --git a/dlib/image_processing/shape_predictor_trainer_abstract.h b/dlib/image_processing/shape_predictor_trainer_abstract.h index 278b97842..6633707f1 100644 --- a/dlib/image_processing/shape_predictor_trainer_abstract.h +++ b/dlib/image_processing/shape_predictor_trainer_abstract.h @@ -33,6 +33,7 @@ namespace dlib - #get_num_trees_per_cascade_level() == 500 - #get_nu() == 0.1 - #get_oversampling_amount() == 20 + - #get_oversampling_translation_jitter() == 0 - #get_feature_pool_size() == 400 - #get_lambda() == 0.1 - #get_num_test_splits() == 20 @@ -162,6 +163,34 @@ namespace dlib - #get_oversampling_amount() == amount !*/ + unsigned long get_oversampling_translation_jitter ( + ) const; + /*! + ensures + - When generating the get_oversampling_amount() factor of extra training + samples you can also jitter the bounding box by adding random small + translational shifts. You can tell the shape_predictor_trainer to do + this by setting get_oversampling_translation_jitter() to some non-zero + value. For instance, if you set it to 0.1 then it would randomly + translate the bounding boxes by between 0% and 10% their width and + height in the x and y directions respectively. Doing this is essentially + equivalent to randomly jittering the bounding boxes in the training data + (i.e. the boxes given by full_object_detection::get_rect()). This is + useful because the seed shape is determined by the bounding box position, + so doing this kind of jittering can help make the learned model more + robust against slightly misplaced bounding boxes. + !*/ + + void set_oversampling_translation_jitter ( + double amount + ); + /*! + requires + - amount >= 0 + ensures + - #get_oversampling_translation_jitter() == amount + !*/ + unsigned long get_feature_pool_size ( ) const; /*!