mirror of https://github.com/davisking/dlib.git
Added some tests for the svm_c_ekm_trainer.
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403629
This commit is contained in:
parent
2e75e57f3f
commit
e16ea2eb32
|
@ -347,6 +347,10 @@ namespace
|
|||
pegasos_trainer.set_lambda(0.00001);
|
||||
|
||||
|
||||
svm_c_ekm_trainer<kernel_type> ocas_ekm_trainer;
|
||||
ocas_ekm_trainer.set_kernel(kernel_type(gamma));
|
||||
ocas_ekm_trainer.set_c(100000);
|
||||
|
||||
svm_nu_trainer<kernel_type> trainer;
|
||||
trainer.set_kernel(kernel_type(gamma));
|
||||
trainer.set_nu(0.05);
|
||||
|
@ -369,16 +373,23 @@ namespace
|
|||
print_spinner();
|
||||
matrix<scalar_type> lin_cv = cross_validate_trainer_threaded(lin_trainer, x_linearized, y, 4, 2);
|
||||
print_spinner();
|
||||
matrix<scalar_type> ocas_ekm_cv = cross_validate_trainer_threaded(ocas_ekm_trainer, x, y, 4, 2);
|
||||
print_spinner();
|
||||
ocas_ekm_trainer.set_basis(randomly_subsample(x, 300));
|
||||
matrix<scalar_type> ocas_ekm_cv2 = cross_validate_trainer_threaded(ocas_ekm_trainer, x, y, 4, 2);
|
||||
print_spinner();
|
||||
matrix<scalar_type> peg_cv = cross_validate_trainer_threaded(batch(pegasos_trainer,1.0), x,y, 4, 2);
|
||||
print_spinner();
|
||||
matrix<scalar_type> peg_c_cv = cross_validate_trainer_threaded(batch_cached(pegasos_trainer,1.0), x,y, 4, 2);
|
||||
print_spinner();
|
||||
|
||||
dlog << LDEBUG << "rvm cv: " << rvm_cv;
|
||||
dlog << LDEBUG << "svm cv: " << svm_cv;
|
||||
dlog << LDEBUG << "rbf cv: " << rbf_cv;
|
||||
dlog << LDEBUG << "lin cv: " << lin_cv;
|
||||
dlog << LDEBUG << "peg cv: " << peg_cv;
|
||||
dlog << LDEBUG << "rvm cv: " << rvm_cv;
|
||||
dlog << LDEBUG << "svm cv: " << svm_cv;
|
||||
dlog << LDEBUG << "rbf cv: " << rbf_cv;
|
||||
dlog << LDEBUG << "lin cv: " << lin_cv;
|
||||
dlog << LDEBUG << "ocas_ekm cv: " << ocas_ekm_cv;
|
||||
dlog << LDEBUG << "ocas_ekm cv2: " << ocas_ekm_cv2;
|
||||
dlog << LDEBUG << "peg cv: " << peg_cv;
|
||||
dlog << LDEBUG << "peg cached cv: " << peg_c_cv;
|
||||
|
||||
// make sure the cached version of pegasos computes the same result
|
||||
|
@ -391,6 +402,8 @@ namespace
|
|||
DLIB_TEST_MSG(mean(lin_cv) > 0.9, lin_cv);
|
||||
DLIB_TEST_MSG(mean(peg_cv) > 0.9, peg_cv);
|
||||
DLIB_TEST_MSG(mean(peg_c_cv) > 0.9, peg_c_cv);
|
||||
DLIB_TEST_MSG(mean(ocas_ekm_cv) > 0.9, ocas_ekm_cv);
|
||||
DLIB_TEST_MSG(mean(ocas_ekm_cv2) > 0.9, ocas_ekm_cv2);
|
||||
|
||||
const long num_sv = trainer.train(x,y).basis_vectors.size();
|
||||
print_spinner();
|
||||
|
@ -398,9 +411,14 @@ namespace
|
|||
print_spinner();
|
||||
dlog << LDEBUG << "num sv: " << num_sv;
|
||||
dlog << LDEBUG << "num rv: " << num_rv;
|
||||
print_spinner();
|
||||
ocas_ekm_trainer.clear_basis();
|
||||
const long num_bv = ocas_ekm_trainer.train(x,y).basis_vectors.size();
|
||||
dlog << LDEBUG << "num ekm bv: " << num_bv;
|
||||
|
||||
DLIB_TEST(num_rv <= 17);
|
||||
DLIB_TEST_MSG(num_sv <= 45, num_sv);
|
||||
DLIB_TEST_MSG(num_bv <= 45, num_bv);
|
||||
|
||||
decision_function<kernel_type> df = reduced2(trainer, 19).train(x,y);
|
||||
print_spinner();
|
||||
|
|
Loading…
Reference in New Issue