diff --git a/dlib/dnn/trainer.h b/dlib/dnn/trainer.h index 4a822a9f9..33e89d6c2 100644 --- a/dlib/dnn/trainer.h +++ b/dlib/dnn/trainer.h @@ -713,7 +713,7 @@ namespace dlib // We can't do this outside the loop because the tensors that get // averaged need to be allocated to their devices before we call set() // so that the averagers can determine how best to average them. - if (averagers.size() == 0) + if (averagers.size() == 0 || sync_file_reloaded) { averagers = std::vector(net_type::num_computational_layers); // setup the averagers to point to the tensors in the networks. @@ -736,6 +736,8 @@ namespace dlib if (temp[0]->size() != 0) averagers[i].set(temp); } + + sync_file_reloaded = false; } @@ -855,6 +857,7 @@ namespace dlib prob_loss_increasing_thresh_max_value = 0.99999; prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value; updated_net_since_last_sync = false; + sync_file_reloaded = false; start(); } @@ -979,6 +982,7 @@ namespace dlib { std::ifstream fin(sync_filename, std::ios::binary); deserialize(*this, fin); + sync_file_reloaded = true; if (verbose) std::cout << "Loss has been increasing, reloading saved state from " << sync_filename << std::endl; } @@ -1230,6 +1234,7 @@ namespace dlib double prob_loss_increasing_thresh; std::atomic updated_net_since_last_sync; + bool sync_file_reloaded; }; // ----------------------------------------------------------------------------------------