diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index e4776aa43..5b5c53ce8 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -1135,7 +1135,7 @@ namespace dlib // Prevent calls to tensor_to_dets() from running for a really long time // due to the production of an obscene number of detections. const unsigned long max_num_initial_dets = max_num_dets*100; - if (dets.size() >= max_num_initial_dets) + if (dets.size() > max_num_initial_dets) { det_thresh_speed_adjust = std::max(det_thresh_speed_adjust,dets[max_num_initial_dets].detection_confidence + options.loss_per_false_alarm); } diff --git a/tools/python/src/face_recognition.cpp b/tools/python/src/face_recognition.cpp index 03c19b0c9..d0a0b0094 100644 --- a/tools/python/src/face_recognition.cpp +++ b/tools/python/src/face_recognition.cpp @@ -206,6 +206,36 @@ py::list chinese_whispers_clustering(py::list descriptors, float threshold) return clusters; } +py::list chinese_whispers_raw(py::list edges) +{ + py::list clusters; + size_t num_edges = py::len(edges); + + std::vector edges_pairs; + std::vector labels; + for (size_t idx = 0; idx < num_edges; ++idx) + { + py::tuple t = edges[idx].cast(); + if ((len(t) != 2) && (len(t) != 3)) + { + PyErr_SetString( PyExc_IndexError, "Input must be a list of tuples with 2 or 3 elements."); + throw py::error_already_set(); + } + size_t i = t[0].cast(); + size_t j = t[1].cast(); + double distance = (len(t) == 3) ? t[2].cast(): 1; + + edges_pairs.push_back(sample_pair(i, j, distance)); + } + + chinese_whispers(edges_pairs, labels); + for (size_t i = 0; i < labels.size(); ++i) + { + clusters.append(labels[i]); + } + return clusters; +} + void save_face_chips ( numpy_image img, const std::vector& faces, @@ -296,5 +326,10 @@ void bind_face_recognition(py::module &m) m.def("chinese_whispers_clustering", &chinese_whispers_clustering, py::arg("descriptors"), py::arg("threshold"), "Takes a list of descriptors and returns a list that contains a label for each descriptor. Clustering is done using dlib::chinese_whispers." ); + m.def("chinese_whispers", &chinese_whispers_raw, py::arg("edges"), + "Given a graph with vertices represented as numbers indexed from 0, this algorithm takes a list of edges and returns back a list that contains a labels (found clusters) for each vertex. " + "Edges are tuples with either 2 elements (integers presenting indexes of connected vertices) or 3 elements, where additional one element is float which presents distance weight of the edge). " + "Offers direct access to dlib::chinese_whispers." + ); } diff --git a/tools/python/test/test_chinese_whispers.py b/tools/python/test/test_chinese_whispers.py new file mode 100644 index 000000000..0e6b07854 --- /dev/null +++ b/tools/python/test/test_chinese_whispers.py @@ -0,0 +1,56 @@ +from random import Random + +from dlib import chinese_whispers +from pytest import raises + + +def test_chinese_whispers(): + assert len(chinese_whispers([])) == 0 + assert len(chinese_whispers([(0, 0), (1, 1)])) == 2 + + # Test that values from edges are actually used and that correct values are returned + labels = chinese_whispers([(0, 0), (0, 1), (1, 1)]) + assert len(labels) == 2 + assert labels[0] == labels[1] + labels = chinese_whispers([(0, 0), (1, 1)]) + assert len(labels) == 2 + assert labels[0] != labels[1] + + +def test_chinese_whispers_with_distance(): + assert len(chinese_whispers([(0, 0, 1)])) == 1 + assert len(chinese_whispers([(0, 0, 1), (0, 1, 0.5), (1, 1, 1)])) == 2 + + # Test that values from edges and distances are actually used and that correct values are returned + labels = chinese_whispers([(0, 0, 1), (0, 1, 1), (1, 1, 1)]) + assert len(labels) == 2 + assert labels[0] == labels[1] + labels = chinese_whispers([(0, 0, 1), (0, 1, 0.0), (1, 1, 1)]) + assert len(labels) == 2 + assert labels[0] != labels[1] + + # Non-trivial test + edges = [] + r = Random(0) + for i in range(100): + edges.append((i, i, 1)) + edges.append((i, r.randint(0, 99), r.random())) + assert len(chinese_whispers(edges)) == 100 + + +def test_chinese_whispers_type_checks(): + """ + Tests contract (expected errors) in case client provides wrong types + """ + with raises(TypeError): + chinese_whispers() + with raises(TypeError): + chinese_whispers('foo') + with raises(RuntimeError): + chinese_whispers(['foo']) + with raises(IndexError): + chinese_whispers([(0,)]) + with raises(IndexError): + chinese_whispers([(0, 1, 2, 3)]) + with raises(RuntimeError): + chinese_whispers([('foo', 'bar')])