mirror of https://github.com/davisking/dlib.git
Added the pick_initial_centers() function
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402320
This commit is contained in:
parent
1e2a3615cf
commit
842bdb4a49
|
@ -13,6 +13,7 @@
|
||||||
#include "kkmeans_abstract.h"
|
#include "kkmeans_abstract.h"
|
||||||
#include "../noncopyable.h"
|
#include "../noncopyable.h"
|
||||||
#include "../smart_pointers.h"
|
#include "../smart_pointers.h"
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace dlib
|
namespace dlib
|
||||||
{
|
{
|
||||||
|
@ -224,6 +225,74 @@ namespace dlib
|
||||||
void swap(kkmeans<kernel_type>& a, kkmeans<kernel_type>& b)
|
void swap(kkmeans<kernel_type>& a, kkmeans<kernel_type>& b)
|
||||||
{ a.swap(b); }
|
{ a.swap(b); }
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
struct dlib_pick_initial_centers_data
|
||||||
|
{
|
||||||
|
dlib_pick_initial_centers_data():idx(0), dist(1e200){}
|
||||||
|
long idx;
|
||||||
|
double dist;
|
||||||
|
bool operator< (const dlib_pick_initial_centers_data& d) const { return dist < d.dist; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename vector_type,
|
||||||
|
typename kernel_type
|
||||||
|
>
|
||||||
|
void pick_initial_centers(
|
||||||
|
long num_centers,
|
||||||
|
vector_type& centers,
|
||||||
|
const vector_type& samples,
|
||||||
|
const kernel_type& k,
|
||||||
|
double percentile = 0.01
|
||||||
|
)
|
||||||
|
{
|
||||||
|
// make sure requires clause is not broken
|
||||||
|
DLIB_CASSERT(num_centers > 1 && 0 <= percentile && percentile < 1 && samples.size() > 1,
|
||||||
|
"\tvoid pick_initial_centers()"
|
||||||
|
<< "\n\tYou passed invalid arguments to this function"
|
||||||
|
<< "\n\tnum_centers: " << num_centers
|
||||||
|
<< "\n\tpercentile: " << percentile
|
||||||
|
<< "\n\tsamples.size(): " << samples.size()
|
||||||
|
);
|
||||||
|
|
||||||
|
std::vector<dlib_pick_initial_centers_data> scores(samples.size());
|
||||||
|
std::vector<dlib_pick_initial_centers_data> scores_sorted(samples.size());
|
||||||
|
centers.clear();
|
||||||
|
|
||||||
|
// pick the first sample as one of the centers
|
||||||
|
centers.push_back(samples[0]);
|
||||||
|
|
||||||
|
const long best_idx = samples.size() - samples.size()*percentile - 1;
|
||||||
|
|
||||||
|
// pick the next center
|
||||||
|
for (long i = 0; i < num_centers-1; ++i)
|
||||||
|
{
|
||||||
|
// Loop over the samples and compare them to the most recent center. Store
|
||||||
|
// the distance from each sample to its closest center in scores.
|
||||||
|
const double k_cc = k(centers[i], centers[i]);
|
||||||
|
for (unsigned long s = 0; s < samples.size(); ++s)
|
||||||
|
{
|
||||||
|
// compute the distance between this sample and the current center
|
||||||
|
const double dist = k_cc + k(samples[s],samples[s]) - 2*k(samples[s], centers[i]);
|
||||||
|
|
||||||
|
if (dist < scores[s].dist)
|
||||||
|
{
|
||||||
|
scores[s].dist = dist;
|
||||||
|
scores[s].idx = s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
scores_sorted = scores;
|
||||||
|
|
||||||
|
// now find the winning center and add it to centers. It is the one that is
|
||||||
|
// far away from all the other centers.
|
||||||
|
sort(scores_sorted.begin(), scores_sorted.end());
|
||||||
|
centers.push_back(samples[scores_sorted[best_idx].idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -176,6 +176,38 @@ namespace dlib
|
||||||
provides serialization support for kkmeans objects
|
provides serialization support for kkmeans objects
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename vector_type,
|
||||||
|
typename kernel_type
|
||||||
|
>
|
||||||
|
void pick_initial_centers(
|
||||||
|
long num_centers,
|
||||||
|
vector_type& centers,
|
||||||
|
const vector_type& samples,
|
||||||
|
const kernel_type& k,
|
||||||
|
double percentile = 0.01
|
||||||
|
);
|
||||||
|
/*!
|
||||||
|
requires
|
||||||
|
- num_centers > 1
|
||||||
|
- 0 <= percentile < 1
|
||||||
|
- samples.size() > 1
|
||||||
|
- vector_type == something with an interface compatible with std::vector
|
||||||
|
- k(samples[0],samples[0]) must be a valid expression that returns a double
|
||||||
|
ensures
|
||||||
|
- finds num_centers candidate cluster centers in the data in the samples
|
||||||
|
vector. Assumes that k is the kernel that will be used during clustering
|
||||||
|
to define the space in which clustering occurs.
|
||||||
|
- The centers are found by looking for points that are far away from other
|
||||||
|
candidate centers. However, if the data is noisy you probably want to
|
||||||
|
ignore the farthest way points since they will be outliers. To do this
|
||||||
|
set percentile to the fraction of outliers you expect the data to contain.
|
||||||
|
- #centers.size() == num_centers
|
||||||
|
- #centers == a vector containing the candidate centers found
|
||||||
|
!*/
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue