make update_parameters() a little more uniform

This commit is contained in:
Davis King 2020-03-29 11:19:37 -04:00
parent fd0145345e
commit c79f64f52d
2 changed files with 51 additions and 1 deletions

View File

@ -1003,7 +1003,7 @@ namespace dlib
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
subnetwork->update_parameters(make_sstack(solvers), learning_rate);
update_parameters(make_sstack(solvers), learning_rate);
}
const tensor& get_parameter_gradient(
@ -1369,6 +1369,12 @@ namespace dlib
}
}
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const tensor& get_parameter_gradient(
) const { return params_grad; }
@ -1609,6 +1615,12 @@ namespace dlib
subnetwork.update_parameters(solvers, learning_rate);
}
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const tensor& get_parameter_gradient(
) const { return params_grad; }
@ -1905,6 +1917,12 @@ namespace dlib
subnetwork.update_parameters(solvers.pop(comp_layers_in_each_group*details.size()),learning_rate);
}
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; }
@ -2135,6 +2153,12 @@ namespace dlib
// nothing to do
}
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const subnet_type& subnet() const { return input_layer; }
subnet_type& subnet() { return input_layer; }
@ -2550,6 +2574,12 @@ namespace dlib
subnetwork.update_parameters(solvers, learning_rate);
}
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; }
const loss_details_type& loss_details() const { return loss; }
@ -2940,6 +2970,12 @@ namespace dlib
subnetwork.update_parameters(solvers, learning_rate);
}
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const tensor& get_parameter_gradient(
) const { return params_grad; }

View File

@ -639,6 +639,13 @@ namespace dlib
- The solvers use the given learning rate.
!*/
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{ update_parameters(make_sstack(solvers), learning_rate); }
/*!
Convenience method for calling update_parameters()
!*/
void clean(
);
/*!
@ -1155,6 +1162,13 @@ namespace dlib
- The solvers use the given learning rate.
!*/
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{ update_parameters(make_sstack(solvers), learning_rate); }
/*!
Convenience method for calling update_parameters()
!*/
// -------------
void clean (