diff --git a/dlib/dnn/cublas_dlibapi.cpp b/dlib/dnn/cublas_dlibapi.cpp index 4d7bb07e2..9273280c9 100644 --- a/dlib/dnn/cublas_dlibapi.cpp +++ b/dlib/dnn/cublas_dlibapi.cpp @@ -6,6 +6,7 @@ #ifdef DLIB_USE_CUDA #include "cublas_dlibapi.h" +#include "cuda_utils.h" #include @@ -52,6 +53,7 @@ namespace dlib cublas_context() { CHECK_CUBLAS(cublasCreate(&handle)); + CHECK_CUDA(cudaGetDevice(&device_id)); } ~cublas_context() { @@ -59,18 +61,27 @@ namespace dlib } cublasHandle_t get_handle ( - ) const { return handle; } + ) + { + // Check if the active device for the current thread changed. If so then + // regenerate our cuBLAS handle so it will use the currently selected + // device. + int new_device_id; + CHECK_CUDA(cudaGetDevice(&new_device_id)); + if (new_device_id != device_id) + { + CHECK_CUBLAS(cublasDestroy(handle)); + CHECK_CUBLAS(cublasCreate(&handle)); + } + return handle; + } private: cublasHandle_t handle; + int device_id; }; - // TODO, there should probably be some function that is like dlibCudaSetDevice(). - // Because people will call cudaSetDevice() expecting to set the device but for - // cuBLAS and cuDNN, since they have these handles, they will keep using the old - // devices. So we should have something that resets these handles and does a - // "dlibCudaSetDevice()" static cublasHandle_t context() { thread_local cublas_context c; diff --git a/dlib/dnn/cudnn_dlibapi.cpp b/dlib/dnn/cudnn_dlibapi.cpp index 04022449a..aa0446b2a 100644 --- a/dlib/dnn/cudnn_dlibapi.cpp +++ b/dlib/dnn/cudnn_dlibapi.cpp @@ -70,6 +70,7 @@ namespace dlib cudnn_context() { CHECK_CUDNN(cudnnCreate(&handle)); + CHECK_CUDA(cudaGetDevice(&device_id)); } ~cudnn_context() @@ -78,10 +79,24 @@ namespace dlib } cudnnHandle_t get_handle ( - ) const { return handle; } + ) + { + // Check if the active device for the current thread changed. If so then + // regenerate our cuDNN handle so it will use the currently selected + // device. + int new_device_id; + CHECK_CUDA(cudaGetDevice(&new_device_id)); + if (new_device_id != device_id) + { + CHECK_CUDNN(cudnnDestroy(handle)); + CHECK_CUDNN(cudnnCreate(&handle)); + } + return handle; + } private: cudnnHandle_t handle; + int device_id; }; static cudnnHandle_t context()