Made cuBLAS and cuDNN automatically switch their library handles to the

currently active device id if the user changes the active device via a call to
cudaSetDevice().
This commit is contained in:
Davis King 2015-11-29 08:58:40 -05:00
parent ccb148b445
commit 5c058ea110
2 changed files with 33 additions and 7 deletions

View File

@ -6,6 +6,7 @@
#ifdef DLIB_USE_CUDA
#include "cublas_dlibapi.h"
#include "cuda_utils.h"
#include <cublas_v2.h>
@ -52,6 +53,7 @@ namespace dlib
cublas_context()
{
CHECK_CUBLAS(cublasCreate(&handle));
CHECK_CUDA(cudaGetDevice(&device_id));
}
~cublas_context()
{
@ -59,18 +61,27 @@ namespace dlib
}
cublasHandle_t get_handle (
) const { return handle; }
)
{
// Check if the active device for the current thread changed. If so then
// regenerate our cuBLAS handle so it will use the currently selected
// device.
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
if (new_device_id != device_id)
{
CHECK_CUBLAS(cublasDestroy(handle));
CHECK_CUBLAS(cublasCreate(&handle));
}
return handle;
}
private:
cublasHandle_t handle;
int device_id;
};
// TODO, there should probably be some function that is like dlibCudaSetDevice().
// Because people will call cudaSetDevice() expecting to set the device but for
// cuBLAS and cuDNN, since they have these handles, they will keep using the old
// devices. So we should have something that resets these handles and does a
// "dlibCudaSetDevice()"
static cublasHandle_t context()
{
thread_local cublas_context c;

View File

@ -70,6 +70,7 @@ namespace dlib
cudnn_context()
{
CHECK_CUDNN(cudnnCreate(&handle));
CHECK_CUDA(cudaGetDevice(&device_id));
}
~cudnn_context()
@ -78,10 +79,24 @@ namespace dlib
}
cudnnHandle_t get_handle (
) const { return handle; }
)
{
// Check if the active device for the current thread changed. If so then
// regenerate our cuDNN handle so it will use the currently selected
// device.
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
if (new_device_id != device_id)
{
CHECK_CUDNN(cudnnDestroy(handle));
CHECK_CUDNN(cudnnCreate(&handle));
}
return handle;
}
private:
cudnnHandle_t handle;
int device_id;
};
static cudnnHandle_t context()