mirror of https://github.com/AlexeyAB/darknet.git
CHECK_CUDA is used everywhere
This commit is contained in:
parent
ce2e0eff00
commit
12b6e93893
|
@ -48,7 +48,7 @@ extern "C" void forward_avgpool_layer_gpu(avgpool_layer layer, network_state sta
|
|||
size_t n = layer.c*layer.batch;
|
||||
|
||||
forward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(n, layer.w, layer.h, layer.c, state.input, layer.output_gpu);
|
||||
check_error(cudaPeekAtLastError());
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void backward_avgpool_layer_gpu(avgpool_layer layer, network_state state)
|
||||
|
@ -56,6 +56,6 @@ extern "C" void backward_avgpool_layer_gpu(avgpool_layer layer, network_state st
|
|||
size_t n = layer.c*layer.batch;
|
||||
|
||||
backward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(n, layer.w, layer.h, layer.c, state.delta, layer.delta_gpu);
|
||||
check_error(cudaPeekAtLastError());
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
|
|
@ -723,7 +723,7 @@ extern "C" void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1,
|
|||
if (w1 == w2 && h1 == h2 && c1 == c2) {
|
||||
int size = batch * w1 * h1 * c1;
|
||||
simple_input_shortcut_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> >(in, size, add, out);
|
||||
check_error(cudaPeekAtLastError());
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -54,5 +54,7 @@ void col2im_ongpu(float *data_col,
|
|||
num_kernels, data_col, height, width, ksize, pad,
|
||||
stride, height_col,
|
||||
width_col, data_im);
|
||||
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
|
|
@ -116,7 +116,7 @@ void fast_binarize_weights_gpu(float *weights, int n, int size, float *binary, f
|
|||
{
|
||||
if (size % 32 == 0) {
|
||||
size_t gridsize = n * size;
|
||||
const int num_blocks = gridsize / BLOCK + 1;
|
||||
const int num_blocks = get_number_of_blocks(gridsize, BLOCK);// gridsize / BLOCK + 1;
|
||||
|
||||
set_zero_kernel << <(n/BLOCK + 1), BLOCK, 0, get_cuda_stream() >> > (mean_arr_gpu, n);
|
||||
reduce_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (weights, n, size, mean_arr_gpu);
|
||||
|
@ -139,7 +139,7 @@ __global__ void cuda_f32_to_f16(float* input_f32, size_t size, half *output_f16)
|
|||
}
|
||||
|
||||
void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16) {
|
||||
cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, (half *)output_f16);
|
||||
cuda_f32_to_f16 <<< get_number_of_blocks(size, BLOCK), BLOCK, 0, get_cuda_stream() >>> (input_f32, size, (half *)output_f16);
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -151,7 +151,7 @@ __global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32)
|
|||
}
|
||||
|
||||
void cuda_convert_f16_to_f32(float* input_f16, size_t size, float *output_f32) {
|
||||
cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> ((half *)input_f16, size, output_f32);
|
||||
cuda_f16_to_f32 <<< get_number_of_blocks(size, BLOCK), BLOCK, 0, get_cuda_stream() >>> ((half *)input_f16, size, output_f32);
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -159,7 +159,7 @@ half *cuda_make_f16_from_f32_array(float *src, size_t n)
|
|||
{
|
||||
half *dst16;
|
||||
size_t size = sizeof(half)*n;
|
||||
check_error(cudaMalloc((void **)&dst16, size));
|
||||
CHECK_CUDA(cudaMalloc((void **)&dst16, size));
|
||||
if (src) {
|
||||
cuda_convert_f32_to_f16(src, n, (float *)dst16);
|
||||
}
|
||||
|
|
|
@ -574,7 +574,7 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
|
|||
// check for excessive memory consumption
|
||||
size_t free_byte;
|
||||
size_t total_byte;
|
||||
check_error(cudaMemGetInfo(&free_byte, &total_byte));
|
||||
CHECK_CUDA(cudaMemGetInfo(&free_byte, &total_byte));
|
||||
if (l->workspace_size > free_byte || l->workspace_size >= total_byte / 2) {
|
||||
printf(" used slow CUDNN algo without Workspace! Need memory: %zu, available: %zu\n", l->workspace_size, (free_byte < total_byte/2) ? free_byte : total_byte/2);
|
||||
cudnn_convolutional_setup(l, cudnn_smallest);
|
||||
|
@ -759,19 +759,19 @@ void binary_align_weights(convolutional_layer *l)
|
|||
l->align_workspace_size = l->bit_align * l->size * l->size * l->c;
|
||||
status = cudaMalloc((void **)&l->align_workspace_gpu, l->align_workspace_size * sizeof(float));
|
||||
status = cudaMalloc((void **)&l->transposed_align_workspace_gpu, l->align_workspace_size * sizeof(float));
|
||||
check_error(status);
|
||||
CHECK_CUDA(status);
|
||||
|
||||
//l->align_bit_weights_gpu = cuda_make_array(l->align_bit_weights, l->align_bit_weights_size * sizeof(char)/sizeof(float));
|
||||
status = cudaMalloc((void **)&l->align_bit_weights_gpu, l->align_bit_weights_size);
|
||||
check_error(status);
|
||||
CHECK_CUDA(status);
|
||||
status = cudaMemcpy(l->align_bit_weights_gpu, l->align_bit_weights, l->align_bit_weights_size, cudaMemcpyHostToDevice);
|
||||
check_error(status);
|
||||
CHECK_CUDA(status);
|
||||
status = cudaMemcpy(l->binary_weights_gpu, l->binary_weights, m*k * sizeof(float), cudaMemcpyHostToDevice);
|
||||
check_error(status);
|
||||
CHECK_CUDA(status);
|
||||
|
||||
//l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n);
|
||||
cuda_push_array(l->mean_arr_gpu, l->mean_arr, l->n);
|
||||
cudaDeviceSynchronize();
|
||||
CHECK_CUDA(cudaDeviceSynchronize());
|
||||
#endif // GPU
|
||||
|
||||
free(align_weights);
|
||||
|
|
|
@ -196,12 +196,12 @@ extern "C" void forward_crop_layer_gpu(crop_layer layer, network_state state)
|
|||
int size = layer.batch * layer.w * layer.h;
|
||||
|
||||
levels_image_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >>>(state.input, layer.rand_gpu, layer.batch, layer.w, layer.h, state.train, layer.saturation, layer.exposure, translate, scale, layer.shift);
|
||||
check_error(cudaPeekAtLastError());
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
|
||||
size = layer.batch*layer.c*layer.out_w*layer.out_h;
|
||||
|
||||
forward_crop_layer_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >>>(state.input, layer.rand_gpu, size, layer.c, layer.h, layer.w, layer.out_h, layer.out_w, state.train, layer.flip, radians, layer.output_gpu);
|
||||
check_error(cudaPeekAtLastError());
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
|
||||
/*
|
||||
cuda_pull_array(layer.output_gpu, layer.output, size);
|
||||
|
|
|
@ -255,6 +255,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
|||
|
||||
int draw_precision = 0;
|
||||
int calc_map_for_each = 4 * train_images_num / (net.batch * net.subdivisions); // calculate mAP for each 4 Epochs
|
||||
if (calc_map) printf(" Next mAP calculation at %d iterations \n", (iter_map + calc_map_for_each));
|
||||
if (calc_map && (i >= (iter_map + calc_map_for_each) || i == net.max_batches) && i >= net.burn_in && i >= 1000) {
|
||||
if (l.random) {
|
||||
printf("Resizing to initial size: %d x %d \n", init_w, init_h);
|
||||
|
@ -275,7 +276,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
|||
|
||||
iter_map = i;
|
||||
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, &net_combined);
|
||||
printf("\n mean_average_precision = %f \n", mean_average_precision);
|
||||
printf("\n mean_average_precision (mAP@0.5) = %f \n", mean_average_precision);
|
||||
draw_precision = 1;
|
||||
}
|
||||
#ifdef OPENCV
|
||||
|
@ -973,13 +974,9 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
|||
thresh_calc_avg_iou, tp_for_thresh, fp_for_thresh, unique_truth_count - tp_for_thresh, avg_iou * 100);
|
||||
|
||||
mean_average_precision = mean_average_precision / classes;
|
||||
if (iou_thresh == 0.5) {
|
||||
printf("\n mean average precision (mAP) = %f, or %2.2f %% \n", mean_average_precision, mean_average_precision * 100);
|
||||
}
|
||||
else {
|
||||
printf("\n average precision (AP) = %f, or %2.2f %% for IoU threshold = %f \n", mean_average_precision, mean_average_precision * 100, iou_thresh);
|
||||
}
|
||||
printf("\n IoU threshold = %2.2f %% \n", iou_thresh * 100);
|
||||
|
||||
printf(" mean average precision (mAP@%0.2f) = %f, or %2.2f %% \n", iou_thresh, mean_average_precision, mean_average_precision * 100);
|
||||
|
||||
for (i = 0; i < classes; ++i) {
|
||||
free(pr[i]);
|
||||
|
|
|
@ -28,7 +28,7 @@ void forward_dropout_layer_gpu(dropout_layer layer, network_state state)
|
|||
*/
|
||||
|
||||
yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >>>(state.input, size, layer.rand_gpu, layer.probability, layer.scale);
|
||||
check_error(cudaPeekAtLastError());
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void backward_dropout_layer_gpu(dropout_layer layer, network_state state)
|
||||
|
@ -37,5 +37,5 @@ void backward_dropout_layer_gpu(dropout_layer layer, network_state state)
|
|||
int size = layer.inputs*layer.batch;
|
||||
|
||||
yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >>>(state.delta, size, layer.rand_gpu, layer.probability, layer.scale);
|
||||
check_error(cudaPeekAtLastError());
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
|
|
@ -2669,9 +2669,10 @@ void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA,
|
|||
{
|
||||
cublasHandle_t handle = blas_handle();
|
||||
cudaError_t stream_status = cublasSetStream(handle, get_cuda_stream());
|
||||
CHECK_CUDA(stream_status);
|
||||
cudaError_t status = cublasSgemm(handle, (TB ? CUBLAS_OP_T : CUBLAS_OP_N),
|
||||
(TA ? CUBLAS_OP_T : CUBLAS_OP_N), N, M, K, &ALPHA, B_gpu, ldb, A_gpu, lda, &BETA, C_gpu, ldc);
|
||||
check_error(status);
|
||||
CHECK_CUDA(status);
|
||||
}
|
||||
|
||||
void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA,
|
||||
|
|
|
@ -118,7 +118,7 @@ extern "C" void forward_maxpool_layer_gpu(maxpool_layer layer, network_state sta
|
|||
size_t n = h*w*c*layer.batch;
|
||||
|
||||
forward_maxpool_layer_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream()>>>(n, layer.h, layer.w, layer.c, layer.stride, layer.size, layer.pad, state.input, layer.output_gpu, layer.indexes_gpu);
|
||||
check_error(cudaPeekAtLastError());
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void backward_maxpool_layer_gpu(maxpool_layer layer, network_state state)
|
||||
|
@ -126,6 +126,6 @@ extern "C" void backward_maxpool_layer_gpu(maxpool_layer layer, network_state st
|
|||
size_t n = layer.h*layer.w*layer.c*layer.batch;
|
||||
|
||||
backward_maxpool_layer_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(n, layer.h, layer.w, layer.c, layer.stride, layer.size, layer.pad, layer.delta_gpu, state.delta, layer.indexes_gpu);
|
||||
check_error(cudaPeekAtLastError());
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
|
|
@ -865,9 +865,9 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
|
|||
// pre-allocate memory for inference on Tensor Cores (fp16)
|
||||
if (net.cudnn_half) {
|
||||
*net.max_input16_size = max_inputs;
|
||||
check_error(cudaMalloc((void **)net.input16_gpu, *net.max_input16_size * sizeof(short))); //sizeof(half)
|
||||
CHECK_CUDA(cudaMalloc((void **)net.input16_gpu, *net.max_input16_size * sizeof(short))); //sizeof(half)
|
||||
*net.max_output16_size = max_outputs;
|
||||
check_error(cudaMalloc((void **)net.output16_gpu, *net.max_output16_size * sizeof(short))); //sizeof(half)
|
||||
CHECK_CUDA(cudaMalloc((void **)net.output16_gpu, *net.max_output16_size * sizeof(short))); //sizeof(half)
|
||||
}
|
||||
if (workspace_size) {
|
||||
fprintf(stderr, " Allocate additional workspace_size = %1.2f MB \n", (float)workspace_size/1000000);
|
||||
|
|
Loading…
Reference in New Issue