mirror of https://github.com/davisking/dlib.git
Added support for variadic Python functions in find_max_global. (#1141)
* Added support for variadic Python functions in find_max_global. * Add test for find_{min,max}_global on variadic functions.
This commit is contained in:
parent
09a9ad6d1e
commit
9691c194c0
|
@ -45,17 +45,18 @@ py::list mat_to_list (
|
|||
return l;
|
||||
}
|
||||
|
||||
size_t num_function_arguments(py::object f)
|
||||
size_t num_function_arguments(py::object f, size_t expected_num)
|
||||
{
|
||||
if (hasattr(f,"func_code"))
|
||||
return f.attr("func_code").attr("co_argcount").cast<std::size_t>();
|
||||
else
|
||||
return f.attr("__code__").attr("co_argcount").cast<std::size_t>();
|
||||
const auto code_object = f.attr(hasattr(f,"func_code") ? "func_code" : "__code__");
|
||||
const auto num = code_object.attr("co_argcount").cast<std::size_t>();
|
||||
if (num < expected_num && (code_object.attr("co_flags").cast<int>() & CO_VARARGS))
|
||||
return expected_num;
|
||||
return num;
|
||||
}
|
||||
|
||||
double call_func(py::object f, const matrix<double,0,1>& args)
|
||||
{
|
||||
const auto num = num_function_arguments(f);
|
||||
const auto num = num_function_arguments(f, args.size());
|
||||
DLIB_CASSERT(num == args.size(),
|
||||
"The function being optimized takes a number of arguments that doesn't agree with the size of the bounds lists you provided to find_max_global()");
|
||||
DLIB_CASSERT(0 < num && num < 15, "Functions being optimized must take between 1 and 15 scalar arguments.");
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
from dlib import find_max_global, find_min_global
|
||||
from pytest import raises
|
||||
|
||||
|
||||
def test_global_optimization_nargs():
|
||||
w0 = find_max_global(lambda *args: sum(args), [0, 0, 0], [1, 1, 1], 10)
|
||||
w1 = find_min_global(lambda *args: sum(args), [0, 0, 0], [1, 1, 1], 10)
|
||||
assert w0 == ([1, 1, 1], 3)
|
||||
assert w1 == ([0, 0, 0], 0)
|
||||
|
||||
w2 = find_max_global(lambda a, b, c, *args: a + b + c - sum(args), [0, 0, 0], [1, 1, 1], 10)
|
||||
w3 = find_min_global(lambda a, b, c, *args: a + b + c - sum(args), [0, 0, 0], [1, 1, 1], 10)
|
||||
assert w2 == ([1, 1, 1], 3)
|
||||
assert w3 == ([0, 0, 0], 0)
|
||||
|
||||
with raises(Exception):
|
||||
find_max_global(lambda a, b: 0, [0, 0, 0], [1, 1, 1], 10)
|
||||
with raises(Exception):
|
||||
find_min_global(lambda a, b: 0, [0, 0, 0], [1, 1, 1], 10)
|
||||
with raises(Exception):
|
||||
find_max_global(lambda a, b, c, d, *args: 0, [0, 0, 0], [1, 1, 1], 10)
|
||||
with raises(Exception):
|
||||
find_min_global(lambda a, b, c, d, *args: 0, [0, 0, 0], [1, 1, 1], 10)
|
Loading…
Reference in New Issue