From cabf9fc57722dc1af4a3338c6aad4f6ff7b0bd43 Mon Sep 17 00:00:00 2001 From: Davis King Date: Sat, 6 Mar 2010 14:28:20 +0000 Subject: [PATCH] Added more tests for the svm_c_linear_trainer --HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403539 --- dlib/test/svm.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/dlib/test/svm.cpp b/dlib/test/svm.cpp index de3661187..645db5ef7 100644 --- a/dlib/test/svm.cpp +++ b/dlib/test/svm.cpp @@ -327,6 +327,7 @@ namespace typedef matrix sample_type; std::vector x; + std::vector > x_linearized; std::vector y; get_checkerboard_problem(x,y, 300, 3); @@ -350,6 +351,15 @@ namespace trainer.set_kernel(kernel_type(gamma)); trainer.set_nu(0.05); + svm_c_linear_trainer > > lin_trainer; + lin_trainer.set_c(100000); + // use an ekm to linearize this dataset so we can use it with the lin_trainer + empirical_kernel_map ekm; + ekm.load(kernel_type(gamma), x); + for (unsigned long i = 0; i < x.size(); ++i) + x_linearized.push_back(ekm.project(x[i])); + + print_spinner(); matrix rvm_cv = cross_validate_trainer_threaded(rvm_trainer, x,y, 4, 2); print_spinner(); @@ -357,6 +367,8 @@ namespace print_spinner(); matrix rbf_cv = cross_validate_trainer_threaded(rbf_trainer, x,y, 4, 2); print_spinner(); + matrix lin_cv = cross_validate_trainer_threaded(lin_trainer, x_linearized, y, 4, 2); + print_spinner(); matrix peg_cv = cross_validate_trainer_threaded(batch(pegasos_trainer,1.0), x,y, 4, 2); print_spinner(); matrix peg_c_cv = cross_validate_trainer_threaded(batch_cached(pegasos_trainer,1.0), x,y, 4, 2); @@ -365,6 +377,7 @@ namespace dlog << LDEBUG << "rvm cv: " << rvm_cv; dlog << LDEBUG << "svm cv: " << svm_cv; dlog << LDEBUG << "rbf cv: " << rbf_cv; + dlog << LDEBUG << "lin cv: " << lin_cv; dlog << LDEBUG << "peg cv: " << peg_cv; dlog << LDEBUG << "peg cached cv: " << peg_c_cv; @@ -375,6 +388,7 @@ namespace DLIB_TEST_MSG(mean(rvm_cv) > 0.9, rvm_cv); DLIB_TEST_MSG(mean(svm_cv) > 0.9, svm_cv); DLIB_TEST_MSG(mean(rbf_cv) > 0.9, rbf_cv); + DLIB_TEST_MSG(mean(lin_cv) > 0.9, lin_cv); DLIB_TEST_MSG(mean(peg_cv) > 0.9, peg_cv); DLIB_TEST_MSG(mean(peg_c_cv) > 0.9, peg_c_cv);