From 95aacdfdfbeab897738c5e1cfb32dc8f2b4cebb4 Mon Sep 17 00:00:00 2001 From: Davis King Date: Sun, 14 Jul 2013 10:10:14 -0400 Subject: [PATCH] Improved the way the feature vector cache is used within the structural svm solver. This makes some things, such as the structural_object_detection_trainer, significantly faster. --- dlib/svm/structural_svm_problem.h | 32 +++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/dlib/svm/structural_svm_problem.h b/dlib/svm/structural_svm_problem.h index 0937beaae..632b17037 100644 --- a/dlib/svm/structural_svm_problem.h +++ b/dlib/svm/structural_svm_problem.h @@ -24,7 +24,7 @@ namespace dlib public: cache_element_structural_svm ( - ) : prob(0), sample_idx(0) {} + ) : prob(0), sample_idx(0), last_true_risk_computed(std::numeric_limits::infinity()) {} typedef typename structural_svm_problem::scalar_type scalar_type; typedef typename structural_svm_problem::matrix_type matrix_type; @@ -66,19 +66,23 @@ namespace dlib void separation_oracle_cached ( const bool skip_cache, - const scalar_type& cur_risk_lower_bound, + const scalar_type& saved_current_risk_gap, const matrix_type& current_solution, scalar_type& out_loss, feature_vector_type& out_psi ) const { - if (!skip_cache && prob->get_max_cache_size() != 0) + const bool cache_enabled = prob->get_max_cache_size() != 0; + + // Don't waste time computing this if the cache isn't going to be used. + const scalar_type dot_true_psi = cache_enabled ? dot(true_psi, current_solution) : 0; + + if (!skip_cache && cache_enabled) { scalar_type best_risk = -std::numeric_limits::infinity(); unsigned long best_idx = 0; - const scalar_type dot_true_psi = dot(true_psi, current_solution); // figure out which element in the cache is the best (i.e. has the biggest risk) long max_lru_count = 0; @@ -95,7 +99,11 @@ namespace dlib max_lru_count = lru_count[i]; } - if (best_risk - cur_risk_lower_bound > prob->get_epsilon()) + // Check if the best psi vector in the cache is still good enough to use as + // a proxy for the true separation oracle. If the risk value has dropped + // by enough to get into the stopping condition then the best psi isn't + // good enough. + if (best_risk + saved_current_risk_gap-prob->get_epsilon() > last_true_risk_computed) { out_psi = psi[best_idx]; lru_count[best_idx] = max_lru_count + 1; @@ -106,11 +114,13 @@ namespace dlib prob->separation_oracle(sample_idx, current_solution, out_loss, out_psi); - if (prob->get_max_cache_size() == 0) + if (!cache_enabled) return; compact_sparse_vector(out_psi); + last_true_risk_computed = out_loss + dot(out_psi, current_solution) - dot_true_psi; + // if the cache is full if (loss.size() >= prob->get_max_cache_size()) { @@ -167,6 +177,7 @@ namespace dlib mutable std::vector loss; mutable std::vector psi; mutable std::vector lru_count; + mutable double last_true_risk_computed; }; // ---------------------------------------------------------------------------------------- @@ -200,7 +211,7 @@ namespace dlib structural_svm_problem ( ) : - cur_risk_lower_bound(0), + saved_current_risk_gap(0), eps(0.001), verbose(false), skip_cache(true), @@ -315,7 +326,8 @@ namespace dlib cout << endl; } - cur_risk_lower_bound = std::max(current_risk_value - current_risk_gap, 0); + saved_current_risk_gap = std::max(current_risk_value - current_risk_gap, 0); + saved_current_risk_gap = current_risk_gap; bool should_stop = false; @@ -401,7 +413,7 @@ namespace dlib ) const { cache[idx].separation_oracle_cached(skip_cache, - cur_risk_lower_bound, + saved_current_risk_gap, current_solution, loss, psi); @@ -409,7 +421,7 @@ namespace dlib private: - mutable scalar_type cur_risk_lower_bound; + mutable scalar_type saved_current_risk_gap; mutable matrix_type psi_true; scalar_type eps; mutable bool verbose;