From 362bec1099c5e57c388eb6c66471affddba0074e Mon Sep 17 00:00:00 2001 From: Plumtus Date: Wed, 7 Jun 2017 02:19:23 +0800 Subject: [PATCH] Reinitialize averagers when saved sync file was reloaded. (#629) --- dlib/dnn/trainer.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dlib/dnn/trainer.h b/dlib/dnn/trainer.h index 4a822a9f9..33e89d6c2 100644 --- a/dlib/dnn/trainer.h +++ b/dlib/dnn/trainer.h @@ -713,7 +713,7 @@ namespace dlib // We can't do this outside the loop because the tensors that get // averaged need to be allocated to their devices before we call set() // so that the averagers can determine how best to average them. - if (averagers.size() == 0) + if (averagers.size() == 0 || sync_file_reloaded) { averagers = std::vector(net_type::num_computational_layers); // setup the averagers to point to the tensors in the networks. @@ -736,6 +736,8 @@ namespace dlib if (temp[0]->size() != 0) averagers[i].set(temp); } + + sync_file_reloaded = false; } @@ -855,6 +857,7 @@ namespace dlib prob_loss_increasing_thresh_max_value = 0.99999; prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value; updated_net_since_last_sync = false; + sync_file_reloaded = false; start(); } @@ -979,6 +982,7 @@ namespace dlib { std::ifstream fin(sync_filename, std::ios::binary); deserialize(*this, fin); + sync_file_reloaded = true; if (verbose) std::cout << "Loss has been increasing, reloading saved state from " << sync_filename << std::endl; } @@ -1230,6 +1234,7 @@ namespace dlib double prob_loss_increasing_thresh; std::atomic updated_net_since_last_sync; + bool sync_file_reloaded; }; // ----------------------------------------------------------------------------------------