mirror of https://github.com/davisking/dlib.git
updated the svm_c_linear_trainer
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403504
This commit is contained in:
parent
c2de9780a3
commit
ce42b721dd
|
@ -41,13 +41,15 @@ namespace dlib
|
||||||
const scalar_type C_neg,
|
const scalar_type C_neg,
|
||||||
const in_sample_vector_type& samples_,
|
const in_sample_vector_type& samples_,
|
||||||
const in_scalar_vector_type& labels_,
|
const in_scalar_vector_type& labels_,
|
||||||
bool be_verbose_
|
const bool be_verbose_,
|
||||||
|
const scalar_type eps_
|
||||||
) :
|
) :
|
||||||
samples(samples_),
|
samples(samples_),
|
||||||
labels(labels_),
|
labels(labels_),
|
||||||
Cpos(C_pos),
|
Cpos(C_pos),
|
||||||
Cneg(C_neg),
|
Cneg(C_neg),
|
||||||
be_verbose(be_verbose_)
|
be_verbose(be_verbose_),
|
||||||
|
eps(eps_)
|
||||||
{
|
{
|
||||||
dot_prods.resize(samples.size());
|
dot_prods.resize(samples.size());
|
||||||
is_first_call = true;
|
is_first_call = true;
|
||||||
|
@ -83,10 +85,7 @@ namespace dlib
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (current_error_gap/current_objective_value < 0.001)
|
if (current_error_gap/current_objective_value < eps)
|
||||||
return true;
|
|
||||||
|
|
||||||
if (num_iterations > 10000)
|
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
|
@ -369,7 +368,8 @@ namespace dlib
|
||||||
const scalar_type Cpos;
|
const scalar_type Cpos;
|
||||||
const scalar_type Cneg;
|
const scalar_type Cneg;
|
||||||
|
|
||||||
bool be_verbose;
|
const bool be_verbose;
|
||||||
|
const scalar_type eps;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -384,10 +384,11 @@ namespace dlib
|
||||||
const scalar_type C_neg,
|
const scalar_type C_neg,
|
||||||
const in_sample_vector_type& samples,
|
const in_sample_vector_type& samples,
|
||||||
const in_scalar_vector_type& labels,
|
const in_scalar_vector_type& labels,
|
||||||
bool be_verbose
|
const bool be_verbose,
|
||||||
|
const scalar_type eps
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
return oca_problem_c_svm<matrix_type, in_sample_vector_type, in_scalar_vector_type>(C_pos, C_neg, samples, labels, be_verbose);
|
return oca_problem_c_svm<matrix_type, in_sample_vector_type, in_scalar_vector_type>(C_pos, C_neg, samples, labels, be_verbose, eps);
|
||||||
}
|
}
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@ -419,10 +420,14 @@ namespace dlib
|
||||||
- #get_oca() == oca() (i.e. an instance of oca with default parameters)
|
- #get_oca() == oca() (i.e. an instance of oca with default parameters)
|
||||||
- #get_c_class1() == 1
|
- #get_c_class1() == 1
|
||||||
- #get_c_class2() == 1
|
- #get_c_class2() == 1
|
||||||
|
- #get_epsilon() == 0.001
|
||||||
|
- this object will not be verbose unless be_verbose() is called
|
||||||
!*/
|
!*/
|
||||||
{
|
{
|
||||||
Cpos = 1;
|
Cpos = 1;
|
||||||
Cneg = 1;
|
Cneg = 1;
|
||||||
|
verbose = false;
|
||||||
|
eps = 0.001;
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit svm_c_linear_trainer (
|
explicit svm_c_linear_trainer (
|
||||||
|
@ -439,10 +444,69 @@ namespace dlib
|
||||||
- #get_c_class2() == C
|
- #get_c_class2() == C
|
||||||
!*/
|
!*/
|
||||||
{
|
{
|
||||||
|
// make sure requires clause is not broken
|
||||||
|
DLIB_ASSERT(C > 0,
|
||||||
|
"\t svm_c_linear_trainer::svm_c_linear_trainer()"
|
||||||
|
<< "\n\t C must be greater than 0"
|
||||||
|
<< "\n\t C: " << C
|
||||||
|
<< "\n\t this: " << this
|
||||||
|
);
|
||||||
|
|
||||||
Cpos = C;
|
Cpos = C;
|
||||||
Cneg = C;
|
Cneg = C;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_epsilon (
|
||||||
|
scalar_type eps_
|
||||||
|
)
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- eps > 0
|
||||||
|
ensures
|
||||||
|
- #get_epsilon() == eps
|
||||||
|
!*/
|
||||||
|
{
|
||||||
|
// make sure requires clause is not broken
|
||||||
|
DLIB_ASSERT(eps_ > 0,
|
||||||
|
"\t void svm_c_linear_trainer::set_epsilon()"
|
||||||
|
<< "\n\t eps_ must be greater than 0"
|
||||||
|
<< "\n\t eps_: " << eps_
|
||||||
|
<< "\n\t this: " << this
|
||||||
|
);
|
||||||
|
|
||||||
|
eps = eps_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const scalar_type get_epsilon (
|
||||||
|
) const { return eps; }
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- returns the error epsilon that determines when training should stop.
|
||||||
|
Smaller values may result in a more accurate solution but take longer
|
||||||
|
to execute.
|
||||||
|
!*/
|
||||||
|
|
||||||
|
void be_verbose (
|
||||||
|
)
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- This object will print status messages to standard out so that a
|
||||||
|
user can observe the progress of the algorithm.
|
||||||
|
!*/
|
||||||
|
{
|
||||||
|
verbose = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void be_quiet (
|
||||||
|
)
|
||||||
|
/*!
|
||||||
|
ensures
|
||||||
|
- this object will not print anything to standard out
|
||||||
|
!*/
|
||||||
|
{
|
||||||
|
verbose = false;
|
||||||
|
}
|
||||||
|
|
||||||
void set_oca (
|
void set_oca (
|
||||||
const oca& item
|
const oca& item
|
||||||
)
|
)
|
||||||
|
@ -485,6 +549,14 @@ namespace dlib
|
||||||
- #get_c_class2() == C
|
- #get_c_class2() == C
|
||||||
!*/
|
!*/
|
||||||
{
|
{
|
||||||
|
// make sure requires clause is not broken
|
||||||
|
DLIB_ASSERT(C > 0,
|
||||||
|
"\t void svm_c_linear_trainer::set_c()"
|
||||||
|
<< "\n\t C must be greater than 0"
|
||||||
|
<< "\n\t C: " << C
|
||||||
|
<< "\n\t this: " << this
|
||||||
|
);
|
||||||
|
|
||||||
Cpos = C;
|
Cpos = C;
|
||||||
Cneg = C;
|
Cneg = C;
|
||||||
}
|
}
|
||||||
|
@ -529,6 +601,14 @@ namespace dlib
|
||||||
- #get_c_class1() == C
|
- #get_c_class1() == C
|
||||||
!*/
|
!*/
|
||||||
{
|
{
|
||||||
|
// make sure requires clause is not broken
|
||||||
|
DLIB_ASSERT(C > 0,
|
||||||
|
"\t void svm_c_linear_trainer::set_c_class1()"
|
||||||
|
<< "\n\t C must be greater than 0"
|
||||||
|
<< "\n\t C: " << C
|
||||||
|
<< "\n\t this: " << this
|
||||||
|
);
|
||||||
|
|
||||||
Cpos = C;
|
Cpos = C;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -542,6 +622,14 @@ namespace dlib
|
||||||
- #get_c_class2() == C
|
- #get_c_class2() == C
|
||||||
!*/
|
!*/
|
||||||
{
|
{
|
||||||
|
// make sure requires clause is not broken
|
||||||
|
DLIB_ASSERT(C > 0,
|
||||||
|
"\t void svm_c_linear_trainer::set_c_class2()"
|
||||||
|
<< "\n\t C must be greater than 0"
|
||||||
|
<< "\n\t C: " << C
|
||||||
|
<< "\n\t this: " << this
|
||||||
|
);
|
||||||
|
|
||||||
Cneg = C;
|
Cneg = C;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -570,12 +658,71 @@ namespace dlib
|
||||||
- F(new_x) < 0
|
- F(new_x) < 0
|
||||||
!*/
|
!*/
|
||||||
{
|
{
|
||||||
|
scalar_type obj;
|
||||||
|
return do_train(vector_to_matrix(x),vector_to_matrix(y),obj);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename in_sample_vector_type,
|
||||||
|
typename in_scalar_vector_type
|
||||||
|
>
|
||||||
|
const decision_function<kernel_type> train (
|
||||||
|
const in_sample_vector_type& x,
|
||||||
|
const in_scalar_vector_type& y,
|
||||||
|
scalar_type& svm_objective
|
||||||
|
) const
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- is_binary_classification_problem(x,y) == true
|
||||||
|
- x == a matrix or something convertible to a matrix via vector_to_matrix().
|
||||||
|
Also, x should contain sample_type objects.
|
||||||
|
- y == a matrix or something convertible to a matrix via vector_to_matrix().
|
||||||
|
Also, y should contain scalar_type objects.
|
||||||
|
ensures
|
||||||
|
- trains a C support vector classifier given the training samples in x and
|
||||||
|
labels in y.
|
||||||
|
- #svm_objective == the final value of the SVM objective function
|
||||||
|
- returns a decision function F with the following properties:
|
||||||
|
- if (new_x is a sample predicted have +1 label) then
|
||||||
|
- F(new_x) >= 0
|
||||||
|
- else
|
||||||
|
- F(new_x) < 0
|
||||||
|
!*/
|
||||||
|
{
|
||||||
|
return do_train(vector_to_matrix(x),vector_to_matrix(y),svm_objective);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename in_sample_vector_type,
|
||||||
|
typename in_scalar_vector_type
|
||||||
|
>
|
||||||
|
const decision_function<kernel_type> do_train (
|
||||||
|
const in_sample_vector_type& x,
|
||||||
|
const in_scalar_vector_type& y,
|
||||||
|
scalar_type& svm_objective
|
||||||
|
) const
|
||||||
|
{
|
||||||
|
// make sure requires clause is not broken
|
||||||
|
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
|
||||||
|
"\t decision_function svm_c_linear_trainer::train(x,y)"
|
||||||
|
<< "\n\t invalid inputs were given to this function"
|
||||||
|
<< "\n\t x.nr(): " << x.nr()
|
||||||
|
<< "\n\t y.nr(): " << y.nr()
|
||||||
|
<< "\n\t x.nc(): " << x.nc()
|
||||||
|
<< "\n\t y.nc(): " << y.nc()
|
||||||
|
<< "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
typedef matrix<scalar_type,0,1> w_type;
|
typedef matrix<scalar_type,0,1> w_type;
|
||||||
w_type w;
|
w_type w;
|
||||||
|
|
||||||
scalar_type obj = solver(make_oca_problem_c_svm<w_type>(Cpos, Cneg, vector_to_matrix(x), vector_to_matrix(y), true), w);
|
svm_objective = solver(
|
||||||
|
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps),
|
||||||
std::cout << "final obj: "<< obj << std::endl;
|
w);
|
||||||
|
|
||||||
// put the solution into a decision function and then return it
|
// put the solution into a decision function and then return it
|
||||||
decision_function<kernel_type> df;
|
decision_function<kernel_type> df;
|
||||||
|
@ -587,12 +734,12 @@ namespace dlib
|
||||||
|
|
||||||
return df;
|
return df;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
|
|
||||||
scalar_type Cpos;
|
scalar_type Cpos;
|
||||||
scalar_type Cneg;
|
scalar_type Cneg;
|
||||||
oca solver;
|
oca solver;
|
||||||
|
scalar_type eps;
|
||||||
|
bool verbose;
|
||||||
};
|
};
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
Loading…
Reference in New Issue