From dd62b0e2ff388f8e35ad2f12ac504199196b3ac7 Mon Sep 17 00:00:00 2001 From: Davis King Date: Sun, 20 Aug 2017 19:28:08 -0400 Subject: [PATCH] Made the dnn_trainer not forget all the previous loss values it knows about when it determines that there have been a lot of steps without progress and shrinks the learning rate. Instead, it removes only the oldest 100. The problem with the old way of removing all the loss values in the history was that if you set the steps without progress threshold to a really high number you would often observe that the last few learning rate values were obviously not making progress, however, since all the previous loss values were forgotten the trainer needed to fully populate it's loss history from scratch before it would figure this out. This new style makes the trainer not waste time running this excessive optimization of obviously useless mini-batches. --- dlib/dnn/trainer.h | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/dlib/dnn/trainer.h b/dlib/dnn/trainer.h index 3f64f0322..d6fd15269 100644 --- a/dlib/dnn/trainer.h +++ b/dlib/dnn/trainer.h @@ -699,7 +699,10 @@ namespace dlib // optimization has flattened out, so drop the learning rate. learning_rate = learning_rate_shrink*learning_rate; test_steps_without_progress = 0; - test_previous_loss_values.clear(); + // Empty out some of the previous loss values so that test_steps_without_progress + // will decrease below test_iter_without_progress_thresh. + for (int cnt = 0; cnt < previous_loss_values_dump_amount && test_previous_loss_values.size() > 0; ++cnt) + test_previous_loss_values.pop_front(); } } } @@ -804,7 +807,7 @@ namespace dlib // this because sometimes a mini-batch might be bad and cause the // loss to suddenly jump up, making count_steps_without_decrease() // return a large number. But if we discard the top 10% of the - // values in previous_loss_values when we are robust to that kind + // values in previous_loss_values then we are robust to that kind // of noise. Another way of looking at it, if the reason // count_steps_without_decrease() returns a large value is only // because the most recent loss values have suddenly been large, @@ -816,7 +819,10 @@ namespace dlib // optimization has flattened out, so drop the learning rate. learning_rate = learning_rate_shrink*learning_rate; steps_without_progress = 0; - previous_loss_values.clear(); + // Empty out some of the previous loss values so that steps_without_progress + // will decrease below iter_without_progress_thresh. + for (int cnt = 0; cnt < previous_loss_values_dump_amount && previous_loss_values.size() > 0; ++cnt) + previous_loss_values.pop_front(); } } } @@ -873,6 +879,7 @@ namespace dlib prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value; updated_net_since_last_sync = false; sync_file_reloaded = false; + previous_loss_values_dump_amount = 100; start(); } @@ -883,7 +890,7 @@ namespace dlib friend void serialize(const dnn_trainer& item, std::ostream& out) { item.wait_for_thread_to_pause(); - int version = 9; + int version = 10; serialize(version, out); size_t nl = dnn_trainer::num_layers; @@ -909,6 +916,7 @@ namespace dlib serialize(item.test_iter_without_progress_thresh.load(), out); serialize(item.test_steps_without_progress.load(), out); serialize(item.test_previous_loss_values, out); + serialize(item.previous_loss_values_dump_amount, out); } friend void deserialize(dnn_trainer& item, std::istream& in) @@ -916,7 +924,7 @@ namespace dlib item.wait_for_thread_to_pause(); int version = 0; deserialize(version, in); - if (version != 9) + if (version != 10) throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer."); size_t num_layers = 0; @@ -952,6 +960,7 @@ namespace dlib deserialize(ltemp, in); item.test_iter_without_progress_thresh = ltemp; deserialize(ltemp, in); item.test_steps_without_progress = ltemp; deserialize(item.test_previous_loss_values, in); + deserialize(item.previous_loss_values_dump_amount, in); if (item.devices.size() > 1) { @@ -1259,6 +1268,7 @@ namespace dlib std::atomic updated_net_since_last_sync; bool sync_file_reloaded; + int previous_loss_values_dump_amount; }; // ----------------------------------------------------------------------------------------