mirror of https://github.com/davisking/dlib.git
Made the dnn_trainer not forget all the previous loss values it knows about
when it determines that there have been a lot of steps without progress and shrinks the learning rate. Instead, it removes only the oldest 100. The problem with the old way of removing all the loss values in the history was that if you set the steps without progress threshold to a really high number you would often observe that the last few learning rate values were obviously not making progress, however, since all the previous loss values were forgotten the trainer needed to fully populate it's loss history from scratch before it would figure this out. This new style makes the trainer not waste time running this excessive optimization of obviously useless mini-batches.
This commit is contained in:
parent
618f1084d2
commit
dd62b0e2ff
|
@ -699,7 +699,10 @@ namespace dlib
|
|||
// optimization has flattened out, so drop the learning rate.
|
||||
learning_rate = learning_rate_shrink*learning_rate;
|
||||
test_steps_without_progress = 0;
|
||||
test_previous_loss_values.clear();
|
||||
// Empty out some of the previous loss values so that test_steps_without_progress
|
||||
// will decrease below test_iter_without_progress_thresh.
|
||||
for (int cnt = 0; cnt < previous_loss_values_dump_amount && test_previous_loss_values.size() > 0; ++cnt)
|
||||
test_previous_loss_values.pop_front();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -804,7 +807,7 @@ namespace dlib
|
|||
// this because sometimes a mini-batch might be bad and cause the
|
||||
// loss to suddenly jump up, making count_steps_without_decrease()
|
||||
// return a large number. But if we discard the top 10% of the
|
||||
// values in previous_loss_values when we are robust to that kind
|
||||
// values in previous_loss_values then we are robust to that kind
|
||||
// of noise. Another way of looking at it, if the reason
|
||||
// count_steps_without_decrease() returns a large value is only
|
||||
// because the most recent loss values have suddenly been large,
|
||||
|
@ -816,7 +819,10 @@ namespace dlib
|
|||
// optimization has flattened out, so drop the learning rate.
|
||||
learning_rate = learning_rate_shrink*learning_rate;
|
||||
steps_without_progress = 0;
|
||||
previous_loss_values.clear();
|
||||
// Empty out some of the previous loss values so that steps_without_progress
|
||||
// will decrease below iter_without_progress_thresh.
|
||||
for (int cnt = 0; cnt < previous_loss_values_dump_amount && previous_loss_values.size() > 0; ++cnt)
|
||||
previous_loss_values.pop_front();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -873,6 +879,7 @@ namespace dlib
|
|||
prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
|
||||
updated_net_since_last_sync = false;
|
||||
sync_file_reloaded = false;
|
||||
previous_loss_values_dump_amount = 100;
|
||||
start();
|
||||
}
|
||||
|
||||
|
@ -883,7 +890,7 @@ namespace dlib
|
|||
friend void serialize(const dnn_trainer& item, std::ostream& out)
|
||||
{
|
||||
item.wait_for_thread_to_pause();
|
||||
int version = 9;
|
||||
int version = 10;
|
||||
serialize(version, out);
|
||||
|
||||
size_t nl = dnn_trainer::num_layers;
|
||||
|
@ -909,6 +916,7 @@ namespace dlib
|
|||
serialize(item.test_iter_without_progress_thresh.load(), out);
|
||||
serialize(item.test_steps_without_progress.load(), out);
|
||||
serialize(item.test_previous_loss_values, out);
|
||||
serialize(item.previous_loss_values_dump_amount, out);
|
||||
|
||||
}
|
||||
friend void deserialize(dnn_trainer& item, std::istream& in)
|
||||
|
@ -916,7 +924,7 @@ namespace dlib
|
|||
item.wait_for_thread_to_pause();
|
||||
int version = 0;
|
||||
deserialize(version, in);
|
||||
if (version != 9)
|
||||
if (version != 10)
|
||||
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
|
||||
|
||||
size_t num_layers = 0;
|
||||
|
@ -952,6 +960,7 @@ namespace dlib
|
|||
deserialize(ltemp, in); item.test_iter_without_progress_thresh = ltemp;
|
||||
deserialize(ltemp, in); item.test_steps_without_progress = ltemp;
|
||||
deserialize(item.test_previous_loss_values, in);
|
||||
deserialize(item.previous_loss_values_dump_amount, in);
|
||||
|
||||
if (item.devices.size() > 1)
|
||||
{
|
||||
|
@ -1259,6 +1268,7 @@ namespace dlib
|
|||
std::atomic<bool> updated_net_since_last_sync;
|
||||
|
||||
bool sync_file_reloaded;
|
||||
int previous_loss_values_dump_amount;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in New Issue