mirror of https://github.com/davisking/dlib.git
Reinitialize averagers when saved sync file was reloaded. (#629)
This commit is contained in:
parent
d2b80bfe6f
commit
362bec1099
|
@ -713,7 +713,7 @@ namespace dlib
|
|||
// We can't do this outside the loop because the tensors that get
|
||||
// averaged need to be allocated to their devices before we call set()
|
||||
// so that the averagers can determine how best to average them.
|
||||
if (averagers.size() == 0)
|
||||
if (averagers.size() == 0 || sync_file_reloaded)
|
||||
{
|
||||
averagers = std::vector<tt::multi_device_tensor_averager>(net_type::num_computational_layers);
|
||||
// setup the averagers to point to the tensors in the networks.
|
||||
|
@ -736,6 +736,8 @@ namespace dlib
|
|||
if (temp[0]->size() != 0)
|
||||
averagers[i].set(temp);
|
||||
}
|
||||
|
||||
sync_file_reloaded = false;
|
||||
}
|
||||
|
||||
|
||||
|
@ -855,6 +857,7 @@ namespace dlib
|
|||
prob_loss_increasing_thresh_max_value = 0.99999;
|
||||
prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
|
||||
updated_net_since_last_sync = false;
|
||||
sync_file_reloaded = false;
|
||||
start();
|
||||
}
|
||||
|
||||
|
@ -979,6 +982,7 @@ namespace dlib
|
|||
{
|
||||
std::ifstream fin(sync_filename, std::ios::binary);
|
||||
deserialize(*this, fin);
|
||||
sync_file_reloaded = true;
|
||||
if (verbose)
|
||||
std::cout << "Loss has been increasing, reloading saved state from " << sync_filename << std::endl;
|
||||
}
|
||||
|
@ -1230,6 +1234,7 @@ namespace dlib
|
|||
double prob_loss_increasing_thresh;
|
||||
std::atomic<bool> updated_net_since_last_sync;
|
||||
|
||||
bool sync_file_reloaded;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in New Issue