diff --git a/dlib/dnn/trainer.h b/dlib/dnn/trainer.h index 18962de20..d8bf45f05 100644 --- a/dlib/dnn/trainer.h +++ b/dlib/dnn/trainer.h @@ -697,10 +697,10 @@ namespace dlib // Check if we should shrink the learning rate based on how the test // error has been doing lately. - if (learning_rate_shrink != 1 && steps_since_last_learning_rate_shrink > iter_without_progress_thresh) + if (learning_rate_shrink != 1) { test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values); - if (test_steps_without_progress >= test_iter_without_progress_thresh) + if (test_steps_without_progress >= test_iter_without_progress_thresh && steps_since_last_learning_rate_shrink >= test_iter_without_progress_thresh) { test_steps_without_progress = count_steps_without_decrease_robust(test_previous_loss_values); if (test_steps_without_progress >= test_iter_without_progress_thresh) @@ -809,13 +809,11 @@ namespace dlib // have a "budget" that prevents us from calling // count_steps_without_decrease() every iteration. We do this because // it can be expensive to compute when previous_loss_values is large. - if (gradient_check_budget > iter_without_progress_thresh - && learning_rate_shrink != 1 - && steps_since_last_learning_rate_shrink > iter_without_progress_thresh) + if (gradient_check_budget > iter_without_progress_thresh && learning_rate_shrink != 1) { gradient_check_budget = 0; steps_without_progress = count_steps_without_decrease(previous_loss_values); - if (steps_without_progress >= iter_without_progress_thresh) + if (steps_without_progress >= iter_without_progress_thresh && steps_since_last_learning_rate_shrink >= iter_without_progress_thresh) { // Double check that we aren't seeing decrease. This second check // discards the top 10% largest values and checks again. We do