Fixed a minor bug and did some cleanup

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404014
This commit is contained in:
Davis King 2010-12-22 23:06:55 +00:00
parent 00325e7521
commit 2932d6d3f4
1 changed files with 72 additions and 4 deletions

View File

@ -6,7 +6,6 @@
//#include "local/make_label_kernel_matrix.h"
#include "svm_c_trainer_abstract.h"
#include "calculate_rho_and_b.h"
#include <cmath>
#include <limits>
#include <sstream>
@ -236,8 +235,8 @@ namespace dlib
alpha,
eps);
scalar_type rho, b;
calculate_rho_and_b(y,alpha,solver.get_gradient(),rho,b);
scalar_type b;
calculate_b(y,alpha,solver.get_gradient(),Cpos,Cneg,b);
alpha = pointwise_multiply(alpha,y);
// count the number of support vectors
@ -263,11 +262,80 @@ namespace dlib
}
// now return the decision function
return decision_function<K> (sv_alpha, b*rho, kernel_function, support_vectors);
return decision_function<K> (sv_alpha, b, kernel_function, support_vectors);
}
// ------------------------------------------------------------------------------------
template <
typename scalar_vector_type,
typename scalar_vector_type2
>
void calculate_b(
const scalar_vector_type2& y,
const scalar_vector_type& alpha,
const scalar_vector_type& df,
const scalar_type& Cpos,
const scalar_type& Cneg,
scalar_type& b
) const
{
using namespace std;
long num_free = 0;
scalar_type sum_free = 0;
scalar_type upper_bound = -numeric_limits<scalar_type>::infinity();
scalar_type lower_bound = numeric_limits<scalar_type>::infinity();
for(long i = 0; i < alpha.nr(); ++i)
{
if(y(i) == 1)
{
if(alpha(i) == Cpos)
{
if (df(i) > upper_bound)
upper_bound = df(i);
}
else if(alpha(i) == 0)
{
if (df(i) < lower_bound)
lower_bound = df(i);
}
else
{
++num_free;
sum_free += df(i);
}
}
else
{
if(alpha(i) == Cneg)
{
if (-df(i) > upper_bound)
upper_bound = -df(i);
}
else if(alpha(i) == 0)
{
if (-df(i) < lower_bound)
lower_bound = -df(i);
}
else
{
++num_free;
sum_free -= df(i);
}
}
}
if(num_free > 0)
b = sum_free/num_free;
else
b = (upper_bound+lower_bound)/2;
}
// ------------------------------------------------------------------------------------
kernel_type kernel_function;
scalar_type Cpos;
scalar_type Cneg;