Added the pick_initial_centers() function

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402320
This commit is contained in:
Davis King 2008-06-15 15:13:41 +00:00
parent 1e2a3615cf
commit 842bdb4a49
2 changed files with 101 additions and 0 deletions

View File

@ -13,6 +13,7 @@
#include "kkmeans_abstract.h" #include "kkmeans_abstract.h"
#include "../noncopyable.h" #include "../noncopyable.h"
#include "../smart_pointers.h" #include "../smart_pointers.h"
#include <vector>
namespace dlib namespace dlib
{ {
@ -224,6 +225,74 @@ namespace dlib
void swap(kkmeans<kernel_type>& a, kkmeans<kernel_type>& b) void swap(kkmeans<kernel_type>& a, kkmeans<kernel_type>& b)
{ a.swap(b); } { a.swap(b); }
// ----------------------------------------------------------------------------------------
struct dlib_pick_initial_centers_data
{
dlib_pick_initial_centers_data():idx(0), dist(1e200){}
long idx;
double dist;
bool operator< (const dlib_pick_initial_centers_data& d) const { return dist < d.dist; }
};
template <
typename vector_type,
typename kernel_type
>
void pick_initial_centers(
long num_centers,
vector_type& centers,
const vector_type& samples,
const kernel_type& k,
double percentile = 0.01
)
{
// make sure requires clause is not broken
DLIB_CASSERT(num_centers > 1 && 0 <= percentile && percentile < 1 && samples.size() > 1,
"\tvoid pick_initial_centers()"
<< "\n\tYou passed invalid arguments to this function"
<< "\n\tnum_centers: " << num_centers
<< "\n\tpercentile: " << percentile
<< "\n\tsamples.size(): " << samples.size()
);
std::vector<dlib_pick_initial_centers_data> scores(samples.size());
std::vector<dlib_pick_initial_centers_data> scores_sorted(samples.size());
centers.clear();
// pick the first sample as one of the centers
centers.push_back(samples[0]);
const long best_idx = samples.size() - samples.size()*percentile - 1;
// pick the next center
for (long i = 0; i < num_centers-1; ++i)
{
// Loop over the samples and compare them to the most recent center. Store
// the distance from each sample to its closest center in scores.
const double k_cc = k(centers[i], centers[i]);
for (unsigned long s = 0; s < samples.size(); ++s)
{
// compute the distance between this sample and the current center
const double dist = k_cc + k(samples[s],samples[s]) - 2*k(samples[s], centers[i]);
if (dist < scores[s].dist)
{
scores[s].dist = dist;
scores[s].idx = s;
}
}
scores_sorted = scores;
// now find the winning center and add it to centers. It is the one that is
// far away from all the other centers.
sort(scores_sorted.begin(), scores_sorted.end());
centers.push_back(samples[scores_sorted[best_idx].idx]);
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }

View File

@ -176,6 +176,38 @@ namespace dlib
provides serialization support for kkmeans objects provides serialization support for kkmeans objects
!*/ !*/
// ----------------------------------------------------------------------------------------
template <
typename vector_type,
typename kernel_type
>
void pick_initial_centers(
long num_centers,
vector_type& centers,
const vector_type& samples,
const kernel_type& k,
double percentile = 0.01
);
/*!
requires
- num_centers > 1
- 0 <= percentile < 1
- samples.size() > 1
- vector_type == something with an interface compatible with std::vector
- k(samples[0],samples[0]) must be a valid expression that returns a double
ensures
- finds num_centers candidate cluster centers in the data in the samples
vector. Assumes that k is the kernel that will be used during clustering
to define the space in which clustering occurs.
- The centers are found by looking for points that are far away from other
candidate centers. However, if the data is noisy you probably want to
ignore the farthest way points since they will be outliers. To do this
set percentile to the fraction of outliers you expect the data to contain.
- #centers.size() == num_centers
- #centers == a vector containing the candidate centers found
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }