diff --git a/dlib/dnn/core_abstract.h b/dlib/dnn/core_abstract.h index 53d7e194a..a24c98251 100644 --- a/dlib/dnn/core_abstract.h +++ b/dlib/dnn/core_abstract.h @@ -67,6 +67,31 @@ namespace dlib (except computes it using a numerically accurate method) !*/ +// ---------------------------------------------------------------------------------------- + + bool dnn_prefer_fastest_algorithms( + ); + /*! + ensures + - If dlib should prefer to use fast algorithms rather than ones that use less + RAM then this function returns true and false otherwise. + - On program startup this function will default to true. + !*/ + + void set_dnn_prefer_fastest_algorithms( + ); + /*! + ensures + - #dnn_prefer_fastest_algorithms() == true + !*/ + + void set_dnn_prefer_smallest_algorithms( + ); + /*! + ensures + - #dnn_prefer_fastest_algorithms() == false + !*/ + // ---------------------------------------------------------------------------------------- template < diff --git a/dlib/dnn/cudnn_dlibapi.cpp b/dlib/dnn/cudnn_dlibapi.cpp index ce1cdbae9..9a27d8b4d 100644 --- a/dlib/dnn/cudnn_dlibapi.cpp +++ b/dlib/dnn/cudnn_dlibapi.cpp @@ -13,6 +13,7 @@ #include "cuda_utils.h" #include "cpu_dlib.h" #include "cuda_dlib.h" +#include "tensor_tools.h" static const char* cudnn_get_error_string(cudnnStatus_t s) { @@ -773,7 +774,7 @@ namespace dlib (const cudnnFilterDescriptor_t)filter_handle, (const cudnnConvolutionDescriptor_t)conv_handle, descriptor(dest_desc), - CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, // or CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, + dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_FWD_PREFER_FASTEST:CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, std::numeric_limits::max(), &forward_best_algo)); forward_algo = forward_best_algo; @@ -797,7 +798,7 @@ namespace dlib descriptor(dest_desc), (const cudnnConvolutionDescriptor_t)conv_handle, descriptor(data), - CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, + dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE, std::numeric_limits::max(), &backward_data_best_algo)); backward_data_algo = backward_data_best_algo; @@ -821,7 +822,7 @@ namespace dlib descriptor(dest_desc), (const cudnnConvolutionDescriptor_t)conv_handle, (const cudnnFilterDescriptor_t)filter_handle, - CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, + dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE, std::numeric_limits::max(), &backward_filters_best_algo)); backward_filters_algo = backward_filters_best_algo; diff --git a/dlib/dnn/tensor_tools.cpp b/dlib/dnn/tensor_tools.cpp index fa96d1dce..5b72dfa1a 100644 --- a/dlib/dnn/tensor_tools.cpp +++ b/dlib/dnn/tensor_tools.cpp @@ -5,6 +5,38 @@ #include "tensor_tools.h" #include "../string.h" +#include + +namespace dlib +{ + namespace + { + std::atomic& dnn_prefer_fastest_algo ( + ) + { + static std::atomic var(true); + return var; + } + } + + bool dnn_prefer_fastest_algorithms ( + ) + { + return dnn_prefer_fastest_algo(); + } + + void set_dnn_prefer_fastest_algorithms( + ) + { + dnn_prefer_fastest_algo() = true; + } + + void set_dnn_prefer_smallest_algorithms( + ) + { + dnn_prefer_fastest_algo() = false; + } +} namespace dlib { namespace tt { diff --git a/dlib/dnn/tensor_tools.h b/dlib/dnn/tensor_tools.h index 433dbee5a..4febd3ce6 100644 --- a/dlib/dnn/tensor_tools.h +++ b/dlib/dnn/tensor_tools.h @@ -11,6 +11,13 @@ #include "cuda_dlib.h" #include "../rand.h" +namespace dlib +{ + bool dnn_prefer_fastest_algorithms(); + void set_dnn_prefer_fastest_algorithms(); + void set_dnn_prefer_smallest_algorithms(); +} + namespace dlib { namespace tt {