diff --git a/tools/python/src/face_recognition.cpp b/tools/python/src/face_recognition.cpp index b61375849..ca2af5c76 100644 --- a/tools/python/src/face_recognition.cpp +++ b/tools/python/src/face_recognition.cpp @@ -230,6 +230,37 @@ private: // ---------------------------------------------------------------------------------------- +py::list bottom_up_clustering(py::list descriptors, const int min_num_clusters, const double max_dist) +{ + DLIB_CASSERT(min_num_clusters > 0); + + size_t num_descriptors = py::len(descriptors); + matrix dist(static_cast(num_descriptors), static_cast(num_descriptors)); + + for (size_t i = 0; i < num_descriptors; ++i) + { + for (size_t j = i+1; j < num_descriptors; ++j) + { + const long i2 = static_cast(i); + const long j2 = static_cast(j); + matrix& first_descriptor = descriptors[i].cast&>(); + matrix& second_descriptor = descriptors[j].cast&>(); + dist(i2, j2) = dlib::length( first_descriptor- second_descriptor); + dist(j2, i2) = dist(i2, j2); + } + } + + std::vector labels; + const auto num_clusters = dlib::bottom_up_cluster(dist, labels, min_num_clusters, max_dist); + + py::list clusters; + for (size_t i = 0; i < labels.size(); ++i) + { + clusters.append(labels[i]); + } + return clusters; +} + py::list chinese_whispers_clustering(py::list descriptors, float threshold) { DLIB_CASSERT(threshold > 0); @@ -391,6 +422,8 @@ void bind_face_recognition(py::module &m) "Takes an image and a full_object_detections object that reference faces in that image and saves the faces with the specified file name prefix. The faces will be rotated upright and scaled to 150x150 pixels or with the optional specified size and padding.", py::arg("img"), py::arg("faces"), py::arg("chip_filename"), py::arg("size")=150, py::arg("padding")=0.25 ); + m.def("bottom_up_clustering", &bottom_up_clustering, py::arg("descriptors"), py::arg("min_num_clusters")=1, py::arg("max_dist")=0.6, + "Takes a list of descriptors and returns a list that contains a label for each descriptor. Clustering is done using dlib::bottom_up_cluster."); m.def("chinese_whispers_clustering", &chinese_whispers_clustering, py::arg("descriptors"), py::arg("threshold"), "Takes a list of descriptors and returns a list that contains a label for each descriptor. Clustering is done using dlib::chinese_whispers." );