mirror of https://github.com/davisking/dlib.git
updated the svm_c_linear_trainer
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403504
This commit is contained in:
parent
c2de9780a3
commit
ce42b721dd
|
@ -41,13 +41,15 @@ namespace dlib
|
|||
const scalar_type C_neg,
|
||||
const in_sample_vector_type& samples_,
|
||||
const in_scalar_vector_type& labels_,
|
||||
bool be_verbose_
|
||||
const bool be_verbose_,
|
||||
const scalar_type eps_
|
||||
) :
|
||||
samples(samples_),
|
||||
labels(labels_),
|
||||
Cpos(C_pos),
|
||||
Cneg(C_neg),
|
||||
be_verbose(be_verbose_)
|
||||
be_verbose(be_verbose_),
|
||||
eps(eps_)
|
||||
{
|
||||
dot_prods.resize(samples.size());
|
||||
is_first_call = true;
|
||||
|
@ -83,10 +85,7 @@ namespace dlib
|
|||
cout << endl;
|
||||
}
|
||||
|
||||
if (current_error_gap/current_objective_value < 0.001)
|
||||
return true;
|
||||
|
||||
if (num_iterations > 10000)
|
||||
if (current_error_gap/current_objective_value < eps)
|
||||
return true;
|
||||
|
||||
return false;
|
||||
|
@ -369,7 +368,8 @@ namespace dlib
|
|||
const scalar_type Cpos;
|
||||
const scalar_type Cneg;
|
||||
|
||||
bool be_verbose;
|
||||
const bool be_verbose;
|
||||
const scalar_type eps;
|
||||
};
|
||||
|
||||
|
||||
|
@ -384,10 +384,11 @@ namespace dlib
|
|||
const scalar_type C_neg,
|
||||
const in_sample_vector_type& samples,
|
||||
const in_scalar_vector_type& labels,
|
||||
bool be_verbose
|
||||
const bool be_verbose,
|
||||
const scalar_type eps
|
||||
)
|
||||
{
|
||||
return oca_problem_c_svm<matrix_type, in_sample_vector_type, in_scalar_vector_type>(C_pos, C_neg, samples, labels, be_verbose);
|
||||
return oca_problem_c_svm<matrix_type, in_sample_vector_type, in_scalar_vector_type>(C_pos, C_neg, samples, labels, be_verbose, eps);
|
||||
}
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -419,10 +420,14 @@ namespace dlib
|
|||
- #get_oca() == oca() (i.e. an instance of oca with default parameters)
|
||||
- #get_c_class1() == 1
|
||||
- #get_c_class2() == 1
|
||||
- #get_epsilon() == 0.001
|
||||
- this object will not be verbose unless be_verbose() is called
|
||||
!*/
|
||||
{
|
||||
Cpos = 1;
|
||||
Cneg = 1;
|
||||
verbose = false;
|
||||
eps = 0.001;
|
||||
}
|
||||
|
||||
explicit svm_c_linear_trainer (
|
||||
|
@ -439,10 +444,69 @@ namespace dlib
|
|||
- #get_c_class2() == C
|
||||
!*/
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(C > 0,
|
||||
"\t svm_c_linear_trainer::svm_c_linear_trainer()"
|
||||
<< "\n\t C must be greater than 0"
|
||||
<< "\n\t C: " << C
|
||||
<< "\n\t this: " << this
|
||||
);
|
||||
|
||||
Cpos = C;
|
||||
Cneg = C;
|
||||
}
|
||||
|
||||
void set_epsilon (
|
||||
scalar_type eps_
|
||||
)
|
||||
/*!
|
||||
requires
|
||||
- eps > 0
|
||||
ensures
|
||||
- #get_epsilon() == eps
|
||||
!*/
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(eps_ > 0,
|
||||
"\t void svm_c_linear_trainer::set_epsilon()"
|
||||
<< "\n\t eps_ must be greater than 0"
|
||||
<< "\n\t eps_: " << eps_
|
||||
<< "\n\t this: " << this
|
||||
);
|
||||
|
||||
eps = eps_;
|
||||
}
|
||||
|
||||
const scalar_type get_epsilon (
|
||||
) const { return eps; }
|
||||
/*!
|
||||
ensures
|
||||
- returns the error epsilon that determines when training should stop.
|
||||
Smaller values may result in a more accurate solution but take longer
|
||||
to execute.
|
||||
!*/
|
||||
|
||||
void be_verbose (
|
||||
)
|
||||
/*!
|
||||
ensures
|
||||
- This object will print status messages to standard out so that a
|
||||
user can observe the progress of the algorithm.
|
||||
!*/
|
||||
{
|
||||
verbose = true;
|
||||
}
|
||||
|
||||
void be_quiet (
|
||||
)
|
||||
/*!
|
||||
ensures
|
||||
- this object will not print anything to standard out
|
||||
!*/
|
||||
{
|
||||
verbose = false;
|
||||
}
|
||||
|
||||
void set_oca (
|
||||
const oca& item
|
||||
)
|
||||
|
@ -485,6 +549,14 @@ namespace dlib
|
|||
- #get_c_class2() == C
|
||||
!*/
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(C > 0,
|
||||
"\t void svm_c_linear_trainer::set_c()"
|
||||
<< "\n\t C must be greater than 0"
|
||||
<< "\n\t C: " << C
|
||||
<< "\n\t this: " << this
|
||||
);
|
||||
|
||||
Cpos = C;
|
||||
Cneg = C;
|
||||
}
|
||||
|
@ -529,6 +601,14 @@ namespace dlib
|
|||
- #get_c_class1() == C
|
||||
!*/
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(C > 0,
|
||||
"\t void svm_c_linear_trainer::set_c_class1()"
|
||||
<< "\n\t C must be greater than 0"
|
||||
<< "\n\t C: " << C
|
||||
<< "\n\t this: " << this
|
||||
);
|
||||
|
||||
Cpos = C;
|
||||
}
|
||||
|
||||
|
@ -542,6 +622,14 @@ namespace dlib
|
|||
- #get_c_class2() == C
|
||||
!*/
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(C > 0,
|
||||
"\t void svm_c_linear_trainer::set_c_class2()"
|
||||
<< "\n\t C must be greater than 0"
|
||||
<< "\n\t C: " << C
|
||||
<< "\n\t this: " << this
|
||||
);
|
||||
|
||||
Cneg = C;
|
||||
}
|
||||
|
||||
|
@ -570,12 +658,71 @@ namespace dlib
|
|||
- F(new_x) < 0
|
||||
!*/
|
||||
{
|
||||
scalar_type obj;
|
||||
return do_train(vector_to_matrix(x),vector_to_matrix(y),obj);
|
||||
}
|
||||
|
||||
|
||||
template <
|
||||
typename in_sample_vector_type,
|
||||
typename in_scalar_vector_type
|
||||
>
|
||||
const decision_function<kernel_type> train (
|
||||
const in_sample_vector_type& x,
|
||||
const in_scalar_vector_type& y,
|
||||
scalar_type& svm_objective
|
||||
) const
|
||||
/*!
|
||||
requires
|
||||
- is_binary_classification_problem(x,y) == true
|
||||
- x == a matrix or something convertible to a matrix via vector_to_matrix().
|
||||
Also, x should contain sample_type objects.
|
||||
- y == a matrix or something convertible to a matrix via vector_to_matrix().
|
||||
Also, y should contain scalar_type objects.
|
||||
ensures
|
||||
- trains a C support vector classifier given the training samples in x and
|
||||
labels in y.
|
||||
- #svm_objective == the final value of the SVM objective function
|
||||
- returns a decision function F with the following properties:
|
||||
- if (new_x is a sample predicted have +1 label) then
|
||||
- F(new_x) >= 0
|
||||
- else
|
||||
- F(new_x) < 0
|
||||
!*/
|
||||
{
|
||||
return do_train(vector_to_matrix(x),vector_to_matrix(y),svm_objective);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
template <
|
||||
typename in_sample_vector_type,
|
||||
typename in_scalar_vector_type
|
||||
>
|
||||
const decision_function<kernel_type> do_train (
|
||||
const in_sample_vector_type& x,
|
||||
const in_scalar_vector_type& y,
|
||||
scalar_type& svm_objective
|
||||
) const
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
|
||||
"\t decision_function svm_c_linear_trainer::train(x,y)"
|
||||
<< "\n\t invalid inputs were given to this function"
|
||||
<< "\n\t x.nr(): " << x.nr()
|
||||
<< "\n\t y.nr(): " << y.nr()
|
||||
<< "\n\t x.nc(): " << x.nc()
|
||||
<< "\n\t y.nc(): " << y.nc()
|
||||
<< "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
|
||||
);
|
||||
|
||||
|
||||
typedef matrix<scalar_type,0,1> w_type;
|
||||
w_type w;
|
||||
|
||||
scalar_type obj = solver(make_oca_problem_c_svm<w_type>(Cpos, Cneg, vector_to_matrix(x), vector_to_matrix(y), true), w);
|
||||
|
||||
std::cout << "final obj: "<< obj << std::endl;
|
||||
svm_objective = solver(
|
||||
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps),
|
||||
w);
|
||||
|
||||
// put the solution into a decision function and then return it
|
||||
decision_function<kernel_type> df;
|
||||
|
@ -587,12 +734,12 @@ namespace dlib
|
|||
|
||||
return df;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
scalar_type Cpos;
|
||||
scalar_type Cneg;
|
||||
oca solver;
|
||||
scalar_type eps;
|
||||
bool verbose;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in New Issue