mirror of https://github.com/davisking/dlib.git
fix set_learning_rate_multipliers_range not working (#2304)
* fix set_learning_rate_multipliers not working * add tests for set_learning_rate_multipliers
This commit is contained in:
parent
42e6ace845
commit
04a3534af1
|
@ -309,7 +309,7 @@ namespace dlib
|
|||
static_assert(end <= net_type::num_layers, "Invalid range");
|
||||
DLIB_CASSERT(learning_rate_multiplier >= 0);
|
||||
impl::visitor_learning_rate_multiplier temp(learning_rate_multiplier);
|
||||
visit_layers_range<begin, end>(net, temp);
|
||||
visit_computational_layers_range<begin, end>(net, temp);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
|
|
@ -3920,6 +3920,7 @@ namespace
|
|||
template <typename SUBNET> using dense_layer_32 = dense_layer<32, 8, SUBNET>;
|
||||
void test_disable_duplicative_biases()
|
||||
{
|
||||
print_spinner();
|
||||
using net_type = fc<10, relu<layer_norm<fc<15, relu<bn_fc<fc<20,
|
||||
relu<layer_norm<conp<32, 3, 1,
|
||||
repeat<2, dense_layer_32,
|
||||
|
@ -3946,6 +3947,24 @@ namespace
|
|||
DLIB_TEST(layer<31>(net).layer_details().bias_is_disabled() == true);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void test_set_learning_rate_multipliers()
|
||||
{
|
||||
print_spinner();
|
||||
using net_type = loss_binary_log<fc<2, relu<bn_con<con<16, 5, 5, 2, 2, input<matrix<float>>>>>>>;
|
||||
net_type net;
|
||||
set_all_learning_rate_multipliers(net, 0.5);
|
||||
DLIB_TEST(layer<1>(net).layer_details().get_learning_rate_multiplier() == 0.5);
|
||||
DLIB_TEST(layer<3>(net).layer_details().get_learning_rate_multiplier() == 0.5);
|
||||
DLIB_TEST(layer<4>(net).layer_details().get_learning_rate_multiplier() == 0.5);
|
||||
set_learning_rate_multipliers_range<2, 4>(net, 0.1);
|
||||
set_learning_rate_multipliers_range<4, 6>(net, 0.01);
|
||||
DLIB_TEST(layer<1>(net).layer_details().get_learning_rate_multiplier() == 0.5);
|
||||
DLIB_TEST(layer<3>(net).layer_details().get_learning_rate_multiplier() == 0.1);
|
||||
DLIB_TEST(layer<4>(net).layer_details().get_learning_rate_multiplier() == 0.01);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// This test really just checks if the mmod loss goes negative when a whole lot of overlapping
|
||||
|
@ -4132,6 +4151,7 @@ namespace
|
|||
test_loss_mmod();
|
||||
test_layers_scale_and_scale_prev();
|
||||
test_disable_duplicative_biases();
|
||||
test_set_learning_rate_multipliers();
|
||||
}
|
||||
|
||||
void perform_test()
|
||||
|
|
Loading…
Reference in New Issue