mirror of https://github.com/davisking/dlib.git
merged
This commit is contained in:
commit
adec3eefd1
|
@ -1278,7 +1278,7 @@ namespace dlib
|
|||
{
|
||||
private:
|
||||
// We don't want anyone making these no_label_type objects. They are here only to
|
||||
// allow add_loss_layer::label_type and dnn_trainer::label_type to exist which voids
|
||||
// allow add_loss_layer::label_type and dnn_trainer::label_type to exist which avoids
|
||||
// needing to overload add_loss_layer and dnn_trainer for supervised an unsupervised
|
||||
// losses. It also can be a type to use in template metaprogramming to indicate
|
||||
// "no label". So here we make the constructor private with the exception that
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#ifdef DLIB_USE_CUDA
|
||||
|
||||
#include "cublas_dlibapi.h"
|
||||
#include "cuda_utils.h"
|
||||
|
||||
#include <cublas_v2.h>
|
||||
|
||||
|
@ -52,6 +53,7 @@ namespace dlib
|
|||
cublas_context()
|
||||
{
|
||||
CHECK_CUBLAS(cublasCreate(&handle));
|
||||
CHECK_CUDA(cudaGetDevice(&device_id));
|
||||
}
|
||||
~cublas_context()
|
||||
{
|
||||
|
@ -59,18 +61,27 @@ namespace dlib
|
|||
}
|
||||
|
||||
cublasHandle_t get_handle (
|
||||
) const { return handle; }
|
||||
)
|
||||
{
|
||||
// Check if the active device for the current thread changed. If so then
|
||||
// regenerate our cuBLAS handle so it will use the currently selected
|
||||
// device.
|
||||
int new_device_id;
|
||||
CHECK_CUDA(cudaGetDevice(&new_device_id));
|
||||
if (new_device_id != device_id)
|
||||
{
|
||||
CHECK_CUBLAS(cublasDestroy(handle));
|
||||
CHECK_CUBLAS(cublasCreate(&handle));
|
||||
}
|
||||
return handle;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
cublasHandle_t handle;
|
||||
int device_id;
|
||||
};
|
||||
|
||||
// TODO, there should probably be some function that is like dlibCudaSetDevice().
|
||||
// Because people will call cudaSetDevice() expecting to set the device but for
|
||||
// cuBLAS and cuDNN, since they have these handles, they will keep using the old
|
||||
// devices. So we should have something that resets these handles and does a
|
||||
// "dlibCudaSetDevice()"
|
||||
static cublasHandle_t context()
|
||||
{
|
||||
thread_local cublas_context c;
|
||||
|
|
|
@ -70,6 +70,7 @@ namespace dlib
|
|||
cudnn_context()
|
||||
{
|
||||
CHECK_CUDNN(cudnnCreate(&handle));
|
||||
CHECK_CUDA(cudaGetDevice(&device_id));
|
||||
}
|
||||
|
||||
~cudnn_context()
|
||||
|
@ -78,10 +79,24 @@ namespace dlib
|
|||
}
|
||||
|
||||
cudnnHandle_t get_handle (
|
||||
) const { return handle; }
|
||||
)
|
||||
{
|
||||
// Check if the active device for the current thread changed. If so then
|
||||
// regenerate our cuDNN handle so it will use the currently selected
|
||||
// device.
|
||||
int new_device_id;
|
||||
CHECK_CUDA(cudaGetDevice(&new_device_id));
|
||||
if (new_device_id != device_id)
|
||||
{
|
||||
CHECK_CUDNN(cudnnDestroy(handle));
|
||||
CHECK_CUDNN(cudnnCreate(&handle));
|
||||
}
|
||||
return handle;
|
||||
}
|
||||
|
||||
private:
|
||||
cudnnHandle_t handle;
|
||||
int device_id;
|
||||
};
|
||||
|
||||
static cudnnHandle_t context()
|
||||
|
|
Loading…
Reference in New Issue