mirror of https://github.com/davisking/dlib.git
Cleaned up the code a bit.
This commit is contained in:
parent
0d9043bc09
commit
1fbd1828ab
|
@ -11,7 +11,7 @@ namespace dlib
|
|||
|
||||
namespace qopt_impl
|
||||
{
|
||||
void fit_qp_mse(
|
||||
void fit_quadratic_to_points_mse(
|
||||
const matrix<double>& X,
|
||||
const matrix<double,0,1>& Y,
|
||||
matrix<double>& H,
|
||||
|
@ -64,21 +64,21 @@ namespace dlib
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void fit_qp(
|
||||
void fit_quadratic_to_points(
|
||||
const matrix<double>& X,
|
||||
const matrix<double,0,1>& Y,
|
||||
matrix<double>& H,
|
||||
matrix<double,0,1>& g,
|
||||
double& c
|
||||
)
|
||||
/*!
|
||||
requires
|
||||
- X.size() > 0
|
||||
/*!
|
||||
requires
|
||||
- X.size() > 0
|
||||
- X.nc() == Y.size()
|
||||
- X.nr()+1 <= X.nc() <= (X.nr()+1)*(X.nr()+2)/2
|
||||
ensures
|
||||
- This function finds a quadratic function, Q(x), that interpolates the
|
||||
given set of points. If there aren't enough points to uniquely define
|
||||
- X.nr()+1 <= X.nc()
|
||||
ensures
|
||||
- This function finds a quadratic function, Q(x), that interpolates the
|
||||
given set of points. If there aren't enough points to uniquely define
|
||||
Q(x) then the Q(x) that fits the given points with the minimum Frobenius
|
||||
norm hessian matrix is selected.
|
||||
- To be precise:
|
||||
|
@ -87,16 +87,19 @@ namespace dlib
|
|||
sum(squared(H))
|
||||
such that:
|
||||
Q(colm(X,i)) == Y(i), for all valid i
|
||||
!*/
|
||||
- If there are more points than necessary to constrain Q then the Q
|
||||
that best interpolates the function in the mean squared sense is
|
||||
found.
|
||||
!*/
|
||||
{
|
||||
DLIB_CASSERT(X.size() > 0);
|
||||
DLIB_CASSERT(X.nc() == Y.size());
|
||||
DLIB_CASSERT(X.nr()+1 <= X.nc());// && X.nc() <= (X.nr()+1)*(X.nr()+2)/2);
|
||||
DLIB_CASSERT(X.nr()+1 <= X.nc());
|
||||
|
||||
|
||||
if (X.nc() >= (X.nr()+1)*(X.nr()+2)/2)
|
||||
{
|
||||
fit_qp_mse(X,Y,H,g,c);
|
||||
fit_quadratic_to_points_mse(X,Y,H,g,c);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -180,7 +183,7 @@ namespace dlib
|
|||
matrix<double,0,1> g;
|
||||
double c;
|
||||
|
||||
fit_qp(X, Y, H, g, c);
|
||||
fit_quadratic_to_points(X, Y, H, g, c);
|
||||
|
||||
matrix<double,0,1> p;
|
||||
|
||||
|
@ -198,7 +201,7 @@ namespace dlib
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
quad_interp_result pick_next_sample_quad_interp (
|
||||
quad_interp_result pick_next_sample_using_trust_region (
|
||||
const std::vector<function_evaluation>& samples,
|
||||
double& radius,
|
||||
const matrix<double,0,1>& lower,
|
||||
|
@ -324,7 +327,7 @@ namespace dlib
|
|||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
max_upper_bound_function pick_next_sample_max_upper_bound_function (
|
||||
max_upper_bound_function pick_next_sample_as_max_upper_bound (
|
||||
dlib::rand& rnd,
|
||||
const upper_bound_function& ub,
|
||||
const matrix<double,0,1>& lower,
|
||||
|
@ -417,10 +420,10 @@ namespace dlib
|
|||
{
|
||||
upper_bound_function tmp(ub);
|
||||
|
||||
// we are going to add the incomplete evals into this and assume the
|
||||
// incomplete evals are going to take y values equal to their nearest
|
||||
// we are going to add the outstanding evals into this and assume the
|
||||
// outstanding evals are going to take y values equal to their nearest
|
||||
// neighbor complete evals.
|
||||
for (auto& eval : incomplete_evals)
|
||||
for (auto& eval : outstanding_evals)
|
||||
{
|
||||
function_evaluation e;
|
||||
e.x = eval.x;
|
||||
|
@ -454,6 +457,7 @@ namespace dlib
|
|||
|
||||
} // end namespace gopt_impl
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -526,9 +530,9 @@ namespace dlib
|
|||
{
|
||||
std::lock_guard<std::mutex> lock(*info->m);
|
||||
|
||||
// remove the evaluation request from the incomplete list.
|
||||
auto i = std::find(info->incomplete_evals.begin(), info->incomplete_evals.end(), req);
|
||||
info->incomplete_evals.erase(i);
|
||||
// remove the evaluation request from the outstanding list.
|
||||
auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req);
|
||||
info->outstanding_evals.erase(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -545,10 +549,10 @@ namespace dlib
|
|||
m_has_been_evaluated = true;
|
||||
|
||||
|
||||
// move the evaluation from incomplete to complete
|
||||
auto i = std::find(info->incomplete_evals.begin(), info->incomplete_evals.end(), req);
|
||||
DLIB_CASSERT(i != info->incomplete_evals.end());
|
||||
info->incomplete_evals.erase(i);
|
||||
// move the evaluation from outstanding to complete
|
||||
auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req);
|
||||
DLIB_CASSERT(i != info->outstanding_evals.end());
|
||||
info->outstanding_evals.erase(i);
|
||||
info->ub.add(function_evaluation(req.x,y));
|
||||
|
||||
|
||||
|
@ -582,6 +586,8 @@ namespace dlib
|
|||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
global_function_search::
|
||||
|
@ -701,13 +707,13 @@ namespace dlib
|
|||
outstanding_function_eval_request new_req;
|
||||
new_req.request_id = next_request_id++;
|
||||
new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
|
||||
info->incomplete_evals.emplace_back(new_req);
|
||||
info->outstanding_evals.emplace_back(new_req);
|
||||
return function_evaluation_request(new_req,info);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (do_trust_region_step && !has_incomplete_trust_region_request())
|
||||
if (do_trust_region_step && !has_outstanding_trust_region_request())
|
||||
{
|
||||
// find the currently best performing function, we will do a trust region
|
||||
// step on it.
|
||||
|
@ -716,7 +722,7 @@ namespace dlib
|
|||
// if we have enough points to do a trust region step
|
||||
if (info->ub.num_points() > dims+1)
|
||||
{
|
||||
auto tmp = pick_next_sample_quad_interp(info->ub.get_points(),
|
||||
auto tmp = pick_next_sample_using_trust_region(info->ub.get_points(),
|
||||
info->radius, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
|
||||
//std::cout << "QP predicted improvement: "<< tmp.predicted_improvement << std::endl;
|
||||
if (tmp.predicted_improvement > min_trust_region_epsilon)
|
||||
|
@ -728,7 +734,7 @@ namespace dlib
|
|||
new_req.was_trust_region_generated_request = true;
|
||||
new_req.anchor_objective_value = info->best_objective_value;
|
||||
new_req.predicted_improvement = tmp.predicted_improvement;
|
||||
info->incomplete_evals.emplace_back(new_req);
|
||||
info->outstanding_evals.emplace_back(new_req);
|
||||
return function_evaluation_request(new_req, info);
|
||||
}
|
||||
}
|
||||
|
@ -747,7 +753,7 @@ namespace dlib
|
|||
// function with the largest upper bound for evaluation.
|
||||
for (auto& info : functions)
|
||||
{
|
||||
auto tmp = pick_next_sample_max_upper_bound_function(rnd,
|
||||
auto tmp = pick_next_sample_as_max_upper_bound(rnd,
|
||||
info->build_upper_bound_with_all_function_evals(), info->spec.lower, info->spec.upper,
|
||||
info->spec.is_integer_variable, num_random_samples);
|
||||
if (tmp.predicted_improvement > 0 && tmp.upper_bound > best_upper_bound)
|
||||
|
@ -764,7 +770,7 @@ namespace dlib
|
|||
outstanding_function_eval_request new_req;
|
||||
new_req.request_id = next_request_id++;
|
||||
new_req.x = std::move(next_sample);
|
||||
best_funct->incomplete_evals.emplace_back(new_req);
|
||||
best_funct->outstanding_evals.emplace_back(new_req);
|
||||
return function_evaluation_request(new_req, best_funct);
|
||||
}
|
||||
}
|
||||
|
@ -776,7 +782,7 @@ namespace dlib
|
|||
outstanding_function_eval_request new_req;
|
||||
new_req.request_id = next_request_id++;
|
||||
new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
|
||||
info->incomplete_evals.emplace_back(new_req);
|
||||
info->outstanding_evals.emplace_back(new_req);
|
||||
return function_evaluation_request(new_req, info);
|
||||
|
||||
}
|
||||
|
@ -839,9 +845,13 @@ namespace dlib
|
|||
{
|
||||
DLIB_CASSERT(0 <= value);
|
||||
relative_noise_magnitude = value;
|
||||
// recreate all the upper bound functions with the new relative noise magnitude
|
||||
for (auto& f : functions)
|
||||
f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude);
|
||||
if (m)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*m);
|
||||
// recreate all the upper bound functions with the new relative noise magnitude
|
||||
for (auto& f : functions)
|
||||
f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude);
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -881,8 +891,10 @@ namespace dlib
|
|||
size_t& idx
|
||||
) const
|
||||
{
|
||||
auto i = std::max_element(functions.begin(), functions.end(),
|
||||
[](const std::shared_ptr<gopt_impl::funct_info>& a, const std::shared_ptr<gopt_impl::funct_info>& b) { return a->best_objective_value < b->best_objective_value; });
|
||||
auto compare = [](const std::shared_ptr<gopt_impl::funct_info>& a, const std::shared_ptr<gopt_impl::funct_info>& b)
|
||||
{ return a->best_objective_value < b->best_objective_value; };
|
||||
|
||||
auto i = std::max_element(functions.begin(), functions.end(), compare);
|
||||
|
||||
idx = std::distance(functions.begin(),i);
|
||||
return *i;
|
||||
|
@ -891,12 +903,12 @@ namespace dlib
|
|||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
bool global_function_search::
|
||||
has_incomplete_trust_region_request (
|
||||
has_outstanding_trust_region_request (
|
||||
) const
|
||||
{
|
||||
for (auto& f : functions)
|
||||
{
|
||||
for (auto& i : f->incomplete_evals)
|
||||
for (auto& i : f->outstanding_evals)
|
||||
{
|
||||
if (i.was_trust_region_generated_request)
|
||||
return true;
|
||||
|
|
|
@ -79,7 +79,7 @@ namespace dlib
|
|||
size_t function_idx = 0;
|
||||
std::shared_ptr<std::mutex> m;
|
||||
upper_bound_function ub;
|
||||
std::vector<outstanding_function_eval_request> incomplete_evals;
|
||||
std::vector<outstanding_function_eval_request> outstanding_evals;
|
||||
matrix<double,0,1> best_x;
|
||||
double best_objective_value = -std::numeric_limits<double>::infinity();
|
||||
double radius = 0;
|
||||
|
@ -101,7 +101,7 @@ namespace dlib
|
|||
function_evaluation_request(function_evaluation_request&& item);
|
||||
function_evaluation_request& operator=(function_evaluation_request&& item);
|
||||
|
||||
void swap(function_evaluation_request& item);
|
||||
~function_evaluation_request();
|
||||
|
||||
size_t function_idx (
|
||||
) const;
|
||||
|
@ -112,12 +112,12 @@ namespace dlib
|
|||
bool has_been_evaluated (
|
||||
) const;
|
||||
|
||||
~function_evaluation_request();
|
||||
|
||||
void set (
|
||||
double y
|
||||
);
|
||||
|
||||
void swap(function_evaluation_request& item);
|
||||
|
||||
private:
|
||||
|
||||
friend class global_function_search;
|
||||
|
@ -218,7 +218,7 @@ namespace dlib
|
|||
size_t& idx
|
||||
) const;
|
||||
|
||||
bool has_incomplete_trust_region_request (
|
||||
bool has_outstanding_trust_region_request (
|
||||
) const;
|
||||
|
||||
|
||||
|
|
|
@ -89,14 +89,6 @@ namespace dlib
|
|||
moving from item causes item.has_been_evaluated() == true, TODO, clarify
|
||||
!*/
|
||||
|
||||
void swap(
|
||||
function_evaluation_request& item
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- swaps the state of *this and item
|
||||
!*/
|
||||
|
||||
~function_evaluation_request(
|
||||
);
|
||||
/*!
|
||||
|
@ -113,7 +105,6 @@ namespace dlib
|
|||
bool has_been_evaluated (
|
||||
) const;
|
||||
|
||||
|
||||
void set (
|
||||
double y
|
||||
);
|
||||
|
@ -124,6 +115,14 @@ namespace dlib
|
|||
- #has_been_evaluated() == true
|
||||
!*/
|
||||
|
||||
void swap(
|
||||
function_evaluation_request& item
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- swaps the state of *this and item
|
||||
!*/
|
||||
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -143,18 +142,6 @@ namespace dlib
|
|||
- #num_functions() == 0
|
||||
!*/
|
||||
|
||||
// This object can't be copied.
|
||||
global_function_search(const global_function_search&) = delete;
|
||||
global_function_search& operator=(const global_function_search& item) = delete;
|
||||
|
||||
global_function_search(global_function_search&& item) = default;
|
||||
global_function_search& operator=(global_function_search&& item) = default;
|
||||
/*!
|
||||
ensures
|
||||
- moves the state of item into *this
|
||||
- #item.num_functions() == 0
|
||||
!*/
|
||||
|
||||
explicit global_function_search(
|
||||
const function_spec& function
|
||||
);
|
||||
|
@ -169,13 +156,25 @@ namespace dlib
|
|||
const double relative_noise_magnitude = 0.001
|
||||
);
|
||||
|
||||
size_t num_functions(
|
||||
) const;
|
||||
// This object can't be copied.
|
||||
global_function_search(const global_function_search&) = delete;
|
||||
global_function_search& operator=(const global_function_search& item) = delete;
|
||||
|
||||
global_function_search(global_function_search&& item) = default;
|
||||
global_function_search& operator=(global_function_search&& item) = default;
|
||||
/*!
|
||||
ensures
|
||||
- moves the state of item into *this
|
||||
- #item.num_functions() == 0
|
||||
!*/
|
||||
|
||||
void set_seed (
|
||||
time_t seed
|
||||
);
|
||||
|
||||
size_t num_functions(
|
||||
) const;
|
||||
|
||||
void get_function_evaluations (
|
||||
std::vector<function_spec>& specs,
|
||||
std::vector<std::vector<function_evaluation>>& function_evals
|
||||
|
|
Loading…
Reference in New Issue