Another CUDA performance improvements

This commit is contained in:
AlexeyAB 2019-01-18 16:29:54 +03:00
parent 5343aa4235
commit bf6b40f4e9
11 changed files with 94 additions and 11 deletions

View File

@ -316,6 +316,8 @@ struct layer {
float *col_image;
float * delta;
float * output;
int delta_pinned;
int output_pinned;
float * loss;
float * squared;
float * norms;
@ -582,6 +584,8 @@ typedef struct network {
float *output_gpu;
float *input_state_gpu;
float *input_pinned_cpu;
int input_pinned_cpu_flag;
float **input_gpu;
float **truth_gpu;
@ -777,6 +781,7 @@ LIB_API pthread_t load_data_in_thread(load_args args);
// cuda.h
LIB_API void cuda_pull_array(float *x_gpu, float *x, size_t n);
LIB_API void cuda_pull_array_async(float *x_gpu, float *x, size_t n);
LIB_API void cuda_set_device(int n);
// utils.h

View File

@ -692,6 +692,14 @@ extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int
check_error(cudaPeekAtLastError());
}
__global__ void simple_input_shortcut_kernel(float *in, int size, float *add, float *out)
{
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (id >= size) return;
out[id] = in[id] + add[id];
}
__global__ void input_shortcut_kernel(float *in, int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
{
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@ -711,6 +719,13 @@ __global__ void input_shortcut_kernel(float *in, int size, int minw, int minh, i
extern "C" void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
{
if (w1 == w2 && h1 == h2 && c1 == c2) {
int size = batch * w1 * h1 * c1;
simple_input_shortcut_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> >(in, size, add, out);
check_error(cudaPeekAtLastError());
return;
}
int minw = (w1 < w2) ? w1 : w2;
int minh = (h1 < h2) ? h1 : h2;
int minc = (c1 < c2) ? c1 : c2;

View File

@ -82,6 +82,27 @@ cudaStream_t get_cuda_stream() {
return streamsArray[i];
}
static cudaStream_t streamsArray2[16]; // cudaStreamSynchronize( get_cuda_memcpy_stream() );
static int streamInit2[16] = { 0 };
cudaStream_t get_cuda_memcpy_stream() {
int i = cuda_get_device();
if (!streamInit2[i]) {
cudaError_t status = cudaStreamCreate(&streamsArray2[i]);
//cudaError_t status = cudaStreamCreateWithFlags(&streamsArray2[i], cudaStreamNonBlocking);
if (status != cudaSuccess) {
printf(" cudaStreamCreate Memcpy error: %d \n", status);
const char *s = cudaGetErrorString(status);
char buffer[256];
printf("CUDA Error: %s\n", s);
status = cudaStreamCreateWithFlags(&streamsArray2[i], cudaStreamDefault);
check_error(status);
}
streamInit2[i] = 1;
}
return streamsArray2[i];
}
#ifdef CUDNN
cudnnHandle_t cudnn_handle()
@ -116,6 +137,7 @@ float *cuda_make_array(float *x, size_t n)
float *x_gpu;
size_t size = sizeof(float)*n;
cudaError_t status = cudaMalloc((void **)&x_gpu, size);
if (status != cudaSuccess) fprintf(stderr, " Try to set subdivisions=64 in your cfg-file. \n");
check_error(status);
if(x){
//status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
@ -200,6 +222,14 @@ void cuda_pull_array(float *x_gpu, float *x, size_t n)
cudaStreamSynchronize(get_cuda_stream());
}
void cuda_pull_array_async(float *x_gpu, float *x, size_t n)
{
size_t size = sizeof(float)*n;
cudaError_t status = cudaMemcpyAsync(x, x_gpu, size, cudaMemcpyDeviceToHost, get_cuda_stream());
check_error(status);
//cudaStreamSynchronize(get_cuda_stream());
}
#else // GPU
#include "cuda.h"
void cuda_set_device(int n) {}

View File

@ -37,6 +37,7 @@ extern "C" {
float cuda_compare(float *x_gpu, float *x, size_t n, char *s);
dim3 cuda_gridsize(size_t n);
cudaStream_t get_cuda_stream();
cudaStream_t get_cuda_memcpy_stream();
#ifdef __cplusplus
}
#endif // __cplusplus

View File

@ -1030,7 +1030,6 @@ void repack_input_gpu_2(float *input, float *re_packed_input, int w, int h, int
__global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c)
{
__shared__ uint32_t tmp[32];
int index = blockIdx.x*blockDim.x + threadIdx.x;
const int num_of_warps = blockDim.x / WARP_SIZE;
@ -1350,6 +1349,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
unsigned char *B, int ldb,
float *C, int ldc, float *mean_arr, float *bias_arr)
{
// total 57%
int index = blockIdx.x*blockDim.x + threadIdx.x;
__shared__ uint8_t A_s[6144*8/4];
@ -1363,7 +1363,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
int i_cur = index / N;
int local_i = i_cur - start_i;
// ~10%
for (int k = threadIdx.x * 64; k < shared_size; k += blockDim.x * 64) {
int x = start_i*lda + k;
if (x < (M*lda)) *((uint64_t *)(A_s + k / 8)) = *((uint64_t *)(A + x / 8));
@ -1371,7 +1371,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
__syncthreads();
int i, j, k, h;
// 47% = 29 + 10 + 8
j = index % N;
{ // out_h*out_w - one channel output size [169 - 173056]
i = index / N;
@ -1413,7 +1413,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
#endif
//#ifdef NOT_USED
// 32 thread X 64 bit = 2048 bit
// 32 thread X 64 bit = 2048 bit // 29%
for (; k < (K - 2048); k += 2048) { // l.size*l.size*l.c - one filter size [27 - 9216]
uint64_t c_bit64;
@ -1444,7 +1444,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
//#endif
//#ifdef NOT_USED
// 32 thread X 32 bit = 1024 bit
// 32 thread X 32 bit = 1024 bit // 10%
for (; k < (K - 1024); k += 1024) { // l.size*l.size*l.c - one filter size [27 - 9216]
//int64_t A_cur_index = (i*lda + k) / 8;
@ -1479,6 +1479,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
float bias_val = bias_arr[i];
//#ifdef NOT_USED
// 8%
for (; k < K; k += 256) { // l.size*l.size*l.c - one filter size [27 - 144 - 9216]
//ulonglong4 a_bit256 = *((ulonglong4 *)(A + (i*lda + k) / 8)); // weights
ulonglong4 a_bit256 = *((ulonglong4 *)(A_s + (local_i*lda + k) / 8)); // weights

View File

@ -35,6 +35,16 @@ void free_layer(layer l)
if (l.weight_updates) free(l.weight_updates);
if (l.align_bit_weights) free(l.align_bit_weights);
if (l.mean_arr) free(l.mean_arr);
#ifdef GPU
if (l.delta && l.delta_pinned) {
cudaFreeHost(l.delta);
l.delta = NULL;
}
if (l.output && l.output_pinned) {
cudaFreeHost(l.output);
l.output = NULL;
}
#endif // GPU
if (l.delta) free(l.delta);
if (l.output) free(l.output);
if (l.squared) free(l.squared);

View File

@ -856,6 +856,10 @@ void free_network(network net)
if (gpu_index >= 0) cuda_free(net.workspace);
else free(net.workspace);
if (net.input_state_gpu) cuda_free(net.input_state_gpu);
if (net.input_pinned_cpu) { // CPU
if (net.input_pinned_cpu_flag) cudaFreeHost(net.input_pinned_cpu);
else free(net.input_pinned_cpu);
}
if (*net.input_gpu) cuda_free(*net.input_gpu);
if (*net.truth_gpu) cuda_free(*net.truth_gpu);
if (net.input_gpu) free(net.input_gpu);

View File

@ -87,6 +87,8 @@ void forward_network_gpu(network net, network_state state)
}
*/
}
cudaStreamSynchronize(get_cuda_stream()); // sync CUDA-functions
//cudaStreamSynchronize(get_cuda_memcpy_stream()); // sync cudaMemcpyAsync()
//cudaDeviceSynchronize();
//show_total_time();
}
@ -444,7 +446,8 @@ float *network_predict_gpu(network net, float *input)
state.net = net;
//state.input = cuda_make_array(input, size); // memory will be allocated in the parse_network_cfg_custom()
state.input = net.input_state_gpu;
cuda_push_array(state.input, input, size);
memcpy(net.input_pinned_cpu, input, size * sizeof(float));
cuda_push_array(state.input, net.input_pinned_cpu, size);
state.truth = 0;
state.train = 0;
state.delta = 0;

View File

@ -829,10 +829,14 @@ network parse_network_cfg_custom(char *filename, int batch)
if(workspace_size){
//printf("%ld\n", workspace_size);
#ifdef GPU
get_cuda_stream();
get_cuda_memcpy_stream();
if(gpu_index >= 0){
net.workspace = cuda_make_array(0, workspace_size/sizeof(float) + 1);
int size = get_network_input_size(net) * net.batch;
net.input_state_gpu = cuda_make_array(0, size);
if (cudaSuccess == cudaHostAlloc(&net.input_pinned_cpu, size*sizeof(float), cudaHostRegisterMapped)) net.input_pinned_cpu_flag = 1;
else net.input_pinned_cpu = calloc(size, sizeof(float));
// pre-allocate memory for inference on Tensor Cores (fp16)
if (net.cudnn_half) {

View File

@ -67,7 +67,7 @@ void resize_route_layer(route_layer *l, network *net)
l->output_gpu = cuda_make_array(l->output, l->outputs*l->batch);
l->delta_gpu = cuda_make_array(l->delta, l->outputs*l->batch);
#endif
}
void forward_route_layer(const route_layer l, network_state state)
@ -110,7 +110,8 @@ void forward_route_layer_gpu(const route_layer l, network_state state)
float *input = state.net.layers[index].output_gpu;
int input_size = l.input_sizes[i];
for(j = 0; j < l.batch; ++j){
copy_ongpu(input_size, input + j*input_size, 1, l.output_gpu + offset + j*l.outputs, 1);
//copy_ongpu(input_size, input + j*input_size, 1, l.output_gpu + offset + j*l.outputs, 1);
simple_copy_ongpu(input_size, input + j*input_size, l.output_gpu + offset + j*l.outputs);
}
offset += input_size;
}

View File

@ -53,6 +53,14 @@ layer make_yolo_layer(int batch, int w, int h, int n, int total, int *mask, int
l.backward_gpu = backward_yolo_layer_gpu;
l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
free(l.output);
if (cudaSuccess == cudaHostAlloc(&l.output, batch*l.outputs*sizeof(float), cudaHostRegisterMapped)) l.output_pinned = 1;
else l.output = calloc(batch*l.outputs, sizeof(float));
free(l.delta);
if (cudaSuccess == cudaHostAlloc(&l.delta, batch*l.outputs*sizeof(float), cudaHostRegisterMapped)) l.delta_pinned = 1;
else l.delta = calloc(batch*l.outputs, sizeof(float));
#endif
fprintf(stderr, "yolo\n");
@ -411,13 +419,14 @@ void forward_yolo_layer_gpu(const layer l, network_state state)
}
}
if(!state.train || l.onlyforward){
cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
//cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
cuda_pull_array_async(l.output_gpu, l.output, l.batch*l.outputs);
return;
}
//cuda_pull_array(l.output_gpu, state.input, l.batch*l.inputs);
float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
cuda_pull_array(l.output_gpu, in_cpu, l.batch*l.inputs);
cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
memcpy(in_cpu, l.output, l.batch*l.outputs*sizeof(float));
float *truth_cpu = 0;
if (state.truth) {
int num_truth = l.batch*l.truths;