mirror of https://github.com/davisking/dlib.git
Made tensor_conv hold references to the cuda_data_void_ptr work buffers in the
member area of the class. This way, we avoid a potential error where the buffers are reallocated while cuDNN is still using them in the background.
This commit is contained in:
parent
863702f059
commit
1f5335c1ad
|
@ -795,6 +795,9 @@ namespace dlib
|
|||
backward_data_workspace_size_in_bytes = 0;
|
||||
backward_filters_workspace_size_in_bytes = 0;
|
||||
|
||||
forward_workspace.reset();
|
||||
backward_data_workspace.reset();
|
||||
backward_filters_workspace.reset();
|
||||
workspace.reset();
|
||||
}
|
||||
|
||||
|
@ -1030,6 +1033,13 @@ namespace dlib
|
|||
const float alpha = 1;
|
||||
const float beta = add_to_output ? 1 : 0;
|
||||
|
||||
// Since cudnnConvolutionForward() is an asynchronous call, we need to hold a
|
||||
// reference to the workspace buffer so we can be sure it isn't reallocated
|
||||
// while the function is still executing on the device. But each time we come
|
||||
// here, we make sure to grab the latest workspace buffer so that, globally, we
|
||||
// minimize the number of such buffers.
|
||||
forward_workspace = workspace->get(forward_workspace_size_in_bytes);
|
||||
|
||||
CHECK_CUDNN(cudnnConvolutionForward(
|
||||
context(),
|
||||
&alpha,
|
||||
|
@ -1039,7 +1049,7 @@ namespace dlib
|
|||
filters.device(),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
(cudnnConvolutionFwdAlgo_t)forward_algo,
|
||||
workspace->get(forward_workspace_size_in_bytes),
|
||||
forward_workspace,
|
||||
forward_workspace_size_in_bytes,
|
||||
&beta,
|
||||
descriptor(output),
|
||||
|
@ -1056,6 +1066,13 @@ namespace dlib
|
|||
const float alpha = 1;
|
||||
const float beta = add_to_output ? 1 : 0;
|
||||
|
||||
// Since cudnnConvolutionBackwardData() is an asynchronous call, we need to hold a
|
||||
// reference to the workspace buffer so we can be sure it isn't reallocated
|
||||
// while the function is still executing on the device. But each time we come
|
||||
// here, we make sure to grab the latest workspace buffer so that, globally, we
|
||||
// minimize the number of such buffers.
|
||||
backward_data_workspace = workspace->get(backward_data_workspace_size_in_bytes);
|
||||
|
||||
|
||||
CHECK_CUDNN(cudnnConvolutionBackwardData(context(),
|
||||
&alpha,
|
||||
|
@ -1065,7 +1082,7 @@ namespace dlib
|
|||
gradient_input.device(),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
(cudnnConvolutionBwdDataAlgo_t)backward_data_algo,
|
||||
workspace->get(backward_data_workspace_size_in_bytes),
|
||||
backward_data_workspace,
|
||||
backward_data_workspace_size_in_bytes,
|
||||
&beta,
|
||||
descriptor(data_gradient),
|
||||
|
@ -1082,6 +1099,14 @@ namespace dlib
|
|||
{
|
||||
const float alpha = 1;
|
||||
const float beta = add_to_output ? 1 : 0;
|
||||
|
||||
// Since cudnnConvolutionBackwardFilter() is an asynchronous call, we need to hold a
|
||||
// reference to the workspace buffer so we can be sure it isn't reallocated
|
||||
// while the function is still executing on the device. But each time we come
|
||||
// here, we make sure to grab the latest workspace buffer so that, globally, we
|
||||
// minimize the number of such buffers.
|
||||
backward_filters_workspace = workspace->get(backward_filters_workspace_size_in_bytes);
|
||||
|
||||
CHECK_CUDNN(cudnnConvolutionBackwardFilter(context(),
|
||||
&alpha,
|
||||
descriptor(data),
|
||||
|
@ -1090,7 +1115,7 @@ namespace dlib
|
|||
gradient_input.device(),
|
||||
(const cudnnConvolutionDescriptor_t)conv_handle,
|
||||
(cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo,
|
||||
workspace->get(backward_filters_workspace_size_in_bytes),
|
||||
backward_filters_workspace,
|
||||
backward_filters_workspace_size_in_bytes,
|
||||
&beta,
|
||||
(const cudnnFilterDescriptor_t)filter_handle,
|
||||
|
|
|
@ -269,6 +269,9 @@ namespace dlib
|
|||
size_t backward_data_workspace_size_in_bytes;
|
||||
size_t backward_filters_workspace_size_in_bytes;
|
||||
std::shared_ptr<resizable_cuda_buffer> workspace;
|
||||
cuda_data_void_ptr forward_workspace;
|
||||
cuda_data_void_ptr backward_data_workspace;
|
||||
cuda_data_void_ptr backward_filters_workspace;
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in New Issue