mirror of https://github.com/davisking/dlib.git
Added cross validation functions for ranking tools and slightly improved documentation
for other cross validation functions.
This commit is contained in:
parent
97f82b1e4f
commit
2f34594f47
|
@ -6,6 +6,7 @@
|
|||
#include "serialize_pickle.h"
|
||||
#include <dlib/svm_threaded.h>
|
||||
#include "pyassert.h"
|
||||
#include <boost/python/args.hpp>
|
||||
|
||||
using namespace dlib;
|
||||
using namespace std;
|
||||
|
@ -166,34 +167,43 @@ const binary_test _cross_validate_trainer_t (
|
|||
|
||||
void bind_svm_c_trainer()
|
||||
{
|
||||
using boost::python::arg;
|
||||
{
|
||||
typedef svm_c_trainer<radial_basis_kernel<sample_type> > T;
|
||||
setup_trainer2<T>("svm_c_trainer_radial_basis")
|
||||
.add_property("gamma", get_gamma, set_gamma);
|
||||
def("cross_validate_trainer", _cross_validate_trainer<T>);
|
||||
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
|
||||
def("cross_validate_trainer", _cross_validate_trainer<T>,
|
||||
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
|
||||
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
|
||||
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
|
||||
}
|
||||
|
||||
{
|
||||
typedef svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > T;
|
||||
setup_trainer2<T>("svm_c_trainer_sparse_radial_basis")
|
||||
.add_property("gamma", get_gamma_sparse, set_gamma_sparse);
|
||||
def("cross_validate_trainer", _cross_validate_trainer<T>);
|
||||
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
|
||||
def("cross_validate_trainer", _cross_validate_trainer<T>,
|
||||
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
|
||||
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
|
||||
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
|
||||
}
|
||||
|
||||
{
|
||||
typedef svm_c_trainer<histogram_intersection_kernel<sample_type> > T;
|
||||
setup_trainer2<T>("svm_c_trainer_histogram_intersection");
|
||||
def("cross_validate_trainer", _cross_validate_trainer<T>);
|
||||
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
|
||||
def("cross_validate_trainer", _cross_validate_trainer<T>,
|
||||
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
|
||||
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
|
||||
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
|
||||
}
|
||||
|
||||
{
|
||||
typedef svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > T;
|
||||
setup_trainer2<T>("svm_c_trainer_sparse_histogram_intersection");
|
||||
def("cross_validate_trainer", _cross_validate_trainer<T>);
|
||||
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
|
||||
def("cross_validate_trainer", _cross_validate_trainer<T>,
|
||||
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
|
||||
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
|
||||
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -205,8 +215,10 @@ void bind_svm_c_trainer()
|
|||
.def("be_verbose", &T::be_verbose)
|
||||
.def("be_quiet", &T::be_quiet);
|
||||
|
||||
def("cross_validate_trainer", _cross_validate_trainer<T>);
|
||||
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
|
||||
def("cross_validate_trainer", _cross_validate_trainer<T>,
|
||||
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
|
||||
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
|
||||
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -218,8 +230,10 @@ void bind_svm_c_trainer()
|
|||
.def("be_verbose", &T::be_verbose)
|
||||
.def("be_quiet", &T::be_quiet);
|
||||
|
||||
def("cross_validate_trainer", _cross_validate_trainer<T>);
|
||||
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
|
||||
def("cross_validate_trainer", _cross_validate_trainer<T>,
|
||||
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
|
||||
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
|
||||
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
#include <dlib/svm.h>
|
||||
#include "pyassert.h"
|
||||
#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
|
||||
#include "testing_results.h"
|
||||
#include <boost/python/args.hpp>
|
||||
|
||||
using namespace dlib;
|
||||
using namespace std;
|
||||
|
@ -99,8 +101,26 @@ void add_ranker (
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
typename trainer_type,
|
||||
typename T
|
||||
>
|
||||
const ranking_test _cross_ranking_validate_trainer (
|
||||
const trainer_type& trainer,
|
||||
const std::vector<ranking_pair<T> >& samples,
|
||||
const unsigned long folds
|
||||
)
|
||||
{
|
||||
pyassert(is_ranking_problem(samples), "Training data does not make a valid training set.");
|
||||
pyassert(1 < folds && folds <= samples.size(), "Invalid number of folds given.");
|
||||
return cross_validate_ranking_trainer(trainer, samples, folds);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void bind_svm_rank_trainer()
|
||||
{
|
||||
using boost::python::arg;
|
||||
class_<ranking_pair<sample_type> >("ranking_pair")
|
||||
.add_property("relevant", &ranking_pair<sample_type>::relevant)
|
||||
.add_property("nonrelevant", &ranking_pair<sample_type>::nonrelevant)
|
||||
|
@ -127,6 +147,13 @@ void bind_svm_rank_trainer()
|
|||
|
||||
add_ranker<svm_rank_trainer<linear_kernel<sample_type> > >("svm_rank_trainer");
|
||||
add_ranker<svm_rank_trainer<sparse_linear_kernel<sparse_vect> > >("svm_rank_trainer_sparse");
|
||||
|
||||
def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer<
|
||||
svm_rank_trainer<linear_kernel<sample_type> >,sample_type>,
|
||||
(arg("trainer"), arg("samples"), arg("folds")) );
|
||||
def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer<
|
||||
svm_rank_trainer<sparse_linear_kernel<sparse_vect> > ,sparse_vect>,
|
||||
(arg("trainer"), arg("samples"), arg("folds")) );
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue