mirror of https://github.com/davisking/dlib.git
Moved gpu_data into its own file, fixed a few bugs, and cleaned up
a few things.
This commit is contained in:
parent
0508fe2bf1
commit
3ba21975f7
|
@ -441,6 +441,7 @@ if (NOT TARGET dlib)
|
|||
dnn/cuda_dlib.cu
|
||||
dnn/cudnn_dlibapi.cpp
|
||||
dnn/cublas_dlibapi.cpp
|
||||
dnn/gpu_data.cpp
|
||||
)
|
||||
set(dlib_needed_libraries ${dlib_needed_libraries} ${CUDA_CUBLAS_LIBRARIES} ${cudnn})
|
||||
include_directories(${cudnn_include})
|
||||
|
|
|
@ -29,8 +29,19 @@ namespace dlib
|
|||
cublas_context(const cublas_context&) = delete;
|
||||
cublas_context& operator=(const cublas_context&) = delete;
|
||||
// but is movable
|
||||
cublas_context(cublas_context&&) = default;
|
||||
cublas_context& operator=(cublas_context&&) = default;
|
||||
cublas_context(cublas_context&& item)
|
||||
{
|
||||
handle = item.handle;
|
||||
item.handle = nullptr;
|
||||
}
|
||||
cublas_context& operator=(cublas_context&& item)
|
||||
{
|
||||
if (this == &item)
|
||||
return *this;
|
||||
handle = item.handle;
|
||||
item.handle = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
cublas_context();
|
||||
~cublas_context();
|
||||
|
|
|
@ -3,27 +3,22 @@
|
|||
#ifndef DLIB_CUDA_ERRORs_H_
|
||||
#define DLIB_CUDA_ERRORs_H_
|
||||
|
||||
#ifdef DLIB_USE_CUDA
|
||||
|
||||
#include "../error.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
namespace cuda
|
||||
struct cuda_error : public error
|
||||
{
|
||||
struct cuda_error : public error
|
||||
{
|
||||
cuda_error(const std::string& message): error(message) {}
|
||||
};
|
||||
cuda_error(const std::string& message): error(message) {}
|
||||
};
|
||||
|
||||
struct cudnn_error : public cuda_error
|
||||
{
|
||||
cudnn_error(const std::string& message): cuda_error(message) {}
|
||||
};
|
||||
}
|
||||
struct cudnn_error : public cuda_error
|
||||
{
|
||||
cudnn_error(const std::string& message): cuda_error(message) {}
|
||||
};
|
||||
}
|
||||
|
||||
#endif // DLIB_USE_CUDA
|
||||
|
||||
#endif // DLIB_CUDA_ERRORs_H_
|
||||
|
||||
|
|
|
@ -3,10 +3,13 @@
|
|||
#ifndef DLIB_CUDA_UtILS_H_
|
||||
#define DLIB_CUDA_UtILS_H_
|
||||
|
||||
#ifndef DLIB_USE_CUDA
|
||||
#error "This file shouldn't be #included unless DLIB_USE_CUDA is #defined"
|
||||
#endif
|
||||
|
||||
#include "cuda_errors.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <sstream>
|
||||
|
||||
|
||||
|
@ -19,7 +22,7 @@
|
|||
std::ostringstream sout; \
|
||||
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
|
||||
sout << "code: " << error << ", reason: " << cudaGetErrorString(error);\
|
||||
throw dlib::cuda::cuda_error(sout.str()); \
|
||||
throw dlib::cuda_error(sout.str()); \
|
||||
} \
|
||||
}
|
||||
|
||||
|
@ -27,75 +30,82 @@
|
|||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
class grid_stride_range
|
||||
namespace dlib
|
||||
{
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This is a tool for making a for loop that loops over an entire block of memory
|
||||
inside a kernel, but doing so in a way that parallelizes appropriately across
|
||||
all the threads in a kernel launch. For example, the following kernel would
|
||||
add the vector a to the vector b and store the output in out (assuming all
|
||||
vectors are of dimension n):
|
||||
__global__ void add_arrays(
|
||||
const float* a,
|
||||
const float* b,
|
||||
float* out,
|
||||
size_t n
|
||||
)
|
||||
namespace cuda
|
||||
{
|
||||
class grid_stride_range
|
||||
{
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This is a tool for making a for loop that loops over an entire block of
|
||||
memory inside a kernel, but doing so in a way that parallelizes
|
||||
appropriately across all the threads in a kernel launch. For example,
|
||||
the following kernel would add the vector a to the vector b and store
|
||||
the output in out (assuming all vectors are of dimension n):
|
||||
__global__ void add_arrays(
|
||||
const float* a,
|
||||
const float* b,
|
||||
float* out,
|
||||
size_t n
|
||||
)
|
||||
{
|
||||
for (auto i : grid_stride_range(0, n))
|
||||
{
|
||||
out[i] = a[i]+b[i];
|
||||
}
|
||||
}
|
||||
!*/
|
||||
|
||||
public:
|
||||
__device__ grid_stride_range(
|
||||
size_t ibegin_,
|
||||
size_t iend_
|
||||
) :
|
||||
ibegin(ibegin_),
|
||||
iend(iend_)
|
||||
{}
|
||||
|
||||
class iterator
|
||||
{
|
||||
public:
|
||||
__device__ iterator() {}
|
||||
__device__ iterator(size_t pos_) : pos(pos_) {}
|
||||
|
||||
__device__ size_t operator*() const
|
||||
{
|
||||
for (auto i : grid_stride_range(0, n))
|
||||
{
|
||||
out[i] = a[i]+b[i];
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
!*/
|
||||
|
||||
public:
|
||||
__device__ grid_stride_range(
|
||||
size_t ibegin_,
|
||||
size_t iend_
|
||||
) :
|
||||
ibegin(ibegin_),
|
||||
iend(iend_)
|
||||
{}
|
||||
__device__ iterator& operator++()
|
||||
{
|
||||
pos += gridDim.x * blockDim.x;
|
||||
return *this;
|
||||
}
|
||||
|
||||
class iterator
|
||||
{
|
||||
public:
|
||||
__device__ iterator() {}
|
||||
__device__ iterator(size_t pos_) : pos(pos_) {}
|
||||
__device__ bool operator!=(const iterator& item) const
|
||||
{ return pos < item.pos; }
|
||||
|
||||
__device__ size_t operator*() const
|
||||
{
|
||||
return pos;
|
||||
}
|
||||
private:
|
||||
size_t pos;
|
||||
};
|
||||
|
||||
__device__ iterator& operator++()
|
||||
{
|
||||
pos += gridDim.x * blockDim.x;
|
||||
return *this;
|
||||
}
|
||||
__device__ iterator begin() const
|
||||
{
|
||||
return iterator(ibegin+blockDim.x * blockIdx.x + threadIdx.x);
|
||||
}
|
||||
__device__ iterator end() const
|
||||
{
|
||||
return iterator(iend);
|
||||
}
|
||||
private:
|
||||
|
||||
__device__ bool operator!=(const iterator& item) const
|
||||
{ return pos < item.pos; }
|
||||
size_t ibegin;
|
||||
size_t iend;
|
||||
};
|
||||
|
||||
private:
|
||||
size_t pos;
|
||||
};
|
||||
|
||||
__device__ iterator begin() const
|
||||
{
|
||||
return iterator(ibegin+blockDim.x * blockIdx.x + threadIdx.x);
|
||||
}
|
||||
__device__ iterator end() const
|
||||
{
|
||||
return iterator(iend);
|
||||
}
|
||||
private:
|
||||
|
||||
size_t ibegin;
|
||||
size_t iend;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // __CUDACC__
|
||||
|
||||
|
|
|
@ -15,133 +15,6 @@
|
|||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// gpu_data member functions
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// TODO, add error handling
|
||||
void gpu_data::
|
||||
wait_for_transfer_to_finish() const
|
||||
{
|
||||
if (have_active_transfer)
|
||||
{
|
||||
std::cout << "wait for cudaStreamSynchronize()" << std::endl;
|
||||
CHECK_CUDA(cudaStreamSynchronize((cudaStream_t)cuda_stream.get()));
|
||||
have_active_transfer = false;
|
||||
// Check for errors. These calls to cudaGetLastError() are what help us find
|
||||
// out if our kernel launches have been failing.
|
||||
CHECK_CUDA(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
void gpu_data::
|
||||
copy_to_device() const
|
||||
{
|
||||
wait_for_transfer_to_finish();
|
||||
if (!device_current)
|
||||
{
|
||||
std::cout << "cudaMemcpy to device" << std::endl;
|
||||
CHECK_CUDA(cudaMemcpy(data_device.get(), data_host.get(), data_size*sizeof(float), cudaMemcpyHostToDevice));
|
||||
device_current = true;
|
||||
// Check for errors. These calls to cudaGetLastError() are what help us find
|
||||
// out if our kernel launches have been failing.
|
||||
CHECK_CUDA(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
void gpu_data::
|
||||
copy_to_host() const
|
||||
{
|
||||
wait_for_transfer_to_finish();
|
||||
if (!host_current)
|
||||
{
|
||||
std::cout << "cudaMemcpy to host" << std::endl;
|
||||
CHECK_CUDA(cudaMemcpy(data_host.get(), data_device.get(), data_size*sizeof(float), cudaMemcpyDeviceToHost));
|
||||
host_current = true;
|
||||
// Check for errors. These calls to cudaGetLastError() are what help us find
|
||||
// out if our kernel launches have been failing.
|
||||
CHECK_CUDA(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
void gpu_data::
|
||||
async_copy_to_device()
|
||||
{
|
||||
if (!device_current)
|
||||
{
|
||||
std::cout << "cudaMemcpyAsync to device" << std::endl;
|
||||
CHECK_CUDA(cudaMemcpyAsync(data_device.get(), data_host.get(), data_size*sizeof(float), cudaMemcpyHostToDevice, (cudaStream_t)cuda_stream.get()));
|
||||
have_active_transfer = true;
|
||||
device_current = true;
|
||||
}
|
||||
}
|
||||
|
||||
void gpu_data::
|
||||
set_size(
|
||||
size_t new_size
|
||||
)
|
||||
{
|
||||
wait_for_transfer_to_finish();
|
||||
if (new_size == 0)
|
||||
{
|
||||
data_size = 0;
|
||||
host_current = true;
|
||||
device_current = true;
|
||||
data_host.reset();
|
||||
data_device.reset();
|
||||
}
|
||||
else if (new_size != data_size)
|
||||
{
|
||||
data_size = new_size;
|
||||
host_current = true;
|
||||
device_current = true;
|
||||
|
||||
try
|
||||
{
|
||||
void* data;
|
||||
CHECK_CUDA(cudaMallocHost(&data, new_size*sizeof(float)));
|
||||
// Note that we don't throw exceptions since the free calls are invariably
|
||||
// called in destructors. They also shouldn't fail anyway unless someone
|
||||
// is resetting the GPU card in the middle of their program.
|
||||
data_host.reset((float*)data, [](float* ptr){
|
||||
auto err = cudaFreeHost(ptr);
|
||||
if(err!=cudaSuccess)
|
||||
std::cerr << "cudaFreeHost() failed. Reason: " << cudaGetErrorString(err) << std::endl;
|
||||
});
|
||||
|
||||
CHECK_CUDA(cudaMalloc(&data, new_size*sizeof(float)));
|
||||
data_device.reset((float*)data, [](float* ptr){
|
||||
auto err = cudaFree(ptr);
|
||||
if(err!=cudaSuccess)
|
||||
std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl;
|
||||
});
|
||||
|
||||
if (!cuda_stream)
|
||||
{
|
||||
cudaStream_t cstream;
|
||||
CHECK_CUDA(cudaStreamCreateWithFlags(&cstream, cudaStreamNonBlocking));
|
||||
cuda_stream.reset(cstream, [](void* ptr){
|
||||
auto err = cudaStreamDestroy((cudaStream_t)ptr);
|
||||
if(err!=cudaSuccess)
|
||||
std::cerr << "cudaStreamDestroy() failed. Reason: " << cudaGetErrorString(err) << std::endl;
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
set_size(0);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
namespace cuda
|
||||
{
|
||||
|
||||
|
@ -155,6 +28,8 @@ namespace dlib
|
|||
throw cudnn_error("CUDA Runtime API initialization failed.");
|
||||
case CUDNN_STATUS_ALLOC_FAILED:
|
||||
throw cudnn_error("CUDA Resources could not be allocated.");
|
||||
case CUDNN_STATUS_BAD_PARAM:
|
||||
throw cudnn_error("CUDNN_STATUS_BAD_PARAM");
|
||||
default:
|
||||
throw cudnn_error("A call to cuDNN failed.");
|
||||
}
|
||||
|
@ -180,20 +55,16 @@ namespace dlib
|
|||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
tensor_descriptor::tensor_descriptor() : handle(nullptr)
|
||||
tensor_descriptor::
|
||||
tensor_descriptor(
|
||||
) : handle(nullptr)
|
||||
{
|
||||
cudnnTensorDescriptor_t h;
|
||||
check(cudnnCreateTensorDescriptor(&h));
|
||||
handle = h;
|
||||
}
|
||||
|
||||
tensor_descriptor::~tensor_descriptor()
|
||||
tensor_descriptor::
|
||||
~tensor_descriptor()
|
||||
{
|
||||
if (handle)
|
||||
{
|
||||
cudnnDestroyTensorDescriptor((cudnnTensorDescriptor_t)handle);
|
||||
handle = nullptr;
|
||||
}
|
||||
set_size(0,0,0,0);
|
||||
}
|
||||
|
||||
void tensor_descriptor::
|
||||
|
@ -204,13 +75,28 @@ namespace dlib
|
|||
int k
|
||||
)
|
||||
{
|
||||
check(cudnnSetTensor4dDescriptor((cudnnTensorDescriptor_t)handle,
|
||||
CUDNN_TENSOR_NHWC,
|
||||
CUDNN_DATA_FLOAT,
|
||||
n,
|
||||
k,
|
||||
nr,
|
||||
nc));
|
||||
if (n == 0 || nr == 0 || nc == 0 || k == 0)
|
||||
{
|
||||
if (handle)
|
||||
{
|
||||
cudnnDestroyTensorDescriptor((cudnnTensorDescriptor_t)handle);
|
||||
handle = nullptr;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cudnnTensorDescriptor_t h;
|
||||
check(cudnnCreateTensorDescriptor(&h));
|
||||
handle = h;
|
||||
|
||||
check(cudnnSetTensor4dDescriptor((cudnnTensorDescriptor_t)handle,
|
||||
CUDNN_TENSOR_NHWC,
|
||||
CUDNN_DATA_FLOAT,
|
||||
n,
|
||||
k,
|
||||
nr,
|
||||
nc));
|
||||
}
|
||||
}
|
||||
|
||||
void tensor_descriptor::
|
||||
|
@ -221,18 +107,28 @@ namespace dlib
|
|||
int& k
|
||||
) const
|
||||
{
|
||||
int nStride, cStride, hStride, wStride;
|
||||
cudnnDataType_t datatype;
|
||||
check(cudnnGetTensor4dDescriptor((cudnnTensorDescriptor_t)handle,
|
||||
&datatype,
|
||||
&n,
|
||||
&k,
|
||||
&nr,
|
||||
&nc,
|
||||
&nStride,
|
||||
&cStride,
|
||||
&hStride,
|
||||
&wStride));
|
||||
if (handle)
|
||||
{
|
||||
int nStride, cStride, hStride, wStride;
|
||||
cudnnDataType_t datatype;
|
||||
check(cudnnGetTensor4dDescriptor((cudnnTensorDescriptor_t)handle,
|
||||
&datatype,
|
||||
&n,
|
||||
&k,
|
||||
&nr,
|
||||
&nc,
|
||||
&nStride,
|
||||
&cStride,
|
||||
&hStride,
|
||||
&wStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
n = 0;
|
||||
nr = 0;
|
||||
nc = 0;
|
||||
k = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
|
|
@ -25,8 +25,19 @@ namespace dlib
|
|||
cudnn_context(const cudnn_context&) = delete;
|
||||
cudnn_context& operator=(const cudnn_context&) = delete;
|
||||
// but is movable
|
||||
cudnn_context(cudnn_context&&) = default;
|
||||
cudnn_context& operator=(cudnn_context&&) = default;
|
||||
cudnn_context(cudnn_context&& item)
|
||||
{
|
||||
handle = item.handle;
|
||||
item.handle = nullptr;
|
||||
}
|
||||
cudnn_context& operator=(cudnn_context&& item)
|
||||
{
|
||||
if (this == &item)
|
||||
return *this;
|
||||
handle = item.handle;
|
||||
item.handle = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
cudnn_context();
|
||||
~cudnn_context();
|
||||
|
@ -53,8 +64,19 @@ namespace dlib
|
|||
tensor_descriptor(const tensor_descriptor&) = delete;
|
||||
tensor_descriptor& operator=(const tensor_descriptor&) = delete;
|
||||
// but is movable
|
||||
tensor_descriptor(tensor_descriptor&&) = default;
|
||||
tensor_descriptor& operator=(tensor_descriptor&&) = default;
|
||||
tensor_descriptor(tensor_descriptor&& item)
|
||||
{
|
||||
handle = item.handle;
|
||||
item.handle = nullptr;
|
||||
}
|
||||
tensor_descriptor& operator=(tensor_descriptor&& item)
|
||||
{
|
||||
if (this == &item)
|
||||
return *this;
|
||||
handle = item.handle;
|
||||
item.handle = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
tensor_descriptor();
|
||||
~tensor_descriptor();
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#ifndef DLIB_GPU_DaTA_CPP_
|
||||
#define DLIB_GPU_DaTA_CPP_
|
||||
|
||||
// Only things that require CUDA are declared in this cpp file. Everything else is in the
|
||||
// gpu_data.h header so that it can operate as "header-only" code when using just the CPU.
|
||||
#ifdef DLIB_USE_CUDA
|
||||
|
||||
#include "gpu_data.h"
|
||||
#include <iostream>
|
||||
#include "cuda_utils.h"
|
||||
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void gpu_data::
|
||||
wait_for_transfer_to_finish() const
|
||||
{
|
||||
if (have_active_transfer)
|
||||
{
|
||||
std::cout << "wait for cudaStreamSynchronize()" << std::endl;
|
||||
CHECK_CUDA(cudaStreamSynchronize((cudaStream_t)cuda_stream.get()));
|
||||
have_active_transfer = false;
|
||||
// Check for errors. These calls to cudaGetLastError() are what help us find
|
||||
// out if our kernel launches have been failing.
|
||||
CHECK_CUDA(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
void gpu_data::
|
||||
copy_to_device() const
|
||||
{
|
||||
wait_for_transfer_to_finish();
|
||||
if (!device_current)
|
||||
{
|
||||
std::cout << "cudaMemcpy to device" << std::endl;
|
||||
CHECK_CUDA(cudaMemcpy(data_device.get(), data_host.get(), data_size*sizeof(float), cudaMemcpyHostToDevice));
|
||||
device_current = true;
|
||||
// Check for errors. These calls to cudaGetLastError() are what help us find
|
||||
// out if our kernel launches have been failing.
|
||||
CHECK_CUDA(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
void gpu_data::
|
||||
copy_to_host() const
|
||||
{
|
||||
wait_for_transfer_to_finish();
|
||||
if (!host_current)
|
||||
{
|
||||
std::cout << "cudaMemcpy to host" << std::endl;
|
||||
CHECK_CUDA(cudaMemcpy(data_host.get(), data_device.get(), data_size*sizeof(float), cudaMemcpyDeviceToHost));
|
||||
host_current = true;
|
||||
// Check for errors. These calls to cudaGetLastError() are what help us find
|
||||
// out if our kernel launches have been failing.
|
||||
CHECK_CUDA(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
void gpu_data::
|
||||
async_copy_to_device()
|
||||
{
|
||||
if (!device_current)
|
||||
{
|
||||
std::cout << "cudaMemcpyAsync to device" << std::endl;
|
||||
CHECK_CUDA(cudaMemcpyAsync(data_device.get(), data_host.get(), data_size*sizeof(float), cudaMemcpyHostToDevice, (cudaStream_t)cuda_stream.get()));
|
||||
have_active_transfer = true;
|
||||
device_current = true;
|
||||
}
|
||||
}
|
||||
|
||||
void gpu_data::
|
||||
set_size(
|
||||
size_t new_size
|
||||
)
|
||||
{
|
||||
wait_for_transfer_to_finish();
|
||||
if (new_size == 0)
|
||||
{
|
||||
data_size = 0;
|
||||
host_current = true;
|
||||
device_current = true;
|
||||
data_host.reset();
|
||||
data_device.reset();
|
||||
}
|
||||
else if (new_size != data_size)
|
||||
{
|
||||
data_size = new_size;
|
||||
host_current = true;
|
||||
device_current = true;
|
||||
|
||||
try
|
||||
{
|
||||
void* data;
|
||||
CHECK_CUDA(cudaMallocHost(&data, new_size*sizeof(float)));
|
||||
// Note that we don't throw exceptions since the free calls are invariably
|
||||
// called in destructors. They also shouldn't fail anyway unless someone
|
||||
// is resetting the GPU card in the middle of their program.
|
||||
data_host.reset((float*)data, [](float* ptr){
|
||||
auto err = cudaFreeHost(ptr);
|
||||
if(err!=cudaSuccess)
|
||||
std::cerr << "cudaFreeHost() failed. Reason: " << cudaGetErrorString(err) << std::endl;
|
||||
});
|
||||
|
||||
CHECK_CUDA(cudaMalloc(&data, new_size*sizeof(float)));
|
||||
data_device.reset((float*)data, [](float* ptr){
|
||||
auto err = cudaFree(ptr);
|
||||
if(err!=cudaSuccess)
|
||||
std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl;
|
||||
});
|
||||
|
||||
if (!cuda_stream)
|
||||
{
|
||||
cudaStream_t cstream;
|
||||
CHECK_CUDA(cudaStreamCreateWithFlags(&cstream, cudaStreamNonBlocking));
|
||||
cuda_stream.reset(cstream, [](void* ptr){
|
||||
auto err = cudaStreamDestroy((cudaStream_t)ptr);
|
||||
if(err!=cudaSuccess)
|
||||
std::cerr << "cudaStreamDestroy() failed. Reason: " << cudaGetErrorString(err) << std::endl;
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
set_size(0);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
}
|
||||
|
||||
#endif // DLIB_USE_CUDA
|
||||
|
||||
#endif // DLIB_GPU_DaTA_CPP_
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#ifndef DLIB_GPU_DaTA_H_
|
||||
#define DLIB_GPU_DaTA_H_
|
||||
|
||||
#include <memory>
|
||||
#include "cuda_errors.h"
|
||||
#include "../serialize.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class gpu_data
|
||||
{
|
||||
/*!
|
||||
CONVENTION
|
||||
- if (size() != 0) then
|
||||
- data_host == a pointer to size() floats in CPU memory.
|
||||
- if (data_device) then
|
||||
- data_device == a pointer to size() floats in device memory.
|
||||
|
||||
- if (there might be an active transfer between host and device) then
|
||||
- have_active_transfer == true
|
||||
|
||||
- We use the host_current and device_current bools to keep track of which
|
||||
copy of the data (or both) are most current. e.g. if the CPU has
|
||||
modified the tensor and it hasn't been copied to the device yet then
|
||||
host_current==true and device_current == false.
|
||||
|
||||
|
||||
THREAD SAFETY
|
||||
This object is not thread-safe. Don't touch it from multiple threads as the
|
||||
same time.
|
||||
!*/
|
||||
public:
|
||||
|
||||
gpu_data(
|
||||
) : data_size(0), host_current(true), device_current(true),have_active_transfer(false)
|
||||
{
|
||||
}
|
||||
|
||||
// Not copyable
|
||||
gpu_data(const gpu_data&) = delete;
|
||||
gpu_data& operator=(const gpu_data&) = delete;
|
||||
|
||||
// but is movable
|
||||
gpu_data(gpu_data&&) = default;
|
||||
gpu_data& operator=(gpu_data&&) = default;
|
||||
|
||||
|
||||
#ifdef DLIB_USE_CUDA
|
||||
void async_copy_to_device();
|
||||
void set_size(size_t new_size);
|
||||
#else
|
||||
// Note that calls to host() or device() will block until any async transfers are complete.
|
||||
void async_copy_to_device(){}
|
||||
|
||||
void set_size(size_t new_size)
|
||||
{
|
||||
if (new_size == 0)
|
||||
{
|
||||
data_size = 0;
|
||||
host_current = true;
|
||||
device_current = true;
|
||||
data_host.reset();
|
||||
data_device.reset();
|
||||
}
|
||||
else if (new_size != data_size)
|
||||
{
|
||||
data_size = new_size;
|
||||
host_current = true;
|
||||
device_current = true;
|
||||
data_host.reset(new float[new_size], std::default_delete<float[]>());
|
||||
data_device.reset();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
const float* host() const
|
||||
{
|
||||
copy_to_host();
|
||||
return data_host.get();
|
||||
}
|
||||
|
||||
float* host()
|
||||
{
|
||||
copy_to_host();
|
||||
device_current = false;
|
||||
return data_host.get();
|
||||
}
|
||||
|
||||
const float* device() const
|
||||
{
|
||||
#ifndef DLIB_USE_CUDA
|
||||
DLIB_CASSERT(false, "CUDA NOT ENABLED");
|
||||
#endif
|
||||
copy_to_device();
|
||||
return data_device.get();
|
||||
}
|
||||
|
||||
float* device()
|
||||
{
|
||||
#ifndef DLIB_USE_CUDA
|
||||
DLIB_CASSERT(false, "CUDA NOT ENABLED");
|
||||
#endif
|
||||
copy_to_device();
|
||||
host_current = false;
|
||||
return data_device.get();
|
||||
}
|
||||
|
||||
size_t size() const { return data_size; }
|
||||
|
||||
|
||||
private:
|
||||
|
||||
#ifdef DLIB_USE_CUDA
|
||||
void copy_to_device() const;
|
||||
void copy_to_host() const;
|
||||
void wait_for_transfer_to_finish() const;
|
||||
#else
|
||||
void copy_to_device() const{}
|
||||
void copy_to_host() const{}
|
||||
void wait_for_transfer_to_finish() const{}
|
||||
#endif
|
||||
|
||||
|
||||
size_t data_size;
|
||||
mutable bool host_current;
|
||||
mutable bool device_current;
|
||||
mutable bool have_active_transfer;
|
||||
|
||||
std::shared_ptr<float> data_host;
|
||||
std::shared_ptr<float> data_device;
|
||||
std::shared_ptr<void> cuda_stream;
|
||||
};
|
||||
|
||||
inline void serialize(const gpu_data& item, std::ostream& out)
|
||||
{
|
||||
int version = 1;
|
||||
serialize(version, out);
|
||||
serialize(item.size(), out);
|
||||
auto data = item.host();
|
||||
for (size_t i = 0; i < item.size(); ++i)
|
||||
serialize(data[i], out);
|
||||
}
|
||||
|
||||
inline void deserialize(gpu_data& item, std::istream& in)
|
||||
{
|
||||
int version;
|
||||
deserialize(version, in);
|
||||
if (version != 1)
|
||||
throw serialization_error("Unexpected version found while deserializing dlib::gpu_data.");
|
||||
size_t s;
|
||||
deserialize(s, in);
|
||||
item.set_size(s);
|
||||
auto data = item.host();
|
||||
for (size_t i = 0; i < item.size(); ++i)
|
||||
deserialize(data[i], in);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_GPU_DaTA_H_
|
||||
|
|
@ -3,165 +3,13 @@
|
|||
#ifndef DLIB_DNn_TENSOR_H_
|
||||
#define DLIB_DNn_TENSOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <cstring>
|
||||
#include "../matrix.h"
|
||||
#include "cudnn_dlibapi.h"
|
||||
#include "gpu_data.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class gpu_data
|
||||
{
|
||||
/*!
|
||||
CONVENTION
|
||||
- if (size() != 0) then
|
||||
- data_host == a pointer to size() floats in CPU memory.
|
||||
- if (data_device) then
|
||||
- data_device == a pointer to size() floats in device memory.
|
||||
|
||||
- if (there might be an active transfer between host and device) then
|
||||
- have_active_transfer == true
|
||||
|
||||
- We use the host_current and device_current bools to keep track of which
|
||||
copy of the data (or both) are most current. e.g. if the CPU has
|
||||
modified the tensor and it hasn't been copied to the device yet then
|
||||
host_current==true and device_current == false.
|
||||
|
||||
|
||||
THREAD SAFETY
|
||||
This object is not thread-safe. Don't touch it from multiple threads as the
|
||||
same time.
|
||||
!*/
|
||||
public:
|
||||
|
||||
gpu_data(
|
||||
) : data_size(0), host_current(true), device_current(true),have_active_transfer(false)
|
||||
{
|
||||
}
|
||||
|
||||
// Not copyable
|
||||
gpu_data(const gpu_data&) = delete;
|
||||
gpu_data& operator=(const gpu_data&) = delete;
|
||||
|
||||
// but is movable
|
||||
gpu_data(gpu_data&&) = default;
|
||||
gpu_data& operator=(gpu_data&&) = default;
|
||||
|
||||
|
||||
#ifdef DLIB_USE_CUDA
|
||||
void async_copy_to_device();
|
||||
void set_size(size_t new_size);
|
||||
#else
|
||||
// Note that calls to host() or device() will block until any async transfers are complete.
|
||||
void async_copy_to_device(){}
|
||||
|
||||
void set_size(size_t new_size)
|
||||
{
|
||||
if (new_size == 0)
|
||||
{
|
||||
data_size = 0;
|
||||
host_current = true;
|
||||
device_current = true;
|
||||
data_host.reset();
|
||||
data_device.reset();
|
||||
}
|
||||
else if (new_size != data_size)
|
||||
{
|
||||
data_size = new_size;
|
||||
host_current = true;
|
||||
device_current = true;
|
||||
data_host.reset(new float[new_size], std::default_delete<float[]>());
|
||||
data_device.reset();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
const float* host() const
|
||||
{
|
||||
copy_to_host();
|
||||
return data_host.get();
|
||||
}
|
||||
|
||||
float* host()
|
||||
{
|
||||
copy_to_host();
|
||||
device_current = false;
|
||||
return data_host.get();
|
||||
}
|
||||
|
||||
const float* device() const
|
||||
{
|
||||
#ifndef DLIB_USE_CUDA
|
||||
DLIB_CASSERT(false, "CUDA NOT ENABLED");
|
||||
#endif
|
||||
copy_to_device();
|
||||
return data_device.get();
|
||||
}
|
||||
|
||||
float* device()
|
||||
{
|
||||
#ifndef DLIB_USE_CUDA
|
||||
DLIB_CASSERT(false, "CUDA NOT ENABLED");
|
||||
#endif
|
||||
copy_to_device();
|
||||
host_current = false;
|
||||
return data_device.get();
|
||||
}
|
||||
|
||||
size_t size() const { return data_size; }
|
||||
|
||||
|
||||
private:
|
||||
|
||||
#ifdef DLIB_USE_CUDA
|
||||
void copy_to_device() const;
|
||||
void copy_to_host() const;
|
||||
void wait_for_transfer_to_finish() const;
|
||||
#else
|
||||
void copy_to_device() const{}
|
||||
void copy_to_host() const{}
|
||||
void wait_for_transfer_to_finish() const{}
|
||||
#endif
|
||||
|
||||
|
||||
size_t data_size;
|
||||
mutable bool host_current;
|
||||
mutable bool device_current;
|
||||
mutable bool have_active_transfer;
|
||||
|
||||
std::shared_ptr<float> data_host;
|
||||
std::shared_ptr<float> data_device;
|
||||
std::shared_ptr<void> cuda_stream;
|
||||
};
|
||||
|
||||
inline void serialize(const gpu_data& item, std::ostream& out)
|
||||
{
|
||||
int version = 1;
|
||||
serialize(version, out);
|
||||
serialize(item.size(), out);
|
||||
auto data = item.host();
|
||||
for (size_t i = 0; i < item.size(); ++i)
|
||||
serialize(data[i], out);
|
||||
}
|
||||
|
||||
inline void deserialize(gpu_data& item, std::istream& in)
|
||||
{
|
||||
int version;
|
||||
deserialize(version, in);
|
||||
if (version != 1)
|
||||
throw serialization_error("Unexpected version found while deserializing dlib::gpu_data.");
|
||||
size_t s;
|
||||
deserialize(s, in);
|
||||
item.set_size(s);
|
||||
auto data = item.host();
|
||||
for (size_t i = 0; i < item.size(); ++i)
|
||||
deserialize(data[i], in);
|
||||
}
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class tensor
|
||||
|
|
Loading…
Reference in New Issue