Reinitialize averagers when saved sync file was reloaded. (#629)

This commit is contained in:
Plumtus 2017-06-07 02:19:23 +08:00 committed by Davis E. King
parent d2b80bfe6f
commit 362bec1099
1 changed files with 6 additions and 1 deletions

View File

@ -713,7 +713,7 @@ namespace dlib
// We can't do this outside the loop because the tensors that get
// averaged need to be allocated to their devices before we call set()
// so that the averagers can determine how best to average them.
if (averagers.size() == 0)
if (averagers.size() == 0 || sync_file_reloaded)
{
averagers = std::vector<tt::multi_device_tensor_averager>(net_type::num_computational_layers);
// setup the averagers to point to the tensors in the networks.
@ -736,6 +736,8 @@ namespace dlib
if (temp[0]->size() != 0)
averagers[i].set(temp);
}
sync_file_reloaded = false;
}
@ -855,6 +857,7 @@ namespace dlib
prob_loss_increasing_thresh_max_value = 0.99999;
prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
updated_net_since_last_sync = false;
sync_file_reloaded = false;
start();
}
@ -979,6 +982,7 @@ namespace dlib
{
std::ifstream fin(sync_filename, std::ios::binary);
deserialize(*this, fin);
sync_file_reloaded = true;
if (verbose)
std::cout << "Loss has been increasing, reloading saved state from " << sync_filename << std::endl;
}
@ -1230,6 +1234,7 @@ namespace dlib
double prob_loss_increasing_thresh;
std::atomic<bool> updated_net_since_last_sync;
bool sync_file_reloaded;
};
// ----------------------------------------------------------------------------------------