diff --git a/tools/python/src/global_optimization.cpp b/tools/python/src/global_optimization.cpp index 741126691..39147be41 100644 --- a/tools/python/src/global_optimization.cpp +++ b/tools/python/src/global_optimization.cpp @@ -45,17 +45,18 @@ py::list mat_to_list ( return l; } -size_t num_function_arguments(py::object f) +size_t num_function_arguments(py::object f, size_t expected_num) { - if (hasattr(f,"func_code")) - return f.attr("func_code").attr("co_argcount").cast(); - else - return f.attr("__code__").attr("co_argcount").cast(); + const auto code_object = f.attr(hasattr(f,"func_code") ? "func_code" : "__code__"); + const auto num = code_object.attr("co_argcount").cast(); + if (num < expected_num && (code_object.attr("co_flags").cast() & CO_VARARGS)) + return expected_num; + return num; } double call_func(py::object f, const matrix& args) { - const auto num = num_function_arguments(f); + const auto num = num_function_arguments(f, args.size()); DLIB_CASSERT(num == args.size(), "The function being optimized takes a number of arguments that doesn't agree with the size of the bounds lists you provided to find_max_global()"); DLIB_CASSERT(0 < num && num < 15, "Functions being optimized must take between 1 and 15 scalar arguments."); diff --git a/tools/python/test/test_global_optimization.py b/tools/python/test/test_global_optimization.py new file mode 100644 index 000000000..5e85ca920 --- /dev/null +++ b/tools/python/test/test_global_optimization.py @@ -0,0 +1,23 @@ +from dlib import find_max_global, find_min_global +from pytest import raises + + +def test_global_optimization_nargs(): + w0 = find_max_global(lambda *args: sum(args), [0, 0, 0], [1, 1, 1], 10) + w1 = find_min_global(lambda *args: sum(args), [0, 0, 0], [1, 1, 1], 10) + assert w0 == ([1, 1, 1], 3) + assert w1 == ([0, 0, 0], 0) + + w2 = find_max_global(lambda a, b, c, *args: a + b + c - sum(args), [0, 0, 0], [1, 1, 1], 10) + w3 = find_min_global(lambda a, b, c, *args: a + b + c - sum(args), [0, 0, 0], [1, 1, 1], 10) + assert w2 == ([1, 1, 1], 3) + assert w3 == ([0, 0, 0], 0) + + with raises(Exception): + find_max_global(lambda a, b: 0, [0, 0, 0], [1, 1, 1], 10) + with raises(Exception): + find_min_global(lambda a, b: 0, [0, 0, 0], [1, 1, 1], 10) + with raises(Exception): + find_max_global(lambda a, b, c, d, *args: 0, [0, 0, 0], [1, 1, 1], 10) + with raises(Exception): + find_min_global(lambda a, b, c, d, *args: 0, [0, 0, 0], [1, 1, 1], 10)