From 2f34594f47e78b842c189fef8fd809a5e524f971 Mon Sep 17 00:00:00 2001 From: Davis King Date: Fri, 7 Jun 2013 23:50:40 -0400 Subject: [PATCH] Added cross validation functions for ranking tools and slightly improved documentation for other cross validation functions. --- tools/python/src/svm_c_trainer.cpp | 38 ++++++++++++++++++--------- tools/python/src/svm_rank_trainer.cpp | 27 +++++++++++++++++++ 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/tools/python/src/svm_c_trainer.cpp b/tools/python/src/svm_c_trainer.cpp index 1229868a3..7a8417ec4 100644 --- a/tools/python/src/svm_c_trainer.cpp +++ b/tools/python/src/svm_c_trainer.cpp @@ -6,6 +6,7 @@ #include "serialize_pickle.h" #include #include "pyassert.h" +#include using namespace dlib; using namespace std; @@ -166,34 +167,43 @@ const binary_test _cross_validate_trainer_t ( void bind_svm_c_trainer() { + using boost::python::arg; { typedef svm_c_trainer > T; setup_trainer2("svm_c_trainer_radial_basis") .add_property("gamma", get_gamma, set_gamma); - def("cross_validate_trainer", _cross_validate_trainer); - def("cross_validate_trainer_threaded", _cross_validate_trainer_t); + def("cross_validate_trainer", _cross_validate_trainer, + (arg("trainer"),arg("x"),arg("y"),arg("folds"))); + def("cross_validate_trainer_threaded", _cross_validate_trainer_t, + (arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads"))); } { typedef svm_c_trainer > T; setup_trainer2("svm_c_trainer_sparse_radial_basis") .add_property("gamma", get_gamma_sparse, set_gamma_sparse); - def("cross_validate_trainer", _cross_validate_trainer); - def("cross_validate_trainer_threaded", _cross_validate_trainer_t); + def("cross_validate_trainer", _cross_validate_trainer, + (arg("trainer"),arg("x"),arg("y"),arg("folds"))); + def("cross_validate_trainer_threaded", _cross_validate_trainer_t, + (arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads"))); } { typedef svm_c_trainer > T; setup_trainer2("svm_c_trainer_histogram_intersection"); - def("cross_validate_trainer", _cross_validate_trainer); - def("cross_validate_trainer_threaded", _cross_validate_trainer_t); + def("cross_validate_trainer", _cross_validate_trainer, + (arg("trainer"),arg("x"),arg("y"),arg("folds"))); + def("cross_validate_trainer_threaded", _cross_validate_trainer_t, + (arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads"))); } { typedef svm_c_trainer > T; setup_trainer2("svm_c_trainer_sparse_histogram_intersection"); - def("cross_validate_trainer", _cross_validate_trainer); - def("cross_validate_trainer_threaded", _cross_validate_trainer_t); + def("cross_validate_trainer", _cross_validate_trainer, + (arg("trainer"),arg("x"),arg("y"),arg("folds"))); + def("cross_validate_trainer_threaded", _cross_validate_trainer_t, + (arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads"))); } { @@ -205,8 +215,10 @@ void bind_svm_c_trainer() .def("be_verbose", &T::be_verbose) .def("be_quiet", &T::be_quiet); - def("cross_validate_trainer", _cross_validate_trainer); - def("cross_validate_trainer_threaded", _cross_validate_trainer_t); + def("cross_validate_trainer", _cross_validate_trainer, + (arg("trainer"),arg("x"),arg("y"),arg("folds"))); + def("cross_validate_trainer_threaded", _cross_validate_trainer_t, + (arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads"))); } { @@ -218,8 +230,10 @@ void bind_svm_c_trainer() .def("be_verbose", &T::be_verbose) .def("be_quiet", &T::be_quiet); - def("cross_validate_trainer", _cross_validate_trainer); - def("cross_validate_trainer_threaded", _cross_validate_trainer_t); + def("cross_validate_trainer", _cross_validate_trainer, + (arg("trainer"),arg("x"),arg("y"),arg("folds"))); + def("cross_validate_trainer_threaded", _cross_validate_trainer_t, + (arg("trainer"),arg("x"),arg("y"),arg("folds"),arg("num_threads"))); } } diff --git a/tools/python/src/svm_rank_trainer.cpp b/tools/python/src/svm_rank_trainer.cpp index 4c7b3a9ea..817fba8f6 100644 --- a/tools/python/src/svm_rank_trainer.cpp +++ b/tools/python/src/svm_rank_trainer.cpp @@ -6,6 +6,8 @@ #include #include "pyassert.h" #include +#include "testing_results.h" +#include using namespace dlib; using namespace std; @@ -99,8 +101,26 @@ void add_ranker ( // ---------------------------------------------------------------------------------------- +template < + typename trainer_type, + typename T + > +const ranking_test _cross_ranking_validate_trainer ( + const trainer_type& trainer, + const std::vector >& samples, + const unsigned long folds +) +{ + pyassert(is_ranking_problem(samples), "Training data does not make a valid training set."); + pyassert(1 < folds && folds <= samples.size(), "Invalid number of folds given."); + return cross_validate_ranking_trainer(trainer, samples, folds); +} + +// ---------------------------------------------------------------------------------------- + void bind_svm_rank_trainer() { + using boost::python::arg; class_ >("ranking_pair") .add_property("relevant", &ranking_pair::relevant) .add_property("nonrelevant", &ranking_pair::nonrelevant) @@ -127,6 +147,13 @@ void bind_svm_rank_trainer() add_ranker > >("svm_rank_trainer"); add_ranker > >("svm_rank_trainer_sparse"); + + def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer< + svm_rank_trainer >,sample_type>, + (arg("trainer"), arg("samples"), arg("folds")) ); + def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer< + svm_rank_trainer > ,sparse_vect>, + (arg("trainer"), arg("samples"), arg("folds")) ); }