This commit is contained in:
Davis King 2019-02-20 07:48:50 -05:00
commit 12ca8ad67f
3 changed files with 92 additions and 1 deletions

View File

@ -1135,7 +1135,7 @@ namespace dlib
// Prevent calls to tensor_to_dets() from running for a really long time // Prevent calls to tensor_to_dets() from running for a really long time
// due to the production of an obscene number of detections. // due to the production of an obscene number of detections.
const unsigned long max_num_initial_dets = max_num_dets*100; const unsigned long max_num_initial_dets = max_num_dets*100;
if (dets.size() >= max_num_initial_dets) if (dets.size() > max_num_initial_dets)
{ {
det_thresh_speed_adjust = std::max(det_thresh_speed_adjust,dets[max_num_initial_dets].detection_confidence + options.loss_per_false_alarm); det_thresh_speed_adjust = std::max(det_thresh_speed_adjust,dets[max_num_initial_dets].detection_confidence + options.loss_per_false_alarm);
} }

View File

@ -206,6 +206,36 @@ py::list chinese_whispers_clustering(py::list descriptors, float threshold)
return clusters; return clusters;
} }
py::list chinese_whispers_raw(py::list edges)
{
py::list clusters;
size_t num_edges = py::len(edges);
std::vector<sample_pair> edges_pairs;
std::vector<unsigned long> labels;
for (size_t idx = 0; idx < num_edges; ++idx)
{
py::tuple t = edges[idx].cast<py::tuple>();
if ((len(t) != 2) && (len(t) != 3))
{
PyErr_SetString( PyExc_IndexError, "Input must be a list of tuples with 2 or 3 elements.");
throw py::error_already_set();
}
size_t i = t[0].cast<size_t>();
size_t j = t[1].cast<size_t>();
double distance = (len(t) == 3) ? t[2].cast<double>(): 1;
edges_pairs.push_back(sample_pair(i, j, distance));
}
chinese_whispers(edges_pairs, labels);
for (size_t i = 0; i < labels.size(); ++i)
{
clusters.append(labels[i]);
}
return clusters;
}
void save_face_chips ( void save_face_chips (
numpy_image<rgb_pixel> img, numpy_image<rgb_pixel> img,
const std::vector<full_object_detection>& faces, const std::vector<full_object_detection>& faces,
@ -296,5 +326,10 @@ void bind_face_recognition(py::module &m)
m.def("chinese_whispers_clustering", &chinese_whispers_clustering, py::arg("descriptors"), py::arg("threshold"), m.def("chinese_whispers_clustering", &chinese_whispers_clustering, py::arg("descriptors"), py::arg("threshold"),
"Takes a list of descriptors and returns a list that contains a label for each descriptor. Clustering is done using dlib::chinese_whispers." "Takes a list of descriptors and returns a list that contains a label for each descriptor. Clustering is done using dlib::chinese_whispers."
); );
m.def("chinese_whispers", &chinese_whispers_raw, py::arg("edges"),
"Given a graph with vertices represented as numbers indexed from 0, this algorithm takes a list of edges and returns back a list that contains a labels (found clusters) for each vertex. "
"Edges are tuples with either 2 elements (integers presenting indexes of connected vertices) or 3 elements, where additional one element is float which presents distance weight of the edge). "
"Offers direct access to dlib::chinese_whispers."
);
} }

View File

@ -0,0 +1,56 @@
from random import Random
from dlib import chinese_whispers
from pytest import raises
def test_chinese_whispers():
assert len(chinese_whispers([])) == 0
assert len(chinese_whispers([(0, 0), (1, 1)])) == 2
# Test that values from edges are actually used and that correct values are returned
labels = chinese_whispers([(0, 0), (0, 1), (1, 1)])
assert len(labels) == 2
assert labels[0] == labels[1]
labels = chinese_whispers([(0, 0), (1, 1)])
assert len(labels) == 2
assert labels[0] != labels[1]
def test_chinese_whispers_with_distance():
assert len(chinese_whispers([(0, 0, 1)])) == 1
assert len(chinese_whispers([(0, 0, 1), (0, 1, 0.5), (1, 1, 1)])) == 2
# Test that values from edges and distances are actually used and that correct values are returned
labels = chinese_whispers([(0, 0, 1), (0, 1, 1), (1, 1, 1)])
assert len(labels) == 2
assert labels[0] == labels[1]
labels = chinese_whispers([(0, 0, 1), (0, 1, 0.0), (1, 1, 1)])
assert len(labels) == 2
assert labels[0] != labels[1]
# Non-trivial test
edges = []
r = Random(0)
for i in range(100):
edges.append((i, i, 1))
edges.append((i, r.randint(0, 99), r.random()))
assert len(chinese_whispers(edges)) == 100
def test_chinese_whispers_type_checks():
"""
Tests contract (expected errors) in case client provides wrong types
"""
with raises(TypeError):
chinese_whispers()
with raises(TypeError):
chinese_whispers('foo')
with raises(RuntimeError):
chinese_whispers(['foo'])
with raises(IndexError):
chinese_whispers([(0,)])
with raises(IndexError):
chinese_whispers([(0, 1, 2, 3)])
with raises(RuntimeError):
chinese_whispers([('foo', 'bar')])