From 1fbd1828abb4d19e53965504295defcd853ba53c Mon Sep 17 00:00:00 2001 From: Davis King Date: Fri, 24 Nov 2017 09:56:26 -0500 Subject: [PATCH] Cleaned up the code a bit. --- .../global_function_search.cpp | 90 +++++++++++-------- .../global_function_search.h | 10 +-- .../global_function_search_abstract.h | 45 +++++----- 3 files changed, 78 insertions(+), 67 deletions(-) diff --git a/dlib/global_optimization/global_function_search.cpp b/dlib/global_optimization/global_function_search.cpp index ea2afaeb7..0da7a92cd 100644 --- a/dlib/global_optimization/global_function_search.cpp +++ b/dlib/global_optimization/global_function_search.cpp @@ -11,7 +11,7 @@ namespace dlib namespace qopt_impl { - void fit_qp_mse( + void fit_quadratic_to_points_mse( const matrix& X, const matrix& Y, matrix& H, @@ -64,21 +64,21 @@ namespace dlib // ---------------------------------------------------------------------------------------- - void fit_qp( + void fit_quadratic_to_points( const matrix& X, const matrix& Y, matrix& H, matrix& g, double& c ) - /*! - requires - - X.size() > 0 + /*! + requires + - X.size() > 0 - X.nc() == Y.size() - - X.nr()+1 <= X.nc() <= (X.nr()+1)*(X.nr()+2)/2 - ensures - - This function finds a quadratic function, Q(x), that interpolates the - given set of points. If there aren't enough points to uniquely define + - X.nr()+1 <= X.nc() + ensures + - This function finds a quadratic function, Q(x), that interpolates the + given set of points. If there aren't enough points to uniquely define Q(x) then the Q(x) that fits the given points with the minimum Frobenius norm hessian matrix is selected. - To be precise: @@ -87,16 +87,19 @@ namespace dlib sum(squared(H)) such that: Q(colm(X,i)) == Y(i), for all valid i - !*/ + - If there are more points than necessary to constrain Q then the Q + that best interpolates the function in the mean squared sense is + found. + !*/ { DLIB_CASSERT(X.size() > 0); DLIB_CASSERT(X.nc() == Y.size()); - DLIB_CASSERT(X.nr()+1 <= X.nc());// && X.nc() <= (X.nr()+1)*(X.nr()+2)/2); + DLIB_CASSERT(X.nr()+1 <= X.nc()); if (X.nc() >= (X.nr()+1)*(X.nr()+2)/2) { - fit_qp_mse(X,Y,H,g,c); + fit_quadratic_to_points_mse(X,Y,H,g,c); return; } @@ -180,7 +183,7 @@ namespace dlib matrix g; double c; - fit_qp(X, Y, H, g, c); + fit_quadratic_to_points(X, Y, H, g, c); matrix p; @@ -198,7 +201,7 @@ namespace dlib // ---------------------------------------------------------------------------------------- - quad_interp_result pick_next_sample_quad_interp ( + quad_interp_result pick_next_sample_using_trust_region ( const std::vector& samples, double& radius, const matrix& lower, @@ -324,7 +327,7 @@ namespace dlib // ------------------------------------------------------------------------------------ - max_upper_bound_function pick_next_sample_max_upper_bound_function ( + max_upper_bound_function pick_next_sample_as_max_upper_bound ( dlib::rand& rnd, const upper_bound_function& ub, const matrix& lower, @@ -417,10 +420,10 @@ namespace dlib { upper_bound_function tmp(ub); - // we are going to add the incomplete evals into this and assume the - // incomplete evals are going to take y values equal to their nearest + // we are going to add the outstanding evals into this and assume the + // outstanding evals are going to take y values equal to their nearest // neighbor complete evals. - for (auto& eval : incomplete_evals) + for (auto& eval : outstanding_evals) { function_evaluation e; e.x = eval.x; @@ -454,6 +457,7 @@ namespace dlib } // end namespace gopt_impl +// ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- @@ -526,9 +530,9 @@ namespace dlib { std::lock_guard lock(*info->m); - // remove the evaluation request from the incomplete list. - auto i = std::find(info->incomplete_evals.begin(), info->incomplete_evals.end(), req); - info->incomplete_evals.erase(i); + // remove the evaluation request from the outstanding list. + auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req); + info->outstanding_evals.erase(i); } } @@ -545,10 +549,10 @@ namespace dlib m_has_been_evaluated = true; - // move the evaluation from incomplete to complete - auto i = std::find(info->incomplete_evals.begin(), info->incomplete_evals.end(), req); - DLIB_CASSERT(i != info->incomplete_evals.end()); - info->incomplete_evals.erase(i); + // move the evaluation from outstanding to complete + auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req); + DLIB_CASSERT(i != info->outstanding_evals.end()); + info->outstanding_evals.erase(i); info->ub.add(function_evaluation(req.x,y)); @@ -582,6 +586,8 @@ namespace dlib } } +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- global_function_search:: @@ -701,13 +707,13 @@ namespace dlib outstanding_function_eval_request new_req; new_req.request_id = next_request_id++; new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable); - info->incomplete_evals.emplace_back(new_req); + info->outstanding_evals.emplace_back(new_req); return function_evaluation_request(new_req,info); } } - if (do_trust_region_step && !has_incomplete_trust_region_request()) + if (do_trust_region_step && !has_outstanding_trust_region_request()) { // find the currently best performing function, we will do a trust region // step on it. @@ -716,7 +722,7 @@ namespace dlib // if we have enough points to do a trust region step if (info->ub.num_points() > dims+1) { - auto tmp = pick_next_sample_quad_interp(info->ub.get_points(), + auto tmp = pick_next_sample_using_trust_region(info->ub.get_points(), info->radius, info->spec.lower, info->spec.upper, info->spec.is_integer_variable); //std::cout << "QP predicted improvement: "<< tmp.predicted_improvement << std::endl; if (tmp.predicted_improvement > min_trust_region_epsilon) @@ -728,7 +734,7 @@ namespace dlib new_req.was_trust_region_generated_request = true; new_req.anchor_objective_value = info->best_objective_value; new_req.predicted_improvement = tmp.predicted_improvement; - info->incomplete_evals.emplace_back(new_req); + info->outstanding_evals.emplace_back(new_req); return function_evaluation_request(new_req, info); } } @@ -747,7 +753,7 @@ namespace dlib // function with the largest upper bound for evaluation. for (auto& info : functions) { - auto tmp = pick_next_sample_max_upper_bound_function(rnd, + auto tmp = pick_next_sample_as_max_upper_bound(rnd, info->build_upper_bound_with_all_function_evals(), info->spec.lower, info->spec.upper, info->spec.is_integer_variable, num_random_samples); if (tmp.predicted_improvement > 0 && tmp.upper_bound > best_upper_bound) @@ -764,7 +770,7 @@ namespace dlib outstanding_function_eval_request new_req; new_req.request_id = next_request_id++; new_req.x = std::move(next_sample); - best_funct->incomplete_evals.emplace_back(new_req); + best_funct->outstanding_evals.emplace_back(new_req); return function_evaluation_request(new_req, best_funct); } } @@ -776,7 +782,7 @@ namespace dlib outstanding_function_eval_request new_req; new_req.request_id = next_request_id++; new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable); - info->incomplete_evals.emplace_back(new_req); + info->outstanding_evals.emplace_back(new_req); return function_evaluation_request(new_req, info); } @@ -839,9 +845,13 @@ namespace dlib { DLIB_CASSERT(0 <= value); relative_noise_magnitude = value; - // recreate all the upper bound functions with the new relative noise magnitude - for (auto& f : functions) - f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude); + if (m) + { + std::lock_guard lock(*m); + // recreate all the upper bound functions with the new relative noise magnitude + for (auto& f : functions) + f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude); + } } // ---------------------------------------------------------------------------------------- @@ -881,8 +891,10 @@ namespace dlib size_t& idx ) const { - auto i = std::max_element(functions.begin(), functions.end(), - [](const std::shared_ptr& a, const std::shared_ptr& b) { return a->best_objective_value < b->best_objective_value; }); + auto compare = [](const std::shared_ptr& a, const std::shared_ptr& b) + { return a->best_objective_value < b->best_objective_value; }; + + auto i = std::max_element(functions.begin(), functions.end(), compare); idx = std::distance(functions.begin(),i); return *i; @@ -891,12 +903,12 @@ namespace dlib // ---------------------------------------------------------------------------------------- bool global_function_search:: - has_incomplete_trust_region_request ( + has_outstanding_trust_region_request ( ) const { for (auto& f : functions) { - for (auto& i : f->incomplete_evals) + for (auto& i : f->outstanding_evals) { if (i.was_trust_region_generated_request) return true; diff --git a/dlib/global_optimization/global_function_search.h b/dlib/global_optimization/global_function_search.h index 95366d70d..b8341c33f 100644 --- a/dlib/global_optimization/global_function_search.h +++ b/dlib/global_optimization/global_function_search.h @@ -79,7 +79,7 @@ namespace dlib size_t function_idx = 0; std::shared_ptr m; upper_bound_function ub; - std::vector incomplete_evals; + std::vector outstanding_evals; matrix best_x; double best_objective_value = -std::numeric_limits::infinity(); double radius = 0; @@ -101,7 +101,7 @@ namespace dlib function_evaluation_request(function_evaluation_request&& item); function_evaluation_request& operator=(function_evaluation_request&& item); - void swap(function_evaluation_request& item); + ~function_evaluation_request(); size_t function_idx ( ) const; @@ -112,12 +112,12 @@ namespace dlib bool has_been_evaluated ( ) const; - ~function_evaluation_request(); - void set ( double y ); + void swap(function_evaluation_request& item); + private: friend class global_function_search; @@ -218,7 +218,7 @@ namespace dlib size_t& idx ) const; - bool has_incomplete_trust_region_request ( + bool has_outstanding_trust_region_request ( ) const; diff --git a/dlib/global_optimization/global_function_search_abstract.h b/dlib/global_optimization/global_function_search_abstract.h index f6b2080a1..0a1a8f880 100644 --- a/dlib/global_optimization/global_function_search_abstract.h +++ b/dlib/global_optimization/global_function_search_abstract.h @@ -89,14 +89,6 @@ namespace dlib moving from item causes item.has_been_evaluated() == true, TODO, clarify !*/ - void swap( - function_evaluation_request& item - ); - /*! - ensures - - swaps the state of *this and item - !*/ - ~function_evaluation_request( ); /*! @@ -113,7 +105,6 @@ namespace dlib bool has_been_evaluated ( ) const; - void set ( double y ); @@ -124,6 +115,14 @@ namespace dlib - #has_been_evaluated() == true !*/ + void swap( + function_evaluation_request& item + ); + /*! + ensures + - swaps the state of *this and item + !*/ + }; // ---------------------------------------------------------------------------------------- @@ -143,18 +142,6 @@ namespace dlib - #num_functions() == 0 !*/ - // This object can't be copied. - global_function_search(const global_function_search&) = delete; - global_function_search& operator=(const global_function_search& item) = delete; - - global_function_search(global_function_search&& item) = default; - global_function_search& operator=(global_function_search&& item) = default; - /*! - ensures - - moves the state of item into *this - - #item.num_functions() == 0 - !*/ - explicit global_function_search( const function_spec& function ); @@ -169,13 +156,25 @@ namespace dlib const double relative_noise_magnitude = 0.001 ); - size_t num_functions( - ) const; + // This object can't be copied. + global_function_search(const global_function_search&) = delete; + global_function_search& operator=(const global_function_search& item) = delete; + + global_function_search(global_function_search&& item) = default; + global_function_search& operator=(global_function_search&& item) = default; + /*! + ensures + - moves the state of item into *this + - #item.num_functions() == 0 + !*/ void set_seed ( time_t seed ); + size_t num_functions( + ) const; + void get_function_evaluations ( std::vector& specs, std::vector>& function_evals