Added cross validation functions for ranking tools and slightly improved documentation

for other cross validation functions.
This commit is contained in:
Davis King 2013-06-07 23:50:40 -04:00
parent 97f82b1e4f
commit 2f34594f47
2 changed files with 53 additions and 12 deletions

View File

@ -6,6 +6,7 @@
#include "serialize_pickle.h"
#include <dlib/svm_threaded.h>
#include "pyassert.h"
#include <boost/python/args.hpp>
using namespace dlib;
using namespace std;
@ -166,34 +167,43 @@ const binary_test _cross_validate_trainer_t (
void bind_svm_c_trainer()
{
using boost::python::arg;
{
typedef svm_c_trainer<radial_basis_kernel<sample_type> > T;
setup_trainer2<T>("svm_c_trainer_radial_basis")
.add_property("gamma", get_gamma, set_gamma);
def("cross_validate_trainer", _cross_validate_trainer<T>);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
def("cross_validate_trainer", _cross_validate_trainer<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
}
{
typedef svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > T;
setup_trainer2<T>("svm_c_trainer_sparse_radial_basis")
.add_property("gamma", get_gamma_sparse, set_gamma_sparse);
def("cross_validate_trainer", _cross_validate_trainer<T>);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
def("cross_validate_trainer", _cross_validate_trainer<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
}
{
typedef svm_c_trainer<histogram_intersection_kernel<sample_type> > T;
setup_trainer2<T>("svm_c_trainer_histogram_intersection");
def("cross_validate_trainer", _cross_validate_trainer<T>);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
def("cross_validate_trainer", _cross_validate_trainer<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
}
{
typedef svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > T;
setup_trainer2<T>("svm_c_trainer_sparse_histogram_intersection");
def("cross_validate_trainer", _cross_validate_trainer<T>);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
def("cross_validate_trainer", _cross_validate_trainer<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
}
{
@ -205,8 +215,10 @@ void bind_svm_c_trainer()
.def("be_verbose", &T::be_verbose)
.def("be_quiet", &T::be_quiet);
def("cross_validate_trainer", _cross_validate_trainer<T>);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
def("cross_validate_trainer", _cross_validate_trainer<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
}
{
@ -218,8 +230,10 @@ void bind_svm_c_trainer()
.def("be_verbose", &T::be_verbose)
.def("be_quiet", &T::be_quiet);
def("cross_validate_trainer", _cross_validate_trainer<T>);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>);
def("cross_validate_trainer", _cross_validate_trainer<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds")));
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
(arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads")));
}
}

View File

@ -6,6 +6,8 @@
#include <dlib/svm.h>
#include "pyassert.h"
#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
#include "testing_results.h"
#include <boost/python/args.hpp>
using namespace dlib;
using namespace std;
@ -99,8 +101,26 @@ void add_ranker (
// ----------------------------------------------------------------------------------------
template <
typename trainer_type,
typename T
>
const ranking_test _cross_ranking_validate_trainer (
const trainer_type& trainer,
const std::vector<ranking_pair<T> >& samples,
const unsigned long folds
)
{
pyassert(is_ranking_problem(samples), "Training data does not make a valid training set.");
pyassert(1 < folds && folds <= samples.size(), "Invalid number of folds given.");
return cross_validate_ranking_trainer(trainer, samples, folds);
}
// ----------------------------------------------------------------------------------------
void bind_svm_rank_trainer()
{
using boost::python::arg;
class_<ranking_pair<sample_type> >("ranking_pair")
.add_property("relevant", &ranking_pair<sample_type>::relevant)
.add_property("nonrelevant", &ranking_pair<sample_type>::nonrelevant)
@ -127,6 +147,13 @@ void bind_svm_rank_trainer()
add_ranker<svm_rank_trainer<linear_kernel<sample_type> > >("svm_rank_trainer");
add_ranker<svm_rank_trainer<sparse_linear_kernel<sparse_vect> > >("svm_rank_trainer_sparse");
def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer<
svm_rank_trainer<linear_kernel<sample_type> >,sample_type>,
(arg("trainer"), arg("samples"), arg("folds")) );
def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer<
svm_rank_trainer<sparse_linear_kernel<sparse_vect> > ,sparse_vect>,
(arg("trainer"), arg("samples"), arg("folds")) );
}