diff --git a/dlib/python/numpy_image.h b/dlib/python/numpy_image.h index 9d73bdf7e..af2808ce1 100644 --- a/dlib/python/numpy_image.h +++ b/dlib/python/numpy_image.h @@ -13,6 +13,7 @@ #include #include #include +#include namespace py = pybind11; @@ -356,18 +357,28 @@ namespace pybind11 { using basic_pixel_type = typename dlib::pixel_traits::basic_pixel_type; - static PYBIND11_DESCR name() { - constexpr size_t channels = dlib::pixel_traits::num; - if (channels == 1) - return _("numpy.ndarray[(rows,cols),") + npy_format_descriptor::name() + _("]"); - else if (channels == 2) + template + static PYBIND11_DESCR getname(typename std::enable_if::type) { + return _("numpy.ndarray[(rows,cols),") + npy_format_descriptor::name() + _("]"); + }; + template + static PYBIND11_DESCR getname(typename std::enable_if::type) { + if (channels == 2) return _("numpy.ndarray[(rows,cols,2),") + npy_format_descriptor::name() + _("]"); else if (channels == 3) return _("numpy.ndarray[(rows,cols,3),") + npy_format_descriptor::name() + _("]"); else if (channels == 4) return _("numpy.ndarray[(rows,cols,4),") + npy_format_descriptor::name() + _("]"); - else - DLIB_CASSERT(false,"unsupported pixel type"); + }; + + static PYBIND11_DESCR name() { + constexpr size_t channels = dlib::pixel_traits::num; + // The reason we have to call getname() in this wonky way is because + // pybind11 uses a type that records the length of the returned string in + // the type. So we have to do this overloading to make the return type + // from name() consistent. In C++17 this would be a lot cleaner with + // constexpr if, but can't use C++17 yet because of lack of wide support :( + return getname(0); } };