Added these new functions: remove_long_edges(), remove_percent_longest_edges(),

remove_short_edges(), and remove_percent_shortest_edges().   I also reworked
the graph creation functions to make them a little more versatile.  Now
you can use infinite distances to indicate that certain nodes are not
connected at all.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403657
This commit is contained in:
Davis King 2010-05-30 19:42:06 +00:00
parent ccbdf520b7
commit 260c893c7f
2 changed files with 277 additions and 49 deletions

View File

@ -63,8 +63,7 @@ namespace dlib
)
{
// make sure requires clause is not broken
DLIB_ASSERT(samples.size() > 1 &&
0 < percent && percent <= 1 &&
DLIB_ASSERT( 0 < percent && percent <= 1 &&
num > 0,
"\t void find_percent_shortest_edges_randomly()"
<< "\n\t Invalid inputs were given to this function."
@ -73,6 +72,12 @@ namespace dlib
<< "\n\t num: " << num
);
out.clear();
if (samples.size() <= 1)
{
return;
}
std::vector<sample_pair, alloc> edges;
edges.reserve(num);
@ -81,37 +86,45 @@ namespace dlib
rnd.set_seed(cast_to_string(random_seed));
// randomly sample a bunch of edges
while (edges.size() < num)
for (unsigned long i = 0; i < num; ++i)
{
const unsigned long idx1 = rnd.get_random_32bit_number()%samples.size();
const unsigned long idx2 = rnd.get_random_32bit_number()%samples.size();
if (idx1 != idx2)
{
edges.push_back(sample_pair(idx1, idx2, dist_funct(samples[idx1], samples[idx2])));
const float dist = dist_funct(samples[idx1], samples[idx2]);
if (dist < std::numeric_limits<float>::infinity())
{
edges.push_back(sample_pair(idx1, idx2, dist));
}
}
}
// sort the edges so that duplicate edges will be adjacent
std::sort(edges.begin(), edges.end(), &order_by_index);
// now put edges into out while avoiding duplicates
out.clear();
out.reserve(edges.size());
out.push_back(edges[0]);
for (unsigned long i = 1; i < edges.size(); ++i)
if (edges.size() > 0)
{
if (edges[i] != edges[i-1])
// sort the edges so that duplicate edges will be adjacent
std::sort(edges.begin(), edges.end(), &order_by_index);
out.reserve(edges.size());
out.push_back(edges[0]);
for (unsigned long i = 1; i < edges.size(); ++i)
{
out.push_back(edges[i]);
if (edges[i] != edges[i-1])
{
out.push_back(edges[i]);
}
}
// now sort all the edges by distance and take the percent with the smallest distance
std::sort(out.begin(), out.end(), &order_by_distance);
out.swap(edges);
const unsigned long out_size = std::min<unsigned long>(num*percent, edges.size());
out.assign(edges.begin(), edges.begin() + out_size);
}
// now sort all the edges by distance and take the percent with the smallest distance
std::sort(out.begin(), out.end(), &order_by_distance);
out.swap(edges);
out.assign(edges.begin(), edges.begin() + edges.size()*percent);
}
// ----------------------------------------------------------------------------------------
@ -175,8 +188,7 @@ namespace dlib
)
{
// make sure requires clause is not broken
DLIB_ASSERT(samples.size() > 1 &&
num > 0 && k > 0,
DLIB_ASSERT( num > 0 && k > 0,
"\t void find_approximate_k_nearest_neighbors()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t samples.size(): " << samples.size()
@ -184,6 +196,13 @@ namespace dlib
<< "\n\t num: " << num
);
out.clear();
if (samples.size() <= 1)
{
return;
}
// we add each edge twice in the following loop. So multiply num by 2 to account for that.
num *= 2;
@ -196,16 +215,18 @@ namespace dlib
rnd.set_seed(cast_to_string(random_seed));
// randomly sample a bunch of edges
while (edges.size() < num)
for (unsigned long i = 0; i < num; ++i)
{
const unsigned long idx1 = rnd.get_random_32bit_number()%samples.size();
const unsigned long idx2 = rnd.get_random_32bit_number()%samples.size();
if (idx1 != idx2)
{
const float dist = dist_funct(samples[idx1], samples[idx2]);
edges.push_back(impl2::helper(idx1, idx2, dist));
edges.push_back(impl2::helper(idx2, idx1, dist));
if (dist < std::numeric_limits<float>::infinity())
{
edges.push_back(impl2::helper(idx1, idx2, dist));
edges.push_back(impl2::helper(idx2, idx1, dist));
}
}
}
@ -251,14 +272,16 @@ namespace dlib
// now put edges into out while avoiding duplicates
out.clear();
out.reserve(temp.size());
out.push_back(temp[0]);
for (unsigned long i = 1; i < temp.size(); ++i)
if (temp.size() > 0)
{
if (temp[i] != temp[i-1])
out.reserve(temp.size());
out.push_back(temp[0]);
for (unsigned long i = 1; i < temp.size(); ++i)
{
out.push_back(temp[i]);
if (temp[i] != temp[i-1])
{
out.push_back(temp[i]);
}
}
}
}
@ -278,19 +301,29 @@ namespace dlib
)
{
// make sure requires clause is not broken
DLIB_ASSERT(samples.size() > k && k > 0,
DLIB_ASSERT(k > 0,
"\t void find_k_nearest_neighbors()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t samples.size(): " << samples.size()
<< "\n\t k: " << k
);
out.clear();
if (samples.size() <= 1)
{
return;
}
using namespace impl;
std::vector<sample_pair> edges;
edges.resize(samples.size()*k);
// Initialize all the edges to an edge with an invalid index
edges.resize(samples.size()*k,
sample_pair(samples.size(),samples.size(),std::numeric_limits<float>::infinity()));
std::vector<float> worst_dists(samples.size(), std::numeric_limits<float>::max());
// Hold the length for the longest edge for each node. Initially they are all infinity.
std::vector<float> worst_dists(samples.size(), std::numeric_limits<float>::infinity());
std::vector<sample_pair>::iterator begin_i, end_i, begin_j, end_j, itr;
begin_i = edges.begin();
@ -332,18 +365,27 @@ namespace dlib
// sort the edges so that duplicate edges will be adjacent
std::sort(edges.begin(), edges.end(), &order_by_index);
// now put edges into out while avoiding duplicates
out.clear();
out.reserve(edges.size());
out.push_back(edges[0]);
for (unsigned long i = 1; i < edges.size(); ++i)
// if the first edge is valid
if (edges[0].index1() < samples.size())
{
if (edges[i] != edges[i-1])
// now put edges into out while avoiding duplicates and any remaining invalid edges.
out.reserve(edges.size());
out.push_back(edges[0]);
for (unsigned long i = 1; i < edges.size(); ++i)
{
out.push_back(edges[i]);
// if we hit an invalid edge then we can stop
if (edges[i].index1() >= samples.size())
break;
// if this isn't a duplicate edge
if (edges[i] != edges[i-1])
{
out.push_back(edges[i]);
}
}
}
}
// ----------------------------------------------------------------------------------------
@ -396,6 +438,110 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
void remove_long_edges (
vector_type& pairs,
float distance_threshold
)
{
vector_type temp;
temp.reserve(pairs.size());
// add all the pairs shorter than the given threshold into temp
for (unsigned long i = 0; i < pairs.size(); ++i)
{
if (pairs[i].distance() <= distance_threshold)
temp.push_back(pairs[i]);
}
// move temp into the output vector
temp.swap(pairs);
}
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
void remove_short_edges (
vector_type& pairs,
float distance_threshold
)
{
vector_type temp;
temp.reserve(pairs.size());
// add all the pairs longer than the given threshold into temp
for (unsigned long i = 0; i < pairs.size(); ++i)
{
if (pairs[i].distance() >= distance_threshold)
temp.push_back(pairs[i]);
}
// move temp into the output vector
temp.swap(pairs);
}
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
void remove_percent_longest_edges (
vector_type& pairs,
double percent
)
{
// make sure requires clause is not broken
DLIB_ASSERT( 0 <= percent && percent < 1,
"\t void remove_percent_longest_edges()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t percent: " << percent
);
std::sort(pairs.begin(), pairs.end(), &order_by_distance);
const unsigned long num = static_cast<unsigned long>((1.0-percent)*pairs.size());
// pick out the num shortest pairs
vector_type temp(pairs.begin(), pairs.begin() + num);
// move temp into the output vector
temp.swap(pairs);
}
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
void remove_percent_shortest_edges (
vector_type& pairs,
double percent
)
{
// make sure requires clause is not broken
DLIB_ASSERT( 0 <= percent && percent < 1,
"\t void remove_percent_shortest_edges()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t percent: " << percent
);
std::sort(pairs.rbegin(), pairs.rend(), &order_by_distance);
const unsigned long num = static_cast<unsigned long>((1.0-percent)*pairs.size());
// pick out the num shortest pairs
vector_type temp(pairs.begin(), pairs.begin() + num);
// move temp into the output vector
temp.swap(pairs);
}
// ----------------------------------------------------------------------------------------
}

View File

@ -28,7 +28,6 @@ namespace dlib
);
/*!
requires
- samples.size() > 1
- 0 < percent <= 1
- num > 0
- random_seed must be convertible to a string by dlib::cast_to_string()
@ -39,13 +38,14 @@ namespace dlib
0 and samples.size()-1 inclusive. For each of these pairs, (i,j), a
sample_pair is created as follows:
sample_pair(i, j, dist_funct(samples[i], samples[j]))
num such sample_pair objects are generated, duplicates are removed, and
then the top percent of them with the smallest distance are stored into
out.
num such sample_pair objects are generated, duplicates and pairs with distance
values == infinity are removed, and then the top percent of them with the
smallest distance are stored into out.
- #out.size() <= num*percent
- contains_duplicate_pairs(#out) == false
- for all valid i:
- #out[i].distance() == dist_funct(samples[#out[i].index1()], samples[#out[i].index2()])
- #out[i].distance() < std::numeric_limits<float>::infinity()
- random_seed is used to seed the random number generator used by this
function.
!*/
@ -68,7 +68,6 @@ namespace dlib
);
/*!
requires
- samples.size() > 1
- k > 0
- num > 0
- random_seed must be convertible to a string by dlib::cast_to_string()
@ -84,9 +83,12 @@ namespace dlib
sample_pair(i, j, dist_funct(samples[i], samples[j]))
num such sample_pair objects are generated and then exact k-nearest-neighbors
is performed amongst these sample_pairs and the results are stored into #out.
Note that samples with an infinite distance between them are considered to
be not connected at all.
- contains_duplicate_pairs(#out) == false
- for all valid i:
- #out[i].distance() == dist_funct(samples[#out[i].index1()], samples[#out[i].index2()])
- #out[i].distance() < std::numeric_limits<float>::infinity()
- random_seed is used to seed the random number generator used by this
function.
!*/
@ -106,15 +108,17 @@ namespace dlib
);
/*!
requires
- samples.size() > k
- k > 0
- dist_funct(samples[i], samples[j]) must be a valid expression that evaluates
to a floating point number
ensures
- #out == a set of sample_pair objects that represent all the k nearest
neighbors in samples according to the given distance function dist_funct.
- #out == a set of sample_pair objects that represent all the k nearest
neighbors in samples according to the given distance function dist_funct.
Note that samples with an infinite distance between them are considered to
be not connected at all.
- for all valid i:
- #out[i].distance() == dist_funct(samples[#out[i].index1()], samples[#out[i].index2()])
- #out[i].distance() < std::numeric_limits<float>::infinity()
- contains_duplicate_pairs(#out) == false
!*/
@ -158,6 +162,84 @@ namespace dlib
- for some j: pairs[j].index1()+1 == N || pairs[j].index2()+1 == N
!*/
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
void remove_long_edges (
vector_type& pairs,
float distance_threshold
);
/*!
requires
- vector_type == a type with an interface compatible with std::vector and
it must in turn contain objects with an interface compatible with dlib::sample_pair
ensures
- Removes all elements of pairs that have a distance value greater than the
given threshold.
- #pairs.size() <= pairs.size()
!*/
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
void remove_short_edges (
vector_type& pairs,
float distance_threshold
);
/*!
requires
- vector_type == a type with an interface compatible with std::vector and
it must in turn contain objects with an interface compatible with dlib::sample_pair
ensures
- Removes all elements of pairs that have a distance value less than the
given threshold.
- #pairs.size() <= pairs.size()
!*/
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
void remove_percent_longest_edges (
vector_type& pairs,
double percent
);
/*!
requires
- 0 <= percent < 1
- vector_type == a type with an interface compatible with std::vector and
it must in turn contain objects with an interface compatible with dlib::sample_pair
ensures
- Removes the given upper percentage of the longest edges in pairs. I.e.
this function removes the long edges from pairs.
- #pairs.size() == (1-percent)*pairs.size()
!*/
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
void remove_percent_shortest_edges (
vector_type& pairs,
double percent
);
/*!
requires
- 0 <= percent < 1
- vector_type == a type with an interface compatible with std::vector and
it must in turn contain objects with an interface compatible with dlib::sample_pair
ensures
- Removes the given upper percentage of the shortest edges in pairs. I.e.
this function removes the short edges from pairs.
- #pairs.size() == (1-percent)*pairs.size()
!*/
// ----------------------------------------------------------------------------------------
}