Cleaned up the code a bit.

This commit is contained in:
Davis King 2017-11-24 09:56:26 -05:00
parent 0d9043bc09
commit 1fbd1828ab
3 changed files with 78 additions and 67 deletions

View File

@ -11,7 +11,7 @@ namespace dlib
namespace qopt_impl
{
void fit_qp_mse(
void fit_quadratic_to_points_mse(
const matrix<double>& X,
const matrix<double,0,1>& Y,
matrix<double>& H,
@ -64,21 +64,21 @@ namespace dlib
// ----------------------------------------------------------------------------------------
void fit_qp(
void fit_quadratic_to_points(
const matrix<double>& X,
const matrix<double,0,1>& Y,
matrix<double>& H,
matrix<double,0,1>& g,
double& c
)
/*!
requires
- X.size() > 0
/*!
requires
- X.size() > 0
- X.nc() == Y.size()
- X.nr()+1 <= X.nc() <= (X.nr()+1)*(X.nr()+2)/2
ensures
- This function finds a quadratic function, Q(x), that interpolates the
given set of points. If there aren't enough points to uniquely define
- X.nr()+1 <= X.nc()
ensures
- This function finds a quadratic function, Q(x), that interpolates the
given set of points. If there aren't enough points to uniquely define
Q(x) then the Q(x) that fits the given points with the minimum Frobenius
norm hessian matrix is selected.
- To be precise:
@ -87,16 +87,19 @@ namespace dlib
sum(squared(H))
such that:
Q(colm(X,i)) == Y(i), for all valid i
!*/
- If there are more points than necessary to constrain Q then the Q
that best interpolates the function in the mean squared sense is
found.
!*/
{
DLIB_CASSERT(X.size() > 0);
DLIB_CASSERT(X.nc() == Y.size());
DLIB_CASSERT(X.nr()+1 <= X.nc());// && X.nc() <= (X.nr()+1)*(X.nr()+2)/2);
DLIB_CASSERT(X.nr()+1 <= X.nc());
if (X.nc() >= (X.nr()+1)*(X.nr()+2)/2)
{
fit_qp_mse(X,Y,H,g,c);
fit_quadratic_to_points_mse(X,Y,H,g,c);
return;
}
@ -180,7 +183,7 @@ namespace dlib
matrix<double,0,1> g;
double c;
fit_qp(X, Y, H, g, c);
fit_quadratic_to_points(X, Y, H, g, c);
matrix<double,0,1> p;
@ -198,7 +201,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
quad_interp_result pick_next_sample_quad_interp (
quad_interp_result pick_next_sample_using_trust_region (
const std::vector<function_evaluation>& samples,
double& radius,
const matrix<double,0,1>& lower,
@ -324,7 +327,7 @@ namespace dlib
// ------------------------------------------------------------------------------------
max_upper_bound_function pick_next_sample_max_upper_bound_function (
max_upper_bound_function pick_next_sample_as_max_upper_bound (
dlib::rand& rnd,
const upper_bound_function& ub,
const matrix<double,0,1>& lower,
@ -417,10 +420,10 @@ namespace dlib
{
upper_bound_function tmp(ub);
// we are going to add the incomplete evals into this and assume the
// incomplete evals are going to take y values equal to their nearest
// we are going to add the outstanding evals into this and assume the
// outstanding evals are going to take y values equal to their nearest
// neighbor complete evals.
for (auto& eval : incomplete_evals)
for (auto& eval : outstanding_evals)
{
function_evaluation e;
e.x = eval.x;
@ -454,6 +457,7 @@ namespace dlib
} // end namespace gopt_impl
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
@ -526,9 +530,9 @@ namespace dlib
{
std::lock_guard<std::mutex> lock(*info->m);
// remove the evaluation request from the incomplete list.
auto i = std::find(info->incomplete_evals.begin(), info->incomplete_evals.end(), req);
info->incomplete_evals.erase(i);
// remove the evaluation request from the outstanding list.
auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req);
info->outstanding_evals.erase(i);
}
}
@ -545,10 +549,10 @@ namespace dlib
m_has_been_evaluated = true;
// move the evaluation from incomplete to complete
auto i = std::find(info->incomplete_evals.begin(), info->incomplete_evals.end(), req);
DLIB_CASSERT(i != info->incomplete_evals.end());
info->incomplete_evals.erase(i);
// move the evaluation from outstanding to complete
auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req);
DLIB_CASSERT(i != info->outstanding_evals.end());
info->outstanding_evals.erase(i);
info->ub.add(function_evaluation(req.x,y));
@ -582,6 +586,8 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
global_function_search::
@ -701,13 +707,13 @@ namespace dlib
outstanding_function_eval_request new_req;
new_req.request_id = next_request_id++;
new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
info->incomplete_evals.emplace_back(new_req);
info->outstanding_evals.emplace_back(new_req);
return function_evaluation_request(new_req,info);
}
}
if (do_trust_region_step && !has_incomplete_trust_region_request())
if (do_trust_region_step && !has_outstanding_trust_region_request())
{
// find the currently best performing function, we will do a trust region
// step on it.
@ -716,7 +722,7 @@ namespace dlib
// if we have enough points to do a trust region step
if (info->ub.num_points() > dims+1)
{
auto tmp = pick_next_sample_quad_interp(info->ub.get_points(),
auto tmp = pick_next_sample_using_trust_region(info->ub.get_points(),
info->radius, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
//std::cout << "QP predicted improvement: "<< tmp.predicted_improvement << std::endl;
if (tmp.predicted_improvement > min_trust_region_epsilon)
@ -728,7 +734,7 @@ namespace dlib
new_req.was_trust_region_generated_request = true;
new_req.anchor_objective_value = info->best_objective_value;
new_req.predicted_improvement = tmp.predicted_improvement;
info->incomplete_evals.emplace_back(new_req);
info->outstanding_evals.emplace_back(new_req);
return function_evaluation_request(new_req, info);
}
}
@ -747,7 +753,7 @@ namespace dlib
// function with the largest upper bound for evaluation.
for (auto& info : functions)
{
auto tmp = pick_next_sample_max_upper_bound_function(rnd,
auto tmp = pick_next_sample_as_max_upper_bound(rnd,
info->build_upper_bound_with_all_function_evals(), info->spec.lower, info->spec.upper,
info->spec.is_integer_variable, num_random_samples);
if (tmp.predicted_improvement > 0 && tmp.upper_bound > best_upper_bound)
@ -764,7 +770,7 @@ namespace dlib
outstanding_function_eval_request new_req;
new_req.request_id = next_request_id++;
new_req.x = std::move(next_sample);
best_funct->incomplete_evals.emplace_back(new_req);
best_funct->outstanding_evals.emplace_back(new_req);
return function_evaluation_request(new_req, best_funct);
}
}
@ -776,7 +782,7 @@ namespace dlib
outstanding_function_eval_request new_req;
new_req.request_id = next_request_id++;
new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
info->incomplete_evals.emplace_back(new_req);
info->outstanding_evals.emplace_back(new_req);
return function_evaluation_request(new_req, info);
}
@ -839,9 +845,13 @@ namespace dlib
{
DLIB_CASSERT(0 <= value);
relative_noise_magnitude = value;
// recreate all the upper bound functions with the new relative noise magnitude
for (auto& f : functions)
f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude);
if (m)
{
std::lock_guard<std::mutex> lock(*m);
// recreate all the upper bound functions with the new relative noise magnitude
for (auto& f : functions)
f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude);
}
}
// ----------------------------------------------------------------------------------------
@ -881,8 +891,10 @@ namespace dlib
size_t& idx
) const
{
auto i = std::max_element(functions.begin(), functions.end(),
[](const std::shared_ptr<gopt_impl::funct_info>& a, const std::shared_ptr<gopt_impl::funct_info>& b) { return a->best_objective_value < b->best_objective_value; });
auto compare = [](const std::shared_ptr<gopt_impl::funct_info>& a, const std::shared_ptr<gopt_impl::funct_info>& b)
{ return a->best_objective_value < b->best_objective_value; };
auto i = std::max_element(functions.begin(), functions.end(), compare);
idx = std::distance(functions.begin(),i);
return *i;
@ -891,12 +903,12 @@ namespace dlib
// ----------------------------------------------------------------------------------------
bool global_function_search::
has_incomplete_trust_region_request (
has_outstanding_trust_region_request (
) const
{
for (auto& f : functions)
{
for (auto& i : f->incomplete_evals)
for (auto& i : f->outstanding_evals)
{
if (i.was_trust_region_generated_request)
return true;

View File

@ -79,7 +79,7 @@ namespace dlib
size_t function_idx = 0;
std::shared_ptr<std::mutex> m;
upper_bound_function ub;
std::vector<outstanding_function_eval_request> incomplete_evals;
std::vector<outstanding_function_eval_request> outstanding_evals;
matrix<double,0,1> best_x;
double best_objective_value = -std::numeric_limits<double>::infinity();
double radius = 0;
@ -101,7 +101,7 @@ namespace dlib
function_evaluation_request(function_evaluation_request&& item);
function_evaluation_request& operator=(function_evaluation_request&& item);
void swap(function_evaluation_request& item);
~function_evaluation_request();
size_t function_idx (
) const;
@ -112,12 +112,12 @@ namespace dlib
bool has_been_evaluated (
) const;
~function_evaluation_request();
void set (
double y
);
void swap(function_evaluation_request& item);
private:
friend class global_function_search;
@ -218,7 +218,7 @@ namespace dlib
size_t& idx
) const;
bool has_incomplete_trust_region_request (
bool has_outstanding_trust_region_request (
) const;

View File

@ -89,14 +89,6 @@ namespace dlib
moving from item causes item.has_been_evaluated() == true, TODO, clarify
!*/
void swap(
function_evaluation_request& item
);
/*!
ensures
- swaps the state of *this and item
!*/
~function_evaluation_request(
);
/*!
@ -113,7 +105,6 @@ namespace dlib
bool has_been_evaluated (
) const;
void set (
double y
);
@ -124,6 +115,14 @@ namespace dlib
- #has_been_evaluated() == true
!*/
void swap(
function_evaluation_request& item
);
/*!
ensures
- swaps the state of *this and item
!*/
};
// ----------------------------------------------------------------------------------------
@ -143,18 +142,6 @@ namespace dlib
- #num_functions() == 0
!*/
// This object can't be copied.
global_function_search(const global_function_search&) = delete;
global_function_search& operator=(const global_function_search& item) = delete;
global_function_search(global_function_search&& item) = default;
global_function_search& operator=(global_function_search&& item) = default;
/*!
ensures
- moves the state of item into *this
- #item.num_functions() == 0
!*/
explicit global_function_search(
const function_spec& function
);
@ -169,13 +156,25 @@ namespace dlib
const double relative_noise_magnitude = 0.001
);
size_t num_functions(
) const;
// This object can't be copied.
global_function_search(const global_function_search&) = delete;
global_function_search& operator=(const global_function_search& item) = delete;
global_function_search(global_function_search&& item) = default;
global_function_search& operator=(global_function_search&& item) = default;
/*!
ensures
- moves the state of item into *this
- #item.num_functions() == 0
!*/
void set_seed (
time_t seed
);
size_t num_functions(
) const;
void get_function_evaluations (
std::vector<function_spec>& specs,
std::vector<std::vector<function_evaluation>>& function_evals