From b445ddbd8d6f16abb2cd3ebf67e6224603657ff7 Mon Sep 17 00:00:00 2001 From: Davis King Date: Thu, 3 Jan 2013 22:00:02 -0500 Subject: [PATCH] Switched this code to use the oca object's ability to force a weight to 1 instead of rolling its own implementation. --- dlib/svm/svm_rank_trainer.h | 35 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/dlib/svm/svm_rank_trainer.h b/dlib/svm/svm_rank_trainer.h index 43b98753d..7ede8b787 100644 --- a/dlib/svm/svm_rank_trainer.h +++ b/dlib/svm/svm_rank_trainer.h @@ -37,15 +37,13 @@ namespace dlib const std::vector >& samples_, const bool be_verbose_, const scalar_type eps_, - const unsigned long max_iter, - const bool last_weight_1_ + const unsigned long max_iter ) : samples(samples_), C(C_), be_verbose(be_verbose_), eps(eps_), - max_iterations(max_iter), - last_weight_1(last_weight_1_) + max_iterations(max_iter) { } @@ -113,8 +111,6 @@ namespace dlib // rank flips. So a risk of 0.1 would mean that rank flips happen < 10% of the // time. - if(last_weight_1) - w(w.size()-1) = 1; std::vector rel_scores; std::vector nonrel_scores; @@ -163,12 +159,6 @@ namespace dlib risk *= scale; subgradient = scale*subgradient; - - if(last_weight_1) - { - w(w.size()-1) = 0; - subgradient(w.size()-1) = 0; - } } private: @@ -183,7 +173,6 @@ namespace dlib const bool be_verbose; const scalar_type eps; const unsigned long max_iterations; - const bool last_weight_1; }; // ---------------------------------------------------------------------------------------- @@ -198,12 +187,11 @@ namespace dlib const std::vector >& samples, const bool be_verbose, const scalar_type eps, - const unsigned long max_iterations, - const bool last_weight_1 + const unsigned long max_iterations ) { return oca_problem_ranking_svm( - C, samples, be_verbose, eps, max_iterations, last_weight_1); + C, samples, be_verbose, eps, max_iterations); } // ---------------------------------------------------------------------------------------- @@ -385,12 +373,17 @@ namespace dlib num_nonnegative = num_dims; } - solver( make_oca_problem_ranking_svm(C, samples, verbose, eps, max_iterations, last_weight_1), - w, - num_nonnegative); + unsigned long force_weight_1_idx = std::numeric_limits::max(); + if (last_weight_1) + { + force_weight_1_idx = num_dims-1; + } + + solver( make_oca_problem_ranking_svm(C, samples, verbose, eps, max_iterations), + w, + num_nonnegative, + force_weight_1_idx); - if(last_weight_1) - w(w.size()-1) = 1; // put the solution into a decision function and then return it decision_function df;