From dc45871a319095b9a776c65ebc6fbc7b9401faf4 Mon Sep 17 00:00:00 2001 From: Davis King Date: Sun, 20 Aug 2017 20:08:03 -0400 Subject: [PATCH] Made the loss value management a little more conservative. --- dlib/dnn/trainer.h | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/dlib/dnn/trainer.h b/dlib/dnn/trainer.h index 8f0fb1066..d4bf09b37 100644 --- a/dlib/dnn/trainer.h +++ b/dlib/dnn/trainer.h @@ -701,7 +701,7 @@ namespace dlib test_steps_without_progress = 0; // Empty out some of the previous loss values so that test_steps_without_progress // will decrease below test_iter_without_progress_thresh. - for (int cnt = 0; cnt < previous_loss_values_dump_amount && test_previous_loss_values.size() > 0; ++cnt) + for (int cnt = 0; cnt < test_previous_loss_values_dump_amount && test_previous_loss_values.size() > 0; ++cnt) test_previous_loss_values.pop_front(); } } @@ -879,7 +879,8 @@ namespace dlib prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value; updated_net_since_last_sync = false; sync_file_reloaded = false; - previous_loss_values_dump_amount = 100; + previous_loss_values_dump_amount = 400; + test_previous_loss_values_dump_amount = 100; start(); } @@ -890,7 +891,7 @@ namespace dlib friend void serialize(const dnn_trainer& item, std::ostream& out) { item.wait_for_thread_to_pause(); - int version = 10; + int version = 11; serialize(version, out); size_t nl = dnn_trainer::num_layers; @@ -917,6 +918,7 @@ namespace dlib serialize(item.test_steps_without_progress.load(), out); serialize(item.test_previous_loss_values, out); serialize(item.previous_loss_values_dump_amount, out); + serialize(item.test_previous_loss_values_dump_amount, out); } friend void deserialize(dnn_trainer& item, std::istream& in) @@ -924,7 +926,7 @@ namespace dlib item.wait_for_thread_to_pause(); int version = 0; deserialize(version, in); - if (version != 10) + if (version != 11) throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer."); size_t num_layers = 0; @@ -961,6 +963,7 @@ namespace dlib deserialize(ltemp, in); item.test_steps_without_progress = ltemp; deserialize(item.test_previous_loss_values, in); deserialize(item.previous_loss_values_dump_amount, in); + deserialize(item.test_previous_loss_values_dump_amount, in); if (item.devices.size() > 1) { @@ -1269,6 +1272,7 @@ namespace dlib bool sync_file_reloaded; int previous_loss_values_dump_amount; + int test_previous_loss_values_dump_amount; }; // ----------------------------------------------------------------------------------------