From 2932d6d3f4b4d698d51d2c7c94d4090534eaa279 Mon Sep 17 00:00:00 2001 From: Davis King Date: Wed, 22 Dec 2010 23:06:55 +0000 Subject: [PATCH] Fixed a minor bug and did some cleanup --HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404014 --- dlib/svm/svm_c_trainer.h | 76 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 4 deletions(-) diff --git a/dlib/svm/svm_c_trainer.h b/dlib/svm/svm_c_trainer.h index c325dce15..306305ab4 100644 --- a/dlib/svm/svm_c_trainer.h +++ b/dlib/svm/svm_c_trainer.h @@ -6,7 +6,6 @@ //#include "local/make_label_kernel_matrix.h" #include "svm_c_trainer_abstract.h" -#include "calculate_rho_and_b.h" #include #include #include @@ -236,8 +235,8 @@ namespace dlib alpha, eps); - scalar_type rho, b; - calculate_rho_and_b(y,alpha,solver.get_gradient(),rho,b); + scalar_type b; + calculate_b(y,alpha,solver.get_gradient(),Cpos,Cneg,b); alpha = pointwise_multiply(alpha,y); // count the number of support vectors @@ -263,11 +262,80 @@ namespace dlib } // now return the decision function - return decision_function (sv_alpha, b*rho, kernel_function, support_vectors); + return decision_function (sv_alpha, b, kernel_function, support_vectors); } // ------------------------------------------------------------------------------------ + template < + typename scalar_vector_type, + typename scalar_vector_type2 + > + void calculate_b( + const scalar_vector_type2& y, + const scalar_vector_type& alpha, + const scalar_vector_type& df, + const scalar_type& Cpos, + const scalar_type& Cneg, + scalar_type& b + ) const + { + using namespace std; + long num_free = 0; + scalar_type sum_free = 0; + + scalar_type upper_bound = -numeric_limits::infinity(); + scalar_type lower_bound = numeric_limits::infinity(); + + for(long i = 0; i < alpha.nr(); ++i) + { + if(y(i) == 1) + { + if(alpha(i) == Cpos) + { + if (df(i) > upper_bound) + upper_bound = df(i); + } + else if(alpha(i) == 0) + { + if (df(i) < lower_bound) + lower_bound = df(i); + } + else + { + ++num_free; + sum_free += df(i); + } + } + else + { + if(alpha(i) == Cneg) + { + if (-df(i) > upper_bound) + upper_bound = -df(i); + } + else if(alpha(i) == 0) + { + if (-df(i) < lower_bound) + lower_bound = -df(i); + } + else + { + ++num_free; + sum_free -= df(i); + } + } + } + + if(num_free > 0) + b = sum_free/num_free; + else + b = (upper_bound+lower_bound)/2; + } + + // ------------------------------------------------------------------------------------ + + kernel_type kernel_function; scalar_type Cpos; scalar_type Cneg;