From ec4865ed60b21364a3234c9c8965f50d029a02db Mon Sep 17 00:00:00 2001 From: Davis King Date: Thu, 8 Aug 2013 10:20:38 -0400 Subject: [PATCH] Added some asserts into the optimization code to detect when the user accidentally creates objective functions which output infinite or NaN values. --- dlib/optimization/optimization.h | 18 ++++++++++++++++++ dlib/optimization/optimization_trust_region.h | 7 +++++++ 2 files changed, 25 insertions(+) diff --git a/dlib/optimization/optimization.h b/dlib/optimization/optimization.h index 84f2fa29d..cd4541fc1 100644 --- a/dlib/optimization/optimization.h +++ b/dlib/optimization/optimization.h @@ -188,6 +188,9 @@ namespace dlib double f_value = f(x); g = der(x); + DLIB_ASSERT(is_finite(f_value), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(g), "The objective function generated non-finite outputs"); + while(stop_strategy.should_continue_search(x, f_value, g) && f_value > min_f) { s = search_strategy.get_next_direction(x, f_value, g); @@ -202,6 +205,9 @@ namespace dlib // Take the search step indicated by the above line search x += alpha*s; + + DLIB_ASSERT(is_finite(f_value), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(g), "The objective function generated non-finite outputs"); } return f_value; @@ -249,6 +255,9 @@ namespace dlib double f_value = -f(x); g = -der(x); + DLIB_ASSERT(is_finite(f_value), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(g), "The objective function generated non-finite outputs"); + while(stop_strategy.should_continue_search(x, f_value, g) && f_value > -max_f) { s = search_strategy.get_next_direction(x, f_value, g); @@ -269,6 +278,9 @@ namespace dlib // from the unnegated versions of f() and der() g *= -1; f_value *= -1; + + DLIB_ASSERT(is_finite(f_value), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(g), "The objective function generated non-finite outputs"); } return -f_value; @@ -312,6 +324,9 @@ namespace dlib double f_value = f(x); g = derivative(f,derivative_eps)(x); + DLIB_ASSERT(is_finite(f_value), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(g), "The objective function generated non-finite outputs"); + while(stop_strategy.should_continue_search(x, f_value, g) && f_value > min_f) { s = search_strategy.get_next_direction(x, f_value, g); @@ -330,6 +345,9 @@ namespace dlib x += alpha*s; g = derivative(f,derivative_eps)(x); + + DLIB_ASSERT(is_finite(f_value), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(g), "The objective function generated non-finite outputs"); } return f_value; diff --git a/dlib/optimization/optimization_trust_region.h b/dlib/optimization/optimization_trust_region.h index 125b6f595..b178657a5 100644 --- a/dlib/optimization/optimization_trust_region.h +++ b/dlib/optimization/optimization_trust_region.h @@ -264,6 +264,9 @@ namespace dlib model.get_derivative_and_hessian(x,g,h); + DLIB_ASSERT(is_finite(x), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(g), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(h), "The objective function generated non-finite outputs"); // Sometimes the loop below won't modify x because the trust region step failed. // This bool tells us when we are in that case. @@ -323,6 +326,10 @@ namespace dlib { stale_x = true; } + + DLIB_ASSERT(is_finite(x), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(g), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(h), "The objective function generated non-finite outputs"); } return f_value;