Added an option to do translational jittering of the bounding boxes in the

shape_predictor_trainer.
This commit is contained in:
Davis King 2018-06-08 21:51:17 -04:00
parent 89bfb786d1
commit a9c940b167
2 changed files with 60 additions and 0 deletions

View File

@ -37,6 +37,7 @@ namespace dlib
_num_trees_per_cascade_level = 500;
_nu = 0.1;
_oversampling_amount = 20;
_oversampling_translation_jitter = 0;
_feature_pool_size = 400;
_lambda = 0.1;
_num_test_splits = 20;
@ -116,6 +117,7 @@ namespace dlib
unsigned long get_oversampling_amount (
) const { return _oversampling_amount; }
void set_oversampling_amount (
unsigned long amount
)
@ -129,6 +131,22 @@ namespace dlib
_oversampling_amount = amount;
}
unsigned long get_oversampling_translation_jitter (
) const { return _oversampling_translation_jitter; }
void set_oversampling_translation_jitter (
double amount
)
{
DLIB_CASSERT(amount >= 0,
"\t void shape_predictor_trainer::set_oversampling_translation_jitter()"
<< "\n\t Invalid inputs were given to this function. "
<< "\n\t amount: " << amount
);
_oversampling_translation_jitter = amount;
}
unsigned long get_feature_pool_size (
) const { return _feature_pool_size; }
void set_feature_pool_size (
@ -706,6 +724,18 @@ namespace dlib
hits += alpha*samples[rand_idx].present;
}
samples[i].current_shape = pointwise_multiply(samples[i].current_shape, reciprocal(hits));
if (_oversampling_translation_jitter != 0)
{
dpoint off;
off.x() = rnd.get_double_in_range(-_oversampling_translation_jitter,_oversampling_translation_jitter);
off.y() = rnd.get_double_in_range(-_oversampling_translation_jitter,_oversampling_translation_jitter);
for (long j = 0; j < samples[i].current_shape.size()/2; ++j)
{
samples[i].current_shape(2*j) += off.x();
samples[i].current_shape(2*j+1) += off.y();
}
}
}
}
@ -795,6 +825,7 @@ namespace dlib
bool _verbose;
unsigned long _num_threads;
padding_mode_t _padding_mode;
double _oversampling_translation_jitter;
};
// ----------------------------------------------------------------------------------------

View File

@ -33,6 +33,7 @@ namespace dlib
- #get_num_trees_per_cascade_level() == 500
- #get_nu() == 0.1
- #get_oversampling_amount() == 20
- #get_oversampling_translation_jitter() == 0
- #get_feature_pool_size() == 400
- #get_lambda() == 0.1
- #get_num_test_splits() == 20
@ -162,6 +163,34 @@ namespace dlib
- #get_oversampling_amount() == amount
!*/
unsigned long get_oversampling_translation_jitter (
) const;
/*!
ensures
- When generating the get_oversampling_amount() factor of extra training
samples you can also jitter the bounding box by adding random small
translational shifts. You can tell the shape_predictor_trainer to do
this by setting get_oversampling_translation_jitter() to some non-zero
value. For instance, if you set it to 0.1 then it would randomly
translate the bounding boxes by between 0% and 10% their width and
height in the x and y directions respectively. Doing this is essentially
equivalent to randomly jittering the bounding boxes in the training data
(i.e. the boxes given by full_object_detection::get_rect()). This is
useful because the seed shape is determined by the bounding box position,
so doing this kind of jittering can help make the learned model more
robust against slightly misplaced bounding boxes.
!*/
void set_oversampling_translation_jitter (
double amount
);
/*!
requires
- amount >= 0
ensures
- #get_oversampling_translation_jitter() == amount
!*/
unsigned long get_feature_pool_size (
) const;
/*!