mirror of https://github.com/davisking/dlib.git
Added unit tests for the svm_multiclass_linear with sparse priors
This commit is contained in:
parent
ff8fc68f3b
commit
d7f207f2f7
|
@ -6,6 +6,7 @@
|
|||
#include <dlib/data_io.h>
|
||||
#include "create_iris_datafile.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
|
||||
namespace
|
||||
|
@ -92,6 +93,62 @@ namespace
|
|||
DLIB_TEST((unsigned int)sum(diag(res))==samples.size());
|
||||
}
|
||||
|
||||
void test_prior_sparse ()
|
||||
{
|
||||
print_spinner();
|
||||
typedef std::map<unsigned long,double> sample_type;
|
||||
typedef sparse_linear_kernel<sample_type> kernel_type;
|
||||
|
||||
std::vector<sample_type> samples;
|
||||
std::vector<int> labels;
|
||||
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
if (i==2)
|
||||
++i;
|
||||
for (int iter = 0; iter < 5; ++iter)
|
||||
{
|
||||
sample_type samp;
|
||||
samp[i] = 1;
|
||||
samples.push_back(samp);
|
||||
labels.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
svm_multiclass_linear_trainer<kernel_type,int> trainer;
|
||||
|
||||
multiclass_linear_decision_function<kernel_type,int> df = trainer.train(samples, labels);
|
||||
|
||||
//cout << "test: \n" << test_multiclass_decision_function(df, samples, labels) << endl;
|
||||
//cout << df.weights << endl;
|
||||
//cout << df.b << endl;
|
||||
|
||||
std::vector<sample_type> samples2;
|
||||
std::vector<int> labels2;
|
||||
int i = 2;
|
||||
for (int iter = 0; iter < 5; ++iter)
|
||||
{
|
||||
sample_type samp;
|
||||
samp[i] = 1;
|
||||
samp[i+10] = 1;
|
||||
samples2.push_back(samp);
|
||||
labels2.push_back(i);
|
||||
samples.push_back(samp);
|
||||
labels.push_back(i);
|
||||
}
|
||||
|
||||
trainer.set_prior(df);
|
||||
trainer.set_c(0.1);
|
||||
df = trainer.train(samples2, labels2);
|
||||
|
||||
matrix<double> res = test_multiclass_decision_function(df, samples, labels);
|
||||
dlog << LINFO << "test: \n" << res;
|
||||
dlog << LINFO << df.weights;
|
||||
dlog << LINFO << df.b;
|
||||
DLIB_TEST((unsigned int)sum(diag(res))==samples.size());
|
||||
}
|
||||
|
||||
template <typename sample_type>
|
||||
void run_test()
|
||||
{
|
||||
|
@ -158,6 +215,7 @@ namespace
|
|||
run_test<std::vector<std::pair<unsigned long, double> > >();
|
||||
|
||||
test_prior();
|
||||
test_prior_sparse();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue