Added support for variadic Python functions in find_max_global. (#1141)

* Added support for variadic Python functions in find_max_global.

* Add test for find_{min,max}_global on variadic functions.
This commit is contained in:
Morten Hustveit 2018-02-19 13:26:45 -05:00 committed by Davis E. King
parent 09a9ad6d1e
commit 9691c194c0
2 changed files with 30 additions and 6 deletions

View File

@ -45,17 +45,18 @@ py::list mat_to_list (
return l;
}
size_t num_function_arguments(py::object f)
size_t num_function_arguments(py::object f, size_t expected_num)
{
if (hasattr(f,"func_code"))
return f.attr("func_code").attr("co_argcount").cast<std::size_t>();
else
return f.attr("__code__").attr("co_argcount").cast<std::size_t>();
const auto code_object = f.attr(hasattr(f,"func_code") ? "func_code" : "__code__");
const auto num = code_object.attr("co_argcount").cast<std::size_t>();
if (num < expected_num && (code_object.attr("co_flags").cast<int>() & CO_VARARGS))
return expected_num;
return num;
}
double call_func(py::object f, const matrix<double,0,1>& args)
{
const auto num = num_function_arguments(f);
const auto num = num_function_arguments(f, args.size());
DLIB_CASSERT(num == args.size(),
"The function being optimized takes a number of arguments that doesn't agree with the size of the bounds lists you provided to find_max_global()");
DLIB_CASSERT(0 < num && num < 15, "Functions being optimized must take between 1 and 15 scalar arguments.");

View File

@ -0,0 +1,23 @@
from dlib import find_max_global, find_min_global
from pytest import raises
def test_global_optimization_nargs():
w0 = find_max_global(lambda *args: sum(args), [0, 0, 0], [1, 1, 1], 10)
w1 = find_min_global(lambda *args: sum(args), [0, 0, 0], [1, 1, 1], 10)
assert w0 == ([1, 1, 1], 3)
assert w1 == ([0, 0, 0], 0)
w2 = find_max_global(lambda a, b, c, *args: a + b + c - sum(args), [0, 0, 0], [1, 1, 1], 10)
w3 = find_min_global(lambda a, b, c, *args: a + b + c - sum(args), [0, 0, 0], [1, 1, 1], 10)
assert w2 == ([1, 1, 1], 3)
assert w3 == ([0, 0, 0], 0)
with raises(Exception):
find_max_global(lambda a, b: 0, [0, 0, 0], [1, 1, 1], 10)
with raises(Exception):
find_min_global(lambda a, b: 0, [0, 0, 0], [1, 1, 1], 10)
with raises(Exception):
find_max_global(lambda a, b, c, d, *args: 0, [0, 0, 0], [1, 1, 1], 10)
with raises(Exception):
find_min_global(lambda a, b, c, d, *args: 0, [0, 0, 0], [1, 1, 1], 10)