mirror of https://github.com/davisking/dlib.git
Changed the rbf_network_trainer to use the linearly_independent_subset_finder
to find centers. --HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402413
This commit is contained in:
parent
ec316ccb1d
commit
c8d9f20bca
|
@ -6,7 +6,7 @@
|
|||
#include "../matrix.h"
|
||||
#include "rbf_network_abstract.h"
|
||||
#include "kernel.h"
|
||||
#include "kcentroid.h"
|
||||
#include "linearly_independent_subset_finder.h"
|
||||
#include "function.h"
|
||||
#include "../algs.h"
|
||||
|
||||
|
@ -23,7 +23,8 @@ namespace dlib
|
|||
/*!
|
||||
This is an implemenation of an RBF network trainer that follows
|
||||
the directions right off Wikipedia basically. So nothing
|
||||
particularly fancy.
|
||||
particularly fancy. Although the way the centers are selected
|
||||
is somewhat unique.
|
||||
!*/
|
||||
|
||||
public:
|
||||
|
@ -35,7 +36,7 @@ namespace dlib
|
|||
|
||||
rbf_network_trainer (
|
||||
) :
|
||||
tolerance(0.1)
|
||||
num_centers(10)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -52,17 +53,17 @@ namespace dlib
|
|||
return kernel;
|
||||
}
|
||||
|
||||
void set_tolerance (
|
||||
const scalar_type& tol
|
||||
void set_num_centers (
|
||||
const unsigned long num
|
||||
)
|
||||
{
|
||||
tolerance = tol;
|
||||
num_centers = num;
|
||||
}
|
||||
|
||||
const scalar_type& get_tolerance (
|
||||
const unsigned long get_num_centers (
|
||||
) const
|
||||
{
|
||||
return tolerance;
|
||||
return num_centers;
|
||||
}
|
||||
|
||||
template <
|
||||
|
@ -82,7 +83,7 @@ namespace dlib
|
|||
)
|
||||
{
|
||||
exchange(kernel, item.kernel);
|
||||
exchange(tolerance, item.tolerance);
|
||||
exchange(num_centers, item.num_centers);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -99,6 +100,7 @@ namespace dlib
|
|||
) const
|
||||
{
|
||||
typedef typename decision_function<kernel_type>::scalar_vector_type scalar_vector_type;
|
||||
typedef typename decision_function<kernel_type>::sample_vector_type sample_vector_type;
|
||||
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(x.nr() > 1 && x.nr() == y.nr() && x.nc() == 1 && y.nc() == 1,
|
||||
|
@ -110,18 +112,15 @@ namespace dlib
|
|||
<< "\n\t y.nc(): " << y.nc()
|
||||
);
|
||||
|
||||
// first run all the sampes through a kcentroid object to find the rbf centers
|
||||
kcentroid<kernel_type> kc(kernel,tolerance);
|
||||
// use the linearly_independent_subset_finder object to select the centers. So here
|
||||
// we show it all the data samples so it can find the best centers.
|
||||
linearly_independent_subset_finder<kernel_type> lisf(kernel, num_centers);
|
||||
for (long i = 0; i < x.size(); ++i)
|
||||
{
|
||||
kc.train(x(i));
|
||||
lisf.add(x(i));
|
||||
}
|
||||
|
||||
// now we have a trained kcentroid so lets just extract its results. Note that
|
||||
// all we want out of the kcentroid is really just the set of support vectors
|
||||
// it contains so that we can use them as the RBF centers.
|
||||
distance_function<kernel_type> df(kc.get_distance_function());
|
||||
const long num_centers = df.support_vectors.nr();
|
||||
const long num_centers = lisf.dictionary_size();
|
||||
|
||||
// fill the K matrix with the output of the kernel for all the center and sample point pairs
|
||||
matrix<scalar_type,0,0,mem_manager_type> K(x.nr(), num_centers+1);
|
||||
|
@ -129,7 +128,7 @@ namespace dlib
|
|||
{
|
||||
for (long c = 0; c < num_centers; ++c)
|
||||
{
|
||||
K(r,c) = kernel(x(r), df.support_vectors(c));
|
||||
K(r,c) = kernel(x(r), lisf[c]);
|
||||
}
|
||||
// This last column of the K matrix takes care of the bias term
|
||||
K(r,num_centers) = 1;
|
||||
|
@ -142,12 +141,12 @@ namespace dlib
|
|||
return decision_function<kernel_type> (remove_row(weights,num_centers),
|
||||
-weights(num_centers),
|
||||
kernel,
|
||||
df.support_vectors);
|
||||
lisf.get_dictionary());
|
||||
|
||||
}
|
||||
|
||||
kernel_type kernel;
|
||||
scalar_type tolerance;
|
||||
unsigned long num_centers;
|
||||
|
||||
}; // end of class rbf_network_trainer
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#undef DLIB_RBf_NETWORK_ABSTRACT_
|
||||
#ifdef DLIB_RBf_NETWORK_ABSTRACT_
|
||||
|
||||
#include "../matrix/matrix_abstract.h"
|
||||
#include "../algs.h"
|
||||
#include "function_abstract.h"
|
||||
#include "kernel_abstract.h"
|
||||
|
@ -25,16 +24,16 @@ namespace dlib
|
|||
to use some sort of radial basis kernel)
|
||||
|
||||
INITIAL VALUE
|
||||
- get_gamma() == 0.1
|
||||
- get_tolerance() == 0.1
|
||||
- get_num_centers() == 10
|
||||
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This object implements a trainer for an radial basis function network.
|
||||
This object implements a trainer for a radial basis function network.
|
||||
|
||||
The implementation of this algorithm follows the normal RBF training
|
||||
process. For more details see the code or the Wikipedia article
|
||||
about RBF networks.
|
||||
!*/
|
||||
|
||||
public:
|
||||
typedef K kernel_type;
|
||||
typedef typename kernel_type::scalar_type scalar_type;
|
||||
|
@ -64,22 +63,20 @@ namespace dlib
|
|||
- returns a copy of the kernel function in use by this object
|
||||
!*/
|
||||
|
||||
void set_tolerance (
|
||||
const scalar_type& tol
|
||||
void set_num_centers (
|
||||
const unsigned long num_centers
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_tolerance() == tol
|
||||
- #get_num_centers() == num_centers
|
||||
!*/
|
||||
|
||||
const scalar_type& get_tolerance (
|
||||
const unsigned long get_num_centers (
|
||||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the tolerance parameter. This parameter controls how many
|
||||
RBF centers (a.k.a. support_vectors in the trained decision_function)
|
||||
you get when you call the train function. A smaller tolerance
|
||||
results in more centers while a bigger number results in fewer.
|
||||
- returns the number of centers (a.k.a. support_vectors in the
|
||||
trained decision_function) you will get when you train this object on data.
|
||||
!*/
|
||||
|
||||
template <
|
||||
|
@ -118,10 +115,10 @@ namespace dlib
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename sample_type>
|
||||
template <typename K>
|
||||
void swap (
|
||||
rbf_network_trainer<sample_type>& a,
|
||||
rbf_network_trainer<sample_type>& b
|
||||
rbf_network_trainer<K>& a,
|
||||
rbf_network_trainer<K>& b
|
||||
) { a.swap(b); }
|
||||
/*!
|
||||
provides a global swap
|
||||
|
|
Loading…
Reference in New Issue