Added some tests for the svm_c_ekm_trainer.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403629
This commit is contained in:
Davis King 2010-05-16 17:14:49 +00:00
parent 2e75e57f3f
commit e16ea2eb32
1 changed files with 23 additions and 5 deletions

View File

@ -347,6 +347,10 @@ namespace
pegasos_trainer.set_lambda(0.00001);
svm_c_ekm_trainer<kernel_type> ocas_ekm_trainer;
ocas_ekm_trainer.set_kernel(kernel_type(gamma));
ocas_ekm_trainer.set_c(100000);
svm_nu_trainer<kernel_type> trainer;
trainer.set_kernel(kernel_type(gamma));
trainer.set_nu(0.05);
@ -369,16 +373,23 @@ namespace
print_spinner();
matrix<scalar_type> lin_cv = cross_validate_trainer_threaded(lin_trainer, x_linearized, y, 4, 2);
print_spinner();
matrix<scalar_type> ocas_ekm_cv = cross_validate_trainer_threaded(ocas_ekm_trainer, x, y, 4, 2);
print_spinner();
ocas_ekm_trainer.set_basis(randomly_subsample(x, 300));
matrix<scalar_type> ocas_ekm_cv2 = cross_validate_trainer_threaded(ocas_ekm_trainer, x, y, 4, 2);
print_spinner();
matrix<scalar_type> peg_cv = cross_validate_trainer_threaded(batch(pegasos_trainer,1.0), x,y, 4, 2);
print_spinner();
matrix<scalar_type> peg_c_cv = cross_validate_trainer_threaded(batch_cached(pegasos_trainer,1.0), x,y, 4, 2);
print_spinner();
dlog << LDEBUG << "rvm cv: " << rvm_cv;
dlog << LDEBUG << "svm cv: " << svm_cv;
dlog << LDEBUG << "rbf cv: " << rbf_cv;
dlog << LDEBUG << "lin cv: " << lin_cv;
dlog << LDEBUG << "peg cv: " << peg_cv;
dlog << LDEBUG << "rvm cv: " << rvm_cv;
dlog << LDEBUG << "svm cv: " << svm_cv;
dlog << LDEBUG << "rbf cv: " << rbf_cv;
dlog << LDEBUG << "lin cv: " << lin_cv;
dlog << LDEBUG << "ocas_ekm cv: " << ocas_ekm_cv;
dlog << LDEBUG << "ocas_ekm cv2: " << ocas_ekm_cv2;
dlog << LDEBUG << "peg cv: " << peg_cv;
dlog << LDEBUG << "peg cached cv: " << peg_c_cv;
// make sure the cached version of pegasos computes the same result
@ -391,6 +402,8 @@ namespace
DLIB_TEST_MSG(mean(lin_cv) > 0.9, lin_cv);
DLIB_TEST_MSG(mean(peg_cv) > 0.9, peg_cv);
DLIB_TEST_MSG(mean(peg_c_cv) > 0.9, peg_c_cv);
DLIB_TEST_MSG(mean(ocas_ekm_cv) > 0.9, ocas_ekm_cv);
DLIB_TEST_MSG(mean(ocas_ekm_cv2) > 0.9, ocas_ekm_cv2);
const long num_sv = trainer.train(x,y).basis_vectors.size();
print_spinner();
@ -398,9 +411,14 @@ namespace
print_spinner();
dlog << LDEBUG << "num sv: " << num_sv;
dlog << LDEBUG << "num rv: " << num_rv;
print_spinner();
ocas_ekm_trainer.clear_basis();
const long num_bv = ocas_ekm_trainer.train(x,y).basis_vectors.size();
dlog << LDEBUG << "num ekm bv: " << num_bv;
DLIB_TEST(num_rv <= 17);
DLIB_TEST_MSG(num_sv <= 45, num_sv);
DLIB_TEST_MSG(num_bv <= 45, num_bv);
decision_function<kernel_type> df = reduced2(trainer, 19).train(x,y);
print_spinner();