mirror of https://github.com/davisking/dlib.git
Added unit tests for the svm_multiclass_linear with sparse priors
This commit is contained in:
parent
ff8fc68f3b
commit
d7f207f2f7
|
@ -6,6 +6,7 @@
|
||||||
#include <dlib/data_io.h>
|
#include <dlib/data_io.h>
|
||||||
#include "create_iris_datafile.h"
|
#include "create_iris_datafile.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
|
@ -92,6 +93,62 @@ namespace
|
||||||
DLIB_TEST((unsigned int)sum(diag(res))==samples.size());
|
DLIB_TEST((unsigned int)sum(diag(res))==samples.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void test_prior_sparse ()
|
||||||
|
{
|
||||||
|
print_spinner();
|
||||||
|
typedef std::map<unsigned long,double> sample_type;
|
||||||
|
typedef sparse_linear_kernel<sample_type> kernel_type;
|
||||||
|
|
||||||
|
std::vector<sample_type> samples;
|
||||||
|
std::vector<int> labels;
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; ++i)
|
||||||
|
{
|
||||||
|
if (i==2)
|
||||||
|
++i;
|
||||||
|
for (int iter = 0; iter < 5; ++iter)
|
||||||
|
{
|
||||||
|
sample_type samp;
|
||||||
|
samp[i] = 1;
|
||||||
|
samples.push_back(samp);
|
||||||
|
labels.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
svm_multiclass_linear_trainer<kernel_type,int> trainer;
|
||||||
|
|
||||||
|
multiclass_linear_decision_function<kernel_type,int> df = trainer.train(samples, labels);
|
||||||
|
|
||||||
|
//cout << "test: \n" << test_multiclass_decision_function(df, samples, labels) << endl;
|
||||||
|
//cout << df.weights << endl;
|
||||||
|
//cout << df.b << endl;
|
||||||
|
|
||||||
|
std::vector<sample_type> samples2;
|
||||||
|
std::vector<int> labels2;
|
||||||
|
int i = 2;
|
||||||
|
for (int iter = 0; iter < 5; ++iter)
|
||||||
|
{
|
||||||
|
sample_type samp;
|
||||||
|
samp[i] = 1;
|
||||||
|
samp[i+10] = 1;
|
||||||
|
samples2.push_back(samp);
|
||||||
|
labels2.push_back(i);
|
||||||
|
samples.push_back(samp);
|
||||||
|
labels.push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
trainer.set_prior(df);
|
||||||
|
trainer.set_c(0.1);
|
||||||
|
df = trainer.train(samples2, labels2);
|
||||||
|
|
||||||
|
matrix<double> res = test_multiclass_decision_function(df, samples, labels);
|
||||||
|
dlog << LINFO << "test: \n" << res;
|
||||||
|
dlog << LINFO << df.weights;
|
||||||
|
dlog << LINFO << df.b;
|
||||||
|
DLIB_TEST((unsigned int)sum(diag(res))==samples.size());
|
||||||
|
}
|
||||||
|
|
||||||
template <typename sample_type>
|
template <typename sample_type>
|
||||||
void run_test()
|
void run_test()
|
||||||
{
|
{
|
||||||
|
@ -158,6 +215,7 @@ namespace
|
||||||
run_test<std::vector<std::pair<unsigned long, double> > >();
|
run_test<std::vector<std::pair<unsigned long, double> > >();
|
||||||
|
|
||||||
test_prior();
|
test_prior();
|
||||||
|
test_prior_sparse();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue