Changed the rbf_network_trainer to use the linearly_independent_subset_finder

to find centers.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402413
This commit is contained in:
Davis King 2008-07-12 19:09:44 +00:00
parent ec316ccb1d
commit c8d9f20bca
2 changed files with 31 additions and 35 deletions

View File

@ -6,7 +6,7 @@
#include "../matrix.h"
#include "rbf_network_abstract.h"
#include "kernel.h"
#include "kcentroid.h"
#include "linearly_independent_subset_finder.h"
#include "function.h"
#include "../algs.h"
@ -23,7 +23,8 @@ namespace dlib
/*!
This is an implemenation of an RBF network trainer that follows
the directions right off Wikipedia basically. So nothing
particularly fancy.
particularly fancy. Although the way the centers are selected
is somewhat unique.
!*/
public:
@ -35,7 +36,7 @@ namespace dlib
rbf_network_trainer (
) :
tolerance(0.1)
num_centers(10)
{
}
@ -52,17 +53,17 @@ namespace dlib
return kernel;
}
void set_tolerance (
const scalar_type& tol
void set_num_centers (
const unsigned long num
)
{
tolerance = tol;
num_centers = num;
}
const scalar_type& get_tolerance (
const unsigned long get_num_centers (
) const
{
return tolerance;
return num_centers;
}
template <
@ -82,7 +83,7 @@ namespace dlib
)
{
exchange(kernel, item.kernel);
exchange(tolerance, item.tolerance);
exchange(num_centers, item.num_centers);
}
private:
@ -99,6 +100,7 @@ namespace dlib
) const
{
typedef typename decision_function<kernel_type>::scalar_vector_type scalar_vector_type;
typedef typename decision_function<kernel_type>::sample_vector_type sample_vector_type;
// make sure requires clause is not broken
DLIB_ASSERT(x.nr() > 1 && x.nr() == y.nr() && x.nc() == 1 && y.nc() == 1,
@ -110,18 +112,15 @@ namespace dlib
<< "\n\t y.nc(): " << y.nc()
);
// first run all the sampes through a kcentroid object to find the rbf centers
kcentroid<kernel_type> kc(kernel,tolerance);
// use the linearly_independent_subset_finder object to select the centers. So here
// we show it all the data samples so it can find the best centers.
linearly_independent_subset_finder<kernel_type> lisf(kernel, num_centers);
for (long i = 0; i < x.size(); ++i)
{
kc.train(x(i));
lisf.add(x(i));
}
// now we have a trained kcentroid so lets just extract its results. Note that
// all we want out of the kcentroid is really just the set of support vectors
// it contains so that we can use them as the RBF centers.
distance_function<kernel_type> df(kc.get_distance_function());
const long num_centers = df.support_vectors.nr();
const long num_centers = lisf.dictionary_size();
// fill the K matrix with the output of the kernel for all the center and sample point pairs
matrix<scalar_type,0,0,mem_manager_type> K(x.nr(), num_centers+1);
@ -129,7 +128,7 @@ namespace dlib
{
for (long c = 0; c < num_centers; ++c)
{
K(r,c) = kernel(x(r), df.support_vectors(c));
K(r,c) = kernel(x(r), lisf[c]);
}
// This last column of the K matrix takes care of the bias term
K(r,num_centers) = 1;
@ -142,12 +141,12 @@ namespace dlib
return decision_function<kernel_type> (remove_row(weights,num_centers),
-weights(num_centers),
kernel,
df.support_vectors);
lisf.get_dictionary());
}
kernel_type kernel;
scalar_type tolerance;
unsigned long num_centers;
}; // end of class rbf_network_trainer

View File

@ -3,7 +3,6 @@
#undef DLIB_RBf_NETWORK_ABSTRACT_
#ifdef DLIB_RBf_NETWORK_ABSTRACT_
#include "../matrix/matrix_abstract.h"
#include "../algs.h"
#include "function_abstract.h"
#include "kernel_abstract.h"
@ -25,16 +24,16 @@ namespace dlib
to use some sort of radial basis kernel)
INITIAL VALUE
- get_gamma() == 0.1
- get_tolerance() == 0.1
- get_num_centers() == 10
WHAT THIS OBJECT REPRESENTS
This object implements a trainer for an radial basis function network.
This object implements a trainer for a radial basis function network.
The implementation of this algorithm follows the normal RBF training
process. For more details see the code or the Wikipedia article
about RBF networks.
!*/
public:
typedef K kernel_type;
typedef typename kernel_type::scalar_type scalar_type;
@ -64,22 +63,20 @@ namespace dlib
- returns a copy of the kernel function in use by this object
!*/
void set_tolerance (
const scalar_type& tol
void set_num_centers (
const unsigned long num_centers
);
/*!
ensures
- #get_tolerance() == tol
- #get_num_centers() == num_centers
!*/
const scalar_type& get_tolerance (
const unsigned long get_num_centers (
) const;
/*!
ensures
- returns the tolerance parameter. This parameter controls how many
RBF centers (a.k.a. support_vectors in the trained decision_function)
you get when you call the train function. A smaller tolerance
results in more centers while a bigger number results in fewer.
- returns the number of centers (a.k.a. support_vectors in the
trained decision_function) you will get when you train this object on data.
!*/
template <
@ -118,10 +115,10 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <typename sample_type>
template <typename K>
void swap (
rbf_network_trainer<sample_type>& a,
rbf_network_trainer<sample_type>& b
rbf_network_trainer<K>& a,
rbf_network_trainer<K>& b
) { a.swap(b); }
/*!
provides a global swap