mirror of https://github.com/davisking/dlib.git
Cleanup cuDNN conv algorithm selection code slightly by moving it into its own function.
This commit is contained in:
parent
4d18e0d0c7
commit
6c3243f766
|
@ -776,6 +776,134 @@ namespace dlib
|
|||
return best_alg;
|
||||
}
|
||||
|
||||
void tensor_conv::
|
||||
select_best_algorithms (
|
||||
const tensor& data,
|
||||
const tensor_descriptor& dest_desc
|
||||
)
|
||||
{
|
||||
// Pick which forward algorithm we will use and allocate the necessary
|
||||
// workspace buffer.
|
||||
cudnnConvolutionFwdAlgo_t forward_best_algo;
|
||||
#if CUDNN_MAJOR >= 8
|
||||
{
|
||||
int num_possilbe_algorithms = 0;
|
||||
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(context(), &num_possilbe_algorithms));
|
||||
std::vector<cudnnConvolutionFwdAlgoPerf_t> perf_results(num_possilbe_algorithms);
|
||||
int num_algorithms = 0;
|
||||
CHECK_CUDNN(cudnnFindConvolutionForwardAlgorithm(
|
||||
context(),
|
||||
descriptor(data),
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
descriptor(dest_desc),
|
||||
num_possilbe_algorithms,
|
||||
&num_algorithms,
|
||||
perf_results.data()));
|
||||
perf_results.resize(num_algorithms);
|
||||
forward_best_algo = pick_best_algorithm(perf_results);
|
||||
}
|
||||
#else
|
||||
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(
|
||||
context(),
|
||||
descriptor(data),
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
descriptor(dest_desc),
|
||||
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_FWD_PREFER_FASTEST:CUDNN_CONVOLUTION_FWD_NO_WORKSPACE,
|
||||
std::numeric_limits<size_t>::max(),
|
||||
&forward_best_algo));
|
||||
#endif
|
||||
forward_algo = forward_best_algo;
|
||||
|
||||
|
||||
|
||||
// Pick which backward data algorithm we will use and allocate the
|
||||
// necessary workspace buffer.
|
||||
cudnnConvolutionBwdDataAlgo_t backward_data_best_algo;
|
||||
#if CUDNN_MAJOR >= 8
|
||||
{
|
||||
int num_possilbe_algorithms = 0;
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possilbe_algorithms));
|
||||
std::vector<cudnnConvolutionBwdDataAlgoPerf_t> perf_results(num_possilbe_algorithms);
|
||||
int num_algorithms = 0;
|
||||
CHECK_CUDNN(cudnnFindConvolutionBackwardDataAlgorithm(
|
||||
context(),
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
descriptor(dest_desc),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
descriptor(data),
|
||||
num_possilbe_algorithms,
|
||||
&num_algorithms,
|
||||
perf_results.data()));
|
||||
perf_results.resize(num_algorithms);
|
||||
backward_data_best_algo = pick_best_algorithm(perf_results);
|
||||
}
|
||||
#else
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm(
|
||||
context(),
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
descriptor(dest_desc),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
descriptor(data),
|
||||
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE,
|
||||
std::numeric_limits<size_t>::max(),
|
||||
&backward_data_best_algo));
|
||||
#endif
|
||||
backward_data_algo = backward_data_best_algo;
|
||||
|
||||
|
||||
|
||||
|
||||
// Pick which backward filters algorithm we will use and allocate the
|
||||
// necessary workspace buffer.
|
||||
cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo;
|
||||
#if CUDNN_MAJOR >= 8
|
||||
{
|
||||
int num_possilbe_algorithms = 0;
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possilbe_algorithms));
|
||||
std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> perf_results(num_possilbe_algorithms);
|
||||
int num_algorithms = 0;
|
||||
CHECK_CUDNN(cudnnFindConvolutionBackwardFilterAlgorithm(
|
||||
context(),
|
||||
descriptor(data),
|
||||
descriptor(dest_desc),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
num_possilbe_algorithms,
|
||||
&num_algorithms,
|
||||
perf_results.data()));
|
||||
perf_results.resize(num_algorithms);
|
||||
backward_filters_best_algo = pick_best_algorithm(perf_results);
|
||||
}
|
||||
#else
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm(
|
||||
context(),
|
||||
descriptor(data),
|
||||
descriptor(dest_desc),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE,
|
||||
std::numeric_limits<size_t>::max(),
|
||||
&backward_filters_best_algo));
|
||||
#endif
|
||||
|
||||
// cuDNN 5.1 has a bug that causes
|
||||
// cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd
|
||||
// algorithm even for cases where cuDNN doesn't support it, leading to
|
||||
// incorrect outputs. So here we check if we are in a case where winograd
|
||||
// isn't supported and manually overrule
|
||||
// cudnnGetConvolutionBackwardFilterAlgorithm() by picking a safe
|
||||
// algorithm.
|
||||
if (dnn_prefer_fastest_algorithms() &&
|
||||
!(stride_x == 1 && stride_y == 1 && ((filters_nr==3&&filters_nc==3) || (filters_nr==5&&filters_nc==5)))
|
||||
)
|
||||
{
|
||||
backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
|
||||
}
|
||||
backward_filters_algo = backward_filters_best_algo;
|
||||
}
|
||||
|
||||
void tensor_conv::
|
||||
setup(
|
||||
const tensor& data,
|
||||
|
@ -863,81 +991,17 @@ namespace dlib
|
|||
tensor_descriptor dest_desc;
|
||||
dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc);
|
||||
|
||||
// Pick which forward algorithm we will use and allocate the necessary
|
||||
// workspace buffer.
|
||||
cudnnConvolutionFwdAlgo_t forward_best_algo;
|
||||
#if CUDNN_MAJOR >= 8
|
||||
{
|
||||
int num_possilbe_algorithms = 0;
|
||||
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(context(), &num_possilbe_algorithms));
|
||||
std::vector<cudnnConvolutionFwdAlgoPerf_t> perf_results(num_possilbe_algorithms);
|
||||
int num_algorithms = 0;
|
||||
CHECK_CUDNN(cudnnFindConvolutionForwardAlgorithm(
|
||||
context(),
|
||||
descriptor(data),
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
descriptor(dest_desc),
|
||||
num_possilbe_algorithms,
|
||||
&num_algorithms,
|
||||
perf_results.data()));
|
||||
perf_results.resize(num_algorithms);
|
||||
forward_best_algo = pick_best_algorithm(perf_results);
|
||||
}
|
||||
#else
|
||||
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(
|
||||
context(),
|
||||
descriptor(data),
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
descriptor(dest_desc),
|
||||
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_FWD_PREFER_FASTEST:CUDNN_CONVOLUTION_FWD_NO_WORKSPACE,
|
||||
std::numeric_limits<size_t>::max(),
|
||||
&forward_best_algo));
|
||||
#endif
|
||||
forward_algo = forward_best_algo;
|
||||
select_best_algorithms(data, dest_desc);
|
||||
|
||||
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(
|
||||
context(),
|
||||
descriptor(data),
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
descriptor(dest_desc),
|
||||
forward_best_algo,
|
||||
(cudnnConvolutionFwdAlgo_t)forward_algo,
|
||||
&forward_workspace_size_in_bytes));
|
||||
|
||||
// Pick which backward data algorithm we will use and allocate the
|
||||
// necessary workspace buffer.
|
||||
cudnnConvolutionBwdDataAlgo_t backward_data_best_algo;
|
||||
#if CUDNN_MAJOR >= 8
|
||||
{
|
||||
int num_possilbe_algorithms = 0;
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possilbe_algorithms));
|
||||
std::vector<cudnnConvolutionBwdDataAlgoPerf_t> perf_results(num_possilbe_algorithms);
|
||||
int num_algorithms = 0;
|
||||
CHECK_CUDNN(cudnnFindConvolutionBackwardDataAlgorithm(
|
||||
context(),
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
descriptor(dest_desc),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
descriptor(data),
|
||||
num_possilbe_algorithms,
|
||||
&num_algorithms,
|
||||
perf_results.data()));
|
||||
perf_results.resize(num_algorithms);
|
||||
backward_data_best_algo = pick_best_algorithm(perf_results);
|
||||
}
|
||||
#else
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm(
|
||||
context(),
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
descriptor(dest_desc),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
descriptor(data),
|
||||
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE,
|
||||
std::numeric_limits<size_t>::max(),
|
||||
&backward_data_best_algo));
|
||||
#endif
|
||||
backward_data_algo = backward_data_best_algo;
|
||||
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(
|
||||
context(),
|
||||
|
@ -945,55 +1009,9 @@ namespace dlib
|
|||
descriptor(dest_desc),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
descriptor(data),
|
||||
backward_data_best_algo,
|
||||
(cudnnConvolutionBwdDataAlgo_t)backward_data_algo,
|
||||
&backward_data_workspace_size_in_bytes));
|
||||
|
||||
// Pick which backward filters algorithm we will use and allocate the
|
||||
// necessary workspace buffer.
|
||||
cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo;
|
||||
#if CUDNN_MAJOR >= 8
|
||||
{
|
||||
int num_possilbe_algorithms = 0;
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possilbe_algorithms));
|
||||
std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> perf_results(num_possilbe_algorithms);
|
||||
int num_algorithms = 0;
|
||||
CHECK_CUDNN(cudnnFindConvolutionBackwardFilterAlgorithm(
|
||||
context(),
|
||||
descriptor(data),
|
||||
descriptor(dest_desc),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
num_possilbe_algorithms,
|
||||
&num_algorithms,
|
||||
perf_results.data()));
|
||||
perf_results.resize(num_algorithms);
|
||||
backward_filters_best_algo = pick_best_algorithm(perf_results);
|
||||
}
|
||||
#else
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm(
|
||||
context(),
|
||||
descriptor(data),
|
||||
descriptor(dest_desc),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE,
|
||||
std::numeric_limits<size_t>::max(),
|
||||
&backward_filters_best_algo));
|
||||
#endif
|
||||
// cuDNN 5.1 has a bug that causes
|
||||
// cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd
|
||||
// algorithm even for cases where cuDNN doesn't support it, leading to
|
||||
// incorrect outputs. So here we check if we are in a case where winograd
|
||||
// isn't supported and manually overrule
|
||||
// cudnnGetConvolutionBackwardFilterAlgorithm() by picking a safe
|
||||
// algorithm.
|
||||
if (dnn_prefer_fastest_algorithms() &&
|
||||
!(stride_x == 1 && stride_y == 1 && ((filters_nr==3&&filters_nc==3) || (filters_nr==5&&filters_nc==5)))
|
||||
)
|
||||
{
|
||||
backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
|
||||
}
|
||||
backward_filters_algo = backward_filters_best_algo;
|
||||
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
||||
context(),
|
||||
|
@ -1001,7 +1019,7 @@ namespace dlib
|
|||
descriptor(dest_desc),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
backward_filters_best_algo,
|
||||
(cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo,
|
||||
&backward_filters_workspace_size_in_bytes));
|
||||
}
|
||||
catch(...)
|
||||
|
|
|
@ -228,6 +228,8 @@ namespace dlib
|
|||
int out_nr;
|
||||
int out_nc;
|
||||
|
||||
// sets the three _algo fields.
|
||||
void select_best_algorithms(const tensor& data, const tensor_descriptor& dest_desc);
|
||||
int forward_algo;
|
||||
int backward_data_algo;
|
||||
int backward_filters_algo;
|
||||
|
|
Loading…
Reference in New Issue