diff --git a/dlib/dnn/core.h b/dlib/dnn/core.h index f44d9a913..696b37f96 100644 --- a/dlib/dnn/core.h +++ b/dlib/dnn/core.h @@ -1003,7 +1003,7 @@ namespace dlib template void update_parameters(std::vector& solvers, double learning_rate) { - subnetwork->update_parameters(make_sstack(solvers), learning_rate); + update_parameters(make_sstack(solvers), learning_rate); } const tensor& get_parameter_gradient( @@ -1369,6 +1369,12 @@ namespace dlib } } + template + void update_parameters(std::vector& solvers, double learning_rate) + { + update_parameters(make_sstack(solvers), learning_rate); + } + const tensor& get_parameter_gradient( ) const { return params_grad; } @@ -1609,6 +1615,12 @@ namespace dlib subnetwork.update_parameters(solvers, learning_rate); } + template + void update_parameters(std::vector& solvers, double learning_rate) + { + update_parameters(make_sstack(solvers), learning_rate); + } + const tensor& get_parameter_gradient( ) const { return params_grad; } @@ -1905,6 +1917,12 @@ namespace dlib subnetwork.update_parameters(solvers.pop(comp_layers_in_each_group*details.size()),learning_rate); } + template + void update_parameters(std::vector& solvers, double learning_rate) + { + update_parameters(make_sstack(solvers), learning_rate); + } + const subnet_type& subnet() const { return subnetwork; } subnet_type& subnet() { return subnetwork; } @@ -2135,6 +2153,12 @@ namespace dlib // nothing to do } + template + void update_parameters(std::vector& solvers, double learning_rate) + { + update_parameters(make_sstack(solvers), learning_rate); + } + const subnet_type& subnet() const { return input_layer; } subnet_type& subnet() { return input_layer; } @@ -2550,6 +2574,12 @@ namespace dlib subnetwork.update_parameters(solvers, learning_rate); } + template + void update_parameters(std::vector& solvers, double learning_rate) + { + update_parameters(make_sstack(solvers), learning_rate); + } + const subnet_type& subnet() const { return subnetwork; } subnet_type& subnet() { return subnetwork; } const loss_details_type& loss_details() const { return loss; } @@ -2940,6 +2970,12 @@ namespace dlib subnetwork.update_parameters(solvers, learning_rate); } + template + void update_parameters(std::vector& solvers, double learning_rate) + { + update_parameters(make_sstack(solvers), learning_rate); + } + const tensor& get_parameter_gradient( ) const { return params_grad; } diff --git a/dlib/dnn/core_abstract.h b/dlib/dnn/core_abstract.h index 7b98a48ac..c8352a8c5 100644 --- a/dlib/dnn/core_abstract.h +++ b/dlib/dnn/core_abstract.h @@ -639,6 +639,13 @@ namespace dlib - The solvers use the given learning rate. !*/ + template + void update_parameters(std::vector& solvers, double learning_rate) + { update_parameters(make_sstack(solvers), learning_rate); } + /*! + Convenience method for calling update_parameters() + !*/ + void clean( ); /*! @@ -1155,6 +1162,13 @@ namespace dlib - The solvers use the given learning rate. !*/ + template + void update_parameters(std::vector& solvers, double learning_rate) + { update_parameters(make_sstack(solvers), learning_rate); } + /*! + Convenience method for calling update_parameters() + !*/ + // ------------- void clean (