fix set_learning_rate_multipliers_range not working (#2304)

* fix set_learning_rate_multipliers not working

* add tests for set_learning_rate_multipliers
This commit is contained in:
Adrià Arrufat 2021-02-11 11:55:54 +09:00 committed by GitHub
parent 42e6ace845
commit 04a3534af1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 3 deletions

View File

@ -281,7 +281,7 @@ namespace dlib
{ {
set_learning_rate_multiplier(l, new_learning_rate_multiplier); set_learning_rate_multiplier(l, new_learning_rate_multiplier);
} }
private: private:
double new_learning_rate_multiplier; double new_learning_rate_multiplier;
@ -309,7 +309,7 @@ namespace dlib
static_assert(end <= net_type::num_layers, "Invalid range"); static_assert(end <= net_type::num_layers, "Invalid range");
DLIB_CASSERT(learning_rate_multiplier >= 0); DLIB_CASSERT(learning_rate_multiplier >= 0);
impl::visitor_learning_rate_multiplier temp(learning_rate_multiplier); impl::visitor_learning_rate_multiplier temp(learning_rate_multiplier);
visit_layers_range<begin, end>(net, temp); visit_computational_layers_range<begin, end>(net, temp);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------

View File

@ -3920,6 +3920,7 @@ namespace
template <typename SUBNET> using dense_layer_32 = dense_layer<32, 8, SUBNET>; template <typename SUBNET> using dense_layer_32 = dense_layer<32, 8, SUBNET>;
void test_disable_duplicative_biases() void test_disable_duplicative_biases()
{ {
print_spinner();
using net_type = fc<10, relu<layer_norm<fc<15, relu<bn_fc<fc<20, using net_type = fc<10, relu<layer_norm<fc<15, relu<bn_fc<fc<20,
relu<layer_norm<conp<32, 3, 1, relu<layer_norm<conp<32, 3, 1,
repeat<2, dense_layer_32, repeat<2, dense_layer_32,
@ -3947,7 +3948,25 @@ namespace
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void test_set_learning_rate_multipliers()
{
print_spinner();
using net_type = loss_binary_log<fc<2, relu<bn_con<con<16, 5, 5, 2, 2, input<matrix<float>>>>>>>;
net_type net;
set_all_learning_rate_multipliers(net, 0.5);
DLIB_TEST(layer<1>(net).layer_details().get_learning_rate_multiplier() == 0.5);
DLIB_TEST(layer<3>(net).layer_details().get_learning_rate_multiplier() == 0.5);
DLIB_TEST(layer<4>(net).layer_details().get_learning_rate_multiplier() == 0.5);
set_learning_rate_multipliers_range<2, 4>(net, 0.1);
set_learning_rate_multipliers_range<4, 6>(net, 0.01);
DLIB_TEST(layer<1>(net).layer_details().get_learning_rate_multiplier() == 0.5);
DLIB_TEST(layer<3>(net).layer_details().get_learning_rate_multiplier() == 0.1);
DLIB_TEST(layer<4>(net).layer_details().get_learning_rate_multiplier() == 0.01);
}
// ----------------------------------------------------------------------------------------
// This test really just checks if the mmod loss goes negative when a whole lot of overlapping // This test really just checks if the mmod loss goes negative when a whole lot of overlapping
// truth rectangles are given. // truth rectangles are given.
void test_loss_mmod() void test_loss_mmod()
@ -4132,6 +4151,7 @@ namespace
test_loss_mmod(); test_loss_mmod();
test_layers_scale_and_scale_prev(); test_layers_scale_and_scale_prev();
test_disable_duplicative_biases(); test_disable_duplicative_biases();
test_set_learning_rate_multipliers();
} }
void perform_test() void perform_test()