diff --git a/dlib/svm/kkmeans.h b/dlib/svm/kkmeans.h index 325c8115b..c83d6c16b 100644 --- a/dlib/svm/kkmeans.h +++ b/dlib/svm/kkmeans.h @@ -13,6 +13,7 @@ #include "kkmeans_abstract.h" #include "../noncopyable.h" #include "../smart_pointers.h" +#include namespace dlib { @@ -224,6 +225,74 @@ namespace dlib void swap(kkmeans& a, kkmeans& b) { a.swap(b); } +// ---------------------------------------------------------------------------------------- + + struct dlib_pick_initial_centers_data + { + dlib_pick_initial_centers_data():idx(0), dist(1e200){} + long idx; + double dist; + bool operator< (const dlib_pick_initial_centers_data& d) const { return dist < d.dist; } + }; + + template < + typename vector_type, + typename kernel_type + > + void pick_initial_centers( + long num_centers, + vector_type& centers, + const vector_type& samples, + const kernel_type& k, + double percentile = 0.01 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(num_centers > 1 && 0 <= percentile && percentile < 1 && samples.size() > 1, + "\tvoid pick_initial_centers()" + << "\n\tYou passed invalid arguments to this function" + << "\n\tnum_centers: " << num_centers + << "\n\tpercentile: " << percentile + << "\n\tsamples.size(): " << samples.size() + ); + + std::vector scores(samples.size()); + std::vector scores_sorted(samples.size()); + centers.clear(); + + // pick the first sample as one of the centers + centers.push_back(samples[0]); + + const long best_idx = samples.size() - samples.size()*percentile - 1; + + // pick the next center + for (long i = 0; i < num_centers-1; ++i) + { + // Loop over the samples and compare them to the most recent center. Store + // the distance from each sample to its closest center in scores. + const double k_cc = k(centers[i], centers[i]); + for (unsigned long s = 0; s < samples.size(); ++s) + { + // compute the distance between this sample and the current center + const double dist = k_cc + k(samples[s],samples[s]) - 2*k(samples[s], centers[i]); + + if (dist < scores[s].dist) + { + scores[s].dist = dist; + scores[s].idx = s; + } + } + + scores_sorted = scores; + + // now find the winning center and add it to centers. It is the one that is + // far away from all the other centers. + sort(scores_sorted.begin(), scores_sorted.end()); + centers.push_back(samples[scores_sorted[best_idx].idx]); + } + + } + // ---------------------------------------------------------------------------------------- } diff --git a/dlib/svm/kkmeans_abstract.h b/dlib/svm/kkmeans_abstract.h index a8adde00b..9a38584ab 100644 --- a/dlib/svm/kkmeans_abstract.h +++ b/dlib/svm/kkmeans_abstract.h @@ -176,6 +176,38 @@ namespace dlib provides serialization support for kkmeans objects !*/ +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename kernel_type + > + void pick_initial_centers( + long num_centers, + vector_type& centers, + const vector_type& samples, + const kernel_type& k, + double percentile = 0.01 + ); + /*! + requires + - num_centers > 1 + - 0 <= percentile < 1 + - samples.size() > 1 + - vector_type == something with an interface compatible with std::vector + - k(samples[0],samples[0]) must be a valid expression that returns a double + ensures + - finds num_centers candidate cluster centers in the data in the samples + vector. Assumes that k is the kernel that will be used during clustering + to define the space in which clustering occurs. + - The centers are found by looking for points that are far away from other + candidate centers. However, if the data is noisy you probably want to + ignore the farthest way points since they will be outliers. To do this + set percentile to the fraction of outliers you expect the data to contain. + - #centers.size() == num_centers + - #centers == a vector containing the candidate centers found + !*/ + // ---------------------------------------------------------------------------------------- }