updated the svm_c_linear_trainer

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403504
This commit is contained in:
Davis King 2010-02-28 01:42:55 +00:00
parent c2de9780a3
commit ce42b721dd
1 changed files with 161 additions and 14 deletions

View File

@ -41,13 +41,15 @@ namespace dlib
const scalar_type C_neg, const scalar_type C_neg,
const in_sample_vector_type& samples_, const in_sample_vector_type& samples_,
const in_scalar_vector_type& labels_, const in_scalar_vector_type& labels_,
bool be_verbose_ const bool be_verbose_,
const scalar_type eps_
) : ) :
samples(samples_), samples(samples_),
labels(labels_), labels(labels_),
Cpos(C_pos), Cpos(C_pos),
Cneg(C_neg), Cneg(C_neg),
be_verbose(be_verbose_) be_verbose(be_verbose_),
eps(eps_)
{ {
dot_prods.resize(samples.size()); dot_prods.resize(samples.size());
is_first_call = true; is_first_call = true;
@ -83,10 +85,7 @@ namespace dlib
cout << endl; cout << endl;
} }
if (current_error_gap/current_objective_value < 0.001) if (current_error_gap/current_objective_value < eps)
return true;
if (num_iterations > 10000)
return true; return true;
return false; return false;
@ -369,7 +368,8 @@ namespace dlib
const scalar_type Cpos; const scalar_type Cpos;
const scalar_type Cneg; const scalar_type Cneg;
bool be_verbose; const bool be_verbose;
const scalar_type eps;
}; };
@ -384,10 +384,11 @@ namespace dlib
const scalar_type C_neg, const scalar_type C_neg,
const in_sample_vector_type& samples, const in_sample_vector_type& samples,
const in_scalar_vector_type& labels, const in_scalar_vector_type& labels,
bool be_verbose const bool be_verbose,
const scalar_type eps
) )
{ {
return oca_problem_c_svm<matrix_type, in_sample_vector_type, in_scalar_vector_type>(C_pos, C_neg, samples, labels, be_verbose); return oca_problem_c_svm<matrix_type, in_sample_vector_type, in_scalar_vector_type>(C_pos, C_neg, samples, labels, be_verbose, eps);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
@ -419,10 +420,14 @@ namespace dlib
- #get_oca() == oca() (i.e. an instance of oca with default parameters) - #get_oca() == oca() (i.e. an instance of oca with default parameters)
- #get_c_class1() == 1 - #get_c_class1() == 1
- #get_c_class2() == 1 - #get_c_class2() == 1
- #get_epsilon() == 0.001
- this object will not be verbose unless be_verbose() is called
!*/ !*/
{ {
Cpos = 1; Cpos = 1;
Cneg = 1; Cneg = 1;
verbose = false;
eps = 0.001;
} }
explicit svm_c_linear_trainer ( explicit svm_c_linear_trainer (
@ -439,10 +444,69 @@ namespace dlib
- #get_c_class2() == C - #get_c_class2() == C
!*/ !*/
{ {
// make sure requires clause is not broken
DLIB_ASSERT(C > 0,
"\t svm_c_linear_trainer::svm_c_linear_trainer()"
<< "\n\t C must be greater than 0"
<< "\n\t C: " << C
<< "\n\t this: " << this
);
Cpos = C; Cpos = C;
Cneg = C; Cneg = C;
} }
void set_epsilon (
scalar_type eps_
)
/*!
requires
- eps > 0
ensures
- #get_epsilon() == eps
!*/
{
// make sure requires clause is not broken
DLIB_ASSERT(eps_ > 0,
"\t void svm_c_linear_trainer::set_epsilon()"
<< "\n\t eps_ must be greater than 0"
<< "\n\t eps_: " << eps_
<< "\n\t this: " << this
);
eps = eps_;
}
const scalar_type get_epsilon (
) const { return eps; }
/*!
ensures
- returns the error epsilon that determines when training should stop.
Smaller values may result in a more accurate solution but take longer
to execute.
!*/
void be_verbose (
)
/*!
ensures
- This object will print status messages to standard out so that a
user can observe the progress of the algorithm.
!*/
{
verbose = true;
}
void be_quiet (
)
/*!
ensures
- this object will not print anything to standard out
!*/
{
verbose = false;
}
void set_oca ( void set_oca (
const oca& item const oca& item
) )
@ -485,6 +549,14 @@ namespace dlib
- #get_c_class2() == C - #get_c_class2() == C
!*/ !*/
{ {
// make sure requires clause is not broken
DLIB_ASSERT(C > 0,
"\t void svm_c_linear_trainer::set_c()"
<< "\n\t C must be greater than 0"
<< "\n\t C: " << C
<< "\n\t this: " << this
);
Cpos = C; Cpos = C;
Cneg = C; Cneg = C;
} }
@ -529,6 +601,14 @@ namespace dlib
- #get_c_class1() == C - #get_c_class1() == C
!*/ !*/
{ {
// make sure requires clause is not broken
DLIB_ASSERT(C > 0,
"\t void svm_c_linear_trainer::set_c_class1()"
<< "\n\t C must be greater than 0"
<< "\n\t C: " << C
<< "\n\t this: " << this
);
Cpos = C; Cpos = C;
} }
@ -542,6 +622,14 @@ namespace dlib
- #get_c_class2() == C - #get_c_class2() == C
!*/ !*/
{ {
// make sure requires clause is not broken
DLIB_ASSERT(C > 0,
"\t void svm_c_linear_trainer::set_c_class2()"
<< "\n\t C must be greater than 0"
<< "\n\t C: " << C
<< "\n\t this: " << this
);
Cneg = C; Cneg = C;
} }
@ -570,12 +658,71 @@ namespace dlib
- F(new_x) < 0 - F(new_x) < 0
!*/ !*/
{ {
scalar_type obj;
return do_train(vector_to_matrix(x),vector_to_matrix(y),obj);
}
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const decision_function<kernel_type> train (
const in_sample_vector_type& x,
const in_scalar_vector_type& y,
scalar_type& svm_objective
) const
/*!
requires
- is_binary_classification_problem(x,y) == true
- x == a matrix or something convertible to a matrix via vector_to_matrix().
Also, x should contain sample_type objects.
- y == a matrix or something convertible to a matrix via vector_to_matrix().
Also, y should contain scalar_type objects.
ensures
- trains a C support vector classifier given the training samples in x and
labels in y.
- #svm_objective == the final value of the SVM objective function
- returns a decision function F with the following properties:
- if (new_x is a sample predicted have +1 label) then
- F(new_x) >= 0
- else
- F(new_x) < 0
!*/
{
return do_train(vector_to_matrix(x),vector_to_matrix(y),svm_objective);
}
private:
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const decision_function<kernel_type> do_train (
const in_sample_vector_type& x,
const in_scalar_vector_type& y,
scalar_type& svm_objective
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
"\t decision_function svm_c_linear_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
);
typedef matrix<scalar_type,0,1> w_type; typedef matrix<scalar_type,0,1> w_type;
w_type w; w_type w;
scalar_type obj = solver(make_oca_problem_c_svm<w_type>(Cpos, Cneg, vector_to_matrix(x), vector_to_matrix(y), true), w); svm_objective = solver(
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps),
std::cout << "final obj: "<< obj << std::endl; w);
// put the solution into a decision function and then return it // put the solution into a decision function and then return it
decision_function<kernel_type> df; decision_function<kernel_type> df;
@ -587,12 +734,12 @@ namespace dlib
return df; return df;
} }
private:
scalar_type Cpos; scalar_type Cpos;
scalar_type Cneg; scalar_type Cneg;
oca solver; oca solver;
scalar_type eps;
bool verbose;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------