Added unit tests for the svm_multiclass_linear with sparse priors

This commit is contained in:
Davis King 2014-05-23 18:35:49 -04:00
parent ff8fc68f3b
commit d7f207f2f7
1 changed files with 58 additions and 0 deletions

View File

@ -6,6 +6,7 @@
#include <dlib/data_io.h>
#include "create_iris_datafile.h"
#include <vector>
#include <map>
#include <sstream>
namespace
@ -92,6 +93,62 @@ namespace
DLIB_TEST((unsigned int)sum(diag(res))==samples.size());
}
void test_prior_sparse ()
{
print_spinner();
typedef std::map<unsigned long,double> sample_type;
typedef sparse_linear_kernel<sample_type> kernel_type;
std::vector<sample_type> samples;
std::vector<int> labels;
for (int i = 0; i < 4; ++i)
{
if (i==2)
++i;
for (int iter = 0; iter < 5; ++iter)
{
sample_type samp;
samp[i] = 1;
samples.push_back(samp);
labels.push_back(i);
}
}
svm_multiclass_linear_trainer<kernel_type,int> trainer;
multiclass_linear_decision_function<kernel_type,int> df = trainer.train(samples, labels);
//cout << "test: \n" << test_multiclass_decision_function(df, samples, labels) << endl;
//cout << df.weights << endl;
//cout << df.b << endl;
std::vector<sample_type> samples2;
std::vector<int> labels2;
int i = 2;
for (int iter = 0; iter < 5; ++iter)
{
sample_type samp;
samp[i] = 1;
samp[i+10] = 1;
samples2.push_back(samp);
labels2.push_back(i);
samples.push_back(samp);
labels.push_back(i);
}
trainer.set_prior(df);
trainer.set_c(0.1);
df = trainer.train(samples2, labels2);
matrix<double> res = test_multiclass_decision_function(df, samples, labels);
dlog << LINFO << "test: \n" << res;
dlog << LINFO << df.weights;
dlog << LINFO << df.b;
DLIB_TEST((unsigned int)sum(diag(res))==samples.size());
}
template <typename sample_type>
void run_test()
{
@ -158,6 +215,7 @@ namespace
run_test<std::vector<std::pair<unsigned long, double> > >();
test_prior();
test_prior_sparse();
}
};