mirror of https://github.com/davisking/dlib.git
merged
This commit is contained in:
commit
12ca8ad67f
|
@ -1135,7 +1135,7 @@ namespace dlib
|
|||
// Prevent calls to tensor_to_dets() from running for a really long time
|
||||
// due to the production of an obscene number of detections.
|
||||
const unsigned long max_num_initial_dets = max_num_dets*100;
|
||||
if (dets.size() >= max_num_initial_dets)
|
||||
if (dets.size() > max_num_initial_dets)
|
||||
{
|
||||
det_thresh_speed_adjust = std::max(det_thresh_speed_adjust,dets[max_num_initial_dets].detection_confidence + options.loss_per_false_alarm);
|
||||
}
|
||||
|
|
|
@ -206,6 +206,36 @@ py::list chinese_whispers_clustering(py::list descriptors, float threshold)
|
|||
return clusters;
|
||||
}
|
||||
|
||||
py::list chinese_whispers_raw(py::list edges)
|
||||
{
|
||||
py::list clusters;
|
||||
size_t num_edges = py::len(edges);
|
||||
|
||||
std::vector<sample_pair> edges_pairs;
|
||||
std::vector<unsigned long> labels;
|
||||
for (size_t idx = 0; idx < num_edges; ++idx)
|
||||
{
|
||||
py::tuple t = edges[idx].cast<py::tuple>();
|
||||
if ((len(t) != 2) && (len(t) != 3))
|
||||
{
|
||||
PyErr_SetString( PyExc_IndexError, "Input must be a list of tuples with 2 or 3 elements.");
|
||||
throw py::error_already_set();
|
||||
}
|
||||
size_t i = t[0].cast<size_t>();
|
||||
size_t j = t[1].cast<size_t>();
|
||||
double distance = (len(t) == 3) ? t[2].cast<double>(): 1;
|
||||
|
||||
edges_pairs.push_back(sample_pair(i, j, distance));
|
||||
}
|
||||
|
||||
chinese_whispers(edges_pairs, labels);
|
||||
for (size_t i = 0; i < labels.size(); ++i)
|
||||
{
|
||||
clusters.append(labels[i]);
|
||||
}
|
||||
return clusters;
|
||||
}
|
||||
|
||||
void save_face_chips (
|
||||
numpy_image<rgb_pixel> img,
|
||||
const std::vector<full_object_detection>& faces,
|
||||
|
@ -296,5 +326,10 @@ void bind_face_recognition(py::module &m)
|
|||
m.def("chinese_whispers_clustering", &chinese_whispers_clustering, py::arg("descriptors"), py::arg("threshold"),
|
||||
"Takes a list of descriptors and returns a list that contains a label for each descriptor. Clustering is done using dlib::chinese_whispers."
|
||||
);
|
||||
m.def("chinese_whispers", &chinese_whispers_raw, py::arg("edges"),
|
||||
"Given a graph with vertices represented as numbers indexed from 0, this algorithm takes a list of edges and returns back a list that contains a labels (found clusters) for each vertex. "
|
||||
"Edges are tuples with either 2 elements (integers presenting indexes of connected vertices) or 3 elements, where additional one element is float which presents distance weight of the edge). "
|
||||
"Offers direct access to dlib::chinese_whispers."
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
from random import Random
|
||||
|
||||
from dlib import chinese_whispers
|
||||
from pytest import raises
|
||||
|
||||
|
||||
def test_chinese_whispers():
|
||||
assert len(chinese_whispers([])) == 0
|
||||
assert len(chinese_whispers([(0, 0), (1, 1)])) == 2
|
||||
|
||||
# Test that values from edges are actually used and that correct values are returned
|
||||
labels = chinese_whispers([(0, 0), (0, 1), (1, 1)])
|
||||
assert len(labels) == 2
|
||||
assert labels[0] == labels[1]
|
||||
labels = chinese_whispers([(0, 0), (1, 1)])
|
||||
assert len(labels) == 2
|
||||
assert labels[0] != labels[1]
|
||||
|
||||
|
||||
def test_chinese_whispers_with_distance():
|
||||
assert len(chinese_whispers([(0, 0, 1)])) == 1
|
||||
assert len(chinese_whispers([(0, 0, 1), (0, 1, 0.5), (1, 1, 1)])) == 2
|
||||
|
||||
# Test that values from edges and distances are actually used and that correct values are returned
|
||||
labels = chinese_whispers([(0, 0, 1), (0, 1, 1), (1, 1, 1)])
|
||||
assert len(labels) == 2
|
||||
assert labels[0] == labels[1]
|
||||
labels = chinese_whispers([(0, 0, 1), (0, 1, 0.0), (1, 1, 1)])
|
||||
assert len(labels) == 2
|
||||
assert labels[0] != labels[1]
|
||||
|
||||
# Non-trivial test
|
||||
edges = []
|
||||
r = Random(0)
|
||||
for i in range(100):
|
||||
edges.append((i, i, 1))
|
||||
edges.append((i, r.randint(0, 99), r.random()))
|
||||
assert len(chinese_whispers(edges)) == 100
|
||||
|
||||
|
||||
def test_chinese_whispers_type_checks():
|
||||
"""
|
||||
Tests contract (expected errors) in case client provides wrong types
|
||||
"""
|
||||
with raises(TypeError):
|
||||
chinese_whispers()
|
||||
with raises(TypeError):
|
||||
chinese_whispers('foo')
|
||||
with raises(RuntimeError):
|
||||
chinese_whispers(['foo'])
|
||||
with raises(IndexError):
|
||||
chinese_whispers([(0,)])
|
||||
with raises(IndexError):
|
||||
chinese_whispers([(0, 1, 2, 3)])
|
||||
with raises(RuntimeError):
|
||||
chinese_whispers([('foo', 'bar')])
|
Loading…
Reference in New Issue