mirror of https://github.com/davisking/dlib.git
Improved the way the feature vector cache is used within the structural svm
solver. This makes some things, such as the structural_object_detection_trainer, significantly faster.
This commit is contained in:
parent
06d1331c4d
commit
95aacdfdfb
|
@ -24,7 +24,7 @@ namespace dlib
|
|||
public:
|
||||
|
||||
cache_element_structural_svm (
|
||||
) : prob(0), sample_idx(0) {}
|
||||
) : prob(0), sample_idx(0), last_true_risk_computed(std::numeric_limits<double>::infinity()) {}
|
||||
|
||||
typedef typename structural_svm_problem::scalar_type scalar_type;
|
||||
typedef typename structural_svm_problem::matrix_type matrix_type;
|
||||
|
@ -66,19 +66,23 @@ namespace dlib
|
|||
|
||||
void separation_oracle_cached (
|
||||
const bool skip_cache,
|
||||
const scalar_type& cur_risk_lower_bound,
|
||||
const scalar_type& saved_current_risk_gap,
|
||||
const matrix_type& current_solution,
|
||||
scalar_type& out_loss,
|
||||
feature_vector_type& out_psi
|
||||
) const
|
||||
{
|
||||
if (!skip_cache && prob->get_max_cache_size() != 0)
|
||||
const bool cache_enabled = prob->get_max_cache_size() != 0;
|
||||
|
||||
// Don't waste time computing this if the cache isn't going to be used.
|
||||
const scalar_type dot_true_psi = cache_enabled ? dot(true_psi, current_solution) : 0;
|
||||
|
||||
if (!skip_cache && cache_enabled)
|
||||
{
|
||||
scalar_type best_risk = -std::numeric_limits<scalar_type>::infinity();
|
||||
unsigned long best_idx = 0;
|
||||
|
||||
|
||||
const scalar_type dot_true_psi = dot(true_psi, current_solution);
|
||||
|
||||
// figure out which element in the cache is the best (i.e. has the biggest risk)
|
||||
long max_lru_count = 0;
|
||||
|
@ -95,7 +99,11 @@ namespace dlib
|
|||
max_lru_count = lru_count[i];
|
||||
}
|
||||
|
||||
if (best_risk - cur_risk_lower_bound > prob->get_epsilon())
|
||||
// Check if the best psi vector in the cache is still good enough to use as
|
||||
// a proxy for the true separation oracle. If the risk value has dropped
|
||||
// by enough to get into the stopping condition then the best psi isn't
|
||||
// good enough.
|
||||
if (best_risk + saved_current_risk_gap-prob->get_epsilon() > last_true_risk_computed)
|
||||
{
|
||||
out_psi = psi[best_idx];
|
||||
lru_count[best_idx] = max_lru_count + 1;
|
||||
|
@ -106,11 +114,13 @@ namespace dlib
|
|||
|
||||
prob->separation_oracle(sample_idx, current_solution, out_loss, out_psi);
|
||||
|
||||
if (prob->get_max_cache_size() == 0)
|
||||
if (!cache_enabled)
|
||||
return;
|
||||
|
||||
compact_sparse_vector(out_psi);
|
||||
|
||||
last_true_risk_computed = out_loss + dot(out_psi, current_solution) - dot_true_psi;
|
||||
|
||||
// if the cache is full
|
||||
if (loss.size() >= prob->get_max_cache_size())
|
||||
{
|
||||
|
@ -167,6 +177,7 @@ namespace dlib
|
|||
mutable std::vector<scalar_type> loss;
|
||||
mutable std::vector<feature_vector_type> psi;
|
||||
mutable std::vector<long> lru_count;
|
||||
mutable double last_true_risk_computed;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -200,7 +211,7 @@ namespace dlib
|
|||
|
||||
structural_svm_problem (
|
||||
) :
|
||||
cur_risk_lower_bound(0),
|
||||
saved_current_risk_gap(0),
|
||||
eps(0.001),
|
||||
verbose(false),
|
||||
skip_cache(true),
|
||||
|
@ -315,7 +326,8 @@ namespace dlib
|
|||
cout << endl;
|
||||
}
|
||||
|
||||
cur_risk_lower_bound = std::max<scalar_type>(current_risk_value - current_risk_gap, 0);
|
||||
saved_current_risk_gap = std::max<scalar_type>(current_risk_value - current_risk_gap, 0);
|
||||
saved_current_risk_gap = current_risk_gap;
|
||||
|
||||
bool should_stop = false;
|
||||
|
||||
|
@ -401,7 +413,7 @@ namespace dlib
|
|||
) const
|
||||
{
|
||||
cache[idx].separation_oracle_cached(skip_cache,
|
||||
cur_risk_lower_bound,
|
||||
saved_current_risk_gap,
|
||||
current_solution,
|
||||
loss,
|
||||
psi);
|
||||
|
@ -409,7 +421,7 @@ namespace dlib
|
|||
private:
|
||||
|
||||
|
||||
mutable scalar_type cur_risk_lower_bound;
|
||||
mutable scalar_type saved_current_risk_gap;
|
||||
mutable matrix_type psi_true;
|
||||
scalar_type eps;
|
||||
mutable bool verbose;
|
||||
|
|
Loading…
Reference in New Issue