Improved the way the feature vector cache is used within the structural svm

solver.  This makes some things, such as the structural_object_detection_trainer,
significantly faster.
This commit is contained in:
Davis King 2013-07-14 10:10:14 -04:00
parent 06d1331c4d
commit 95aacdfdfb
1 changed files with 22 additions and 10 deletions

View File

@ -24,7 +24,7 @@ namespace dlib
public:
cache_element_structural_svm (
) : prob(0), sample_idx(0) {}
) : prob(0), sample_idx(0), last_true_risk_computed(std::numeric_limits<double>::infinity()) {}
typedef typename structural_svm_problem::scalar_type scalar_type;
typedef typename structural_svm_problem::matrix_type matrix_type;
@ -66,19 +66,23 @@ namespace dlib
void separation_oracle_cached (
const bool skip_cache,
const scalar_type& cur_risk_lower_bound,
const scalar_type& saved_current_risk_gap,
const matrix_type& current_solution,
scalar_type& out_loss,
feature_vector_type& out_psi
) const
{
if (!skip_cache && prob->get_max_cache_size() != 0)
const bool cache_enabled = prob->get_max_cache_size() != 0;
// Don't waste time computing this if the cache isn't going to be used.
const scalar_type dot_true_psi = cache_enabled ? dot(true_psi, current_solution) : 0;
if (!skip_cache && cache_enabled)
{
scalar_type best_risk = -std::numeric_limits<scalar_type>::infinity();
unsigned long best_idx = 0;
const scalar_type dot_true_psi = dot(true_psi, current_solution);
// figure out which element in the cache is the best (i.e. has the biggest risk)
long max_lru_count = 0;
@ -95,7 +99,11 @@ namespace dlib
max_lru_count = lru_count[i];
}
if (best_risk - cur_risk_lower_bound > prob->get_epsilon())
// Check if the best psi vector in the cache is still good enough to use as
// a proxy for the true separation oracle. If the risk value has dropped
// by enough to get into the stopping condition then the best psi isn't
// good enough.
if (best_risk + saved_current_risk_gap-prob->get_epsilon() > last_true_risk_computed)
{
out_psi = psi[best_idx];
lru_count[best_idx] = max_lru_count + 1;
@ -106,11 +114,13 @@ namespace dlib
prob->separation_oracle(sample_idx, current_solution, out_loss, out_psi);
if (prob->get_max_cache_size() == 0)
if (!cache_enabled)
return;
compact_sparse_vector(out_psi);
last_true_risk_computed = out_loss + dot(out_psi, current_solution) - dot_true_psi;
// if the cache is full
if (loss.size() >= prob->get_max_cache_size())
{
@ -167,6 +177,7 @@ namespace dlib
mutable std::vector<scalar_type> loss;
mutable std::vector<feature_vector_type> psi;
mutable std::vector<long> lru_count;
mutable double last_true_risk_computed;
};
// ----------------------------------------------------------------------------------------
@ -200,7 +211,7 @@ namespace dlib
structural_svm_problem (
) :
cur_risk_lower_bound(0),
saved_current_risk_gap(0),
eps(0.001),
verbose(false),
skip_cache(true),
@ -315,7 +326,8 @@ namespace dlib
cout << endl;
}
cur_risk_lower_bound = std::max<scalar_type>(current_risk_value - current_risk_gap, 0);
saved_current_risk_gap = std::max<scalar_type>(current_risk_value - current_risk_gap, 0);
saved_current_risk_gap = current_risk_gap;
bool should_stop = false;
@ -401,7 +413,7 @@ namespace dlib
) const
{
cache[idx].separation_oracle_cached(skip_cache,
cur_risk_lower_bound,
saved_current_risk_gap,
current_solution,
loss,
psi);
@ -409,7 +421,7 @@ namespace dlib
private:
mutable scalar_type cur_risk_lower_bound;
mutable scalar_type saved_current_risk_gap;
mutable matrix_type psi_true;
scalar_type eps;
mutable bool verbose;