diff --git a/dlib/dnn/trainer.h b/dlib/dnn/trainer.h index 2207db387..333c03a4e 100644 --- a/dlib/dnn/trainer.h +++ b/dlib/dnn/trainer.h @@ -55,6 +55,10 @@ namespace dlib } } + enum class force_flush_to_disk { + no = 0, + yes = 1 + }; template < typename net_type, @@ -135,10 +139,11 @@ namespace dlib } net_type& get_net ( + force_flush_to_disk force_flush = force_flush_to_disk::yes ) { wait_for_thread_to_pause(); - sync_to_disk(true); + sync_to_disk(force_flush == force_flush_to_disk::yes); propagate_exception(); return net; } diff --git a/dlib/dnn/trainer_abstract.h b/dlib/dnn/trainer_abstract.h index 8c4b09606..3bfb6dc99 100644 --- a/dlib/dnn/trainer_abstract.h +++ b/dlib/dnn/trainer_abstract.h @@ -12,6 +12,13 @@ namespace dlib { +// ---------------------------------------------------------------------------------------- + + enum class force_flush_to_disk { + no = 0, + yes = 1 + }; + // ---------------------------------------------------------------------------------------- template < @@ -92,6 +99,7 @@ namespace dlib !*/ net_type& get_net ( + force_flush_to_disk force_flush = force_flush_to_disk::yes ); /*! ensures @@ -102,8 +110,9 @@ namespace dlib dnn_trainer's constructor. - This function blocks until all threads inside the dnn_trainer have stopped touching the net. - - This function will sync the trainer state to disk if the current state - hasn't already been synced to disk since the last network modification. + - If force_flush is yes, then this function will sync the trainer state to + disk if the current state hasn't already been synced to disk since the + last network modification. !*/ const std::vector& get_solvers (