This commit is contained in:
Davis King 2015-11-29 12:10:00 -05:00
commit adec3eefd1
3 changed files with 34 additions and 8 deletions

View File

@ -1278,7 +1278,7 @@ namespace dlib
{
private:
// We don't want anyone making these no_label_type objects. They are here only to
// allow add_loss_layer::label_type and dnn_trainer::label_type to exist which voids
// allow add_loss_layer::label_type and dnn_trainer::label_type to exist which avoids
// needing to overload add_loss_layer and dnn_trainer for supervised an unsupervised
// losses. It also can be a type to use in template metaprogramming to indicate
// "no label". So here we make the constructor private with the exception that

View File

@ -6,6 +6,7 @@
#ifdef DLIB_USE_CUDA
#include "cublas_dlibapi.h"
#include "cuda_utils.h"
#include <cublas_v2.h>
@ -52,6 +53,7 @@ namespace dlib
cublas_context()
{
CHECK_CUBLAS(cublasCreate(&handle));
CHECK_CUDA(cudaGetDevice(&device_id));
}
~cublas_context()
{
@ -59,18 +61,27 @@ namespace dlib
}
cublasHandle_t get_handle (
) const { return handle; }
)
{
// Check if the active device for the current thread changed. If so then
// regenerate our cuBLAS handle so it will use the currently selected
// device.
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
if (new_device_id != device_id)
{
CHECK_CUBLAS(cublasDestroy(handle));
CHECK_CUBLAS(cublasCreate(&handle));
}
return handle;
}
private:
cublasHandle_t handle;
int device_id;
};
// TODO, there should probably be some function that is like dlibCudaSetDevice().
// Because people will call cudaSetDevice() expecting to set the device but for
// cuBLAS and cuDNN, since they have these handles, they will keep using the old
// devices. So we should have something that resets these handles and does a
// "dlibCudaSetDevice()"
static cublasHandle_t context()
{
thread_local cublas_context c;

View File

@ -70,6 +70,7 @@ namespace dlib
cudnn_context()
{
CHECK_CUDNN(cudnnCreate(&handle));
CHECK_CUDA(cudaGetDevice(&device_id));
}
~cudnn_context()
@ -78,10 +79,24 @@ namespace dlib
}
cudnnHandle_t get_handle (
) const { return handle; }
)
{
// Check if the active device for the current thread changed. If so then
// regenerate our cuDNN handle so it will use the currently selected
// device.
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
if (new_device_id != device_id)
{
CHECK_CUDNN(cudnnDestroy(handle));
CHECK_CUDNN(cudnnCreate(&handle));
}
return handle;
}
private:
cudnnHandle_t handle;
int device_id;
};
static cudnnHandle_t context()