mirror of https://github.com/AlexeyAB/darknet.git
Accelerated BiFPN back-propagation
This commit is contained in:
parent
3cb9125b95
commit
f6baa62c9b
|
@ -9,6 +9,16 @@
|
|||
#include "utils.h"
|
||||
#include "tree.h"
|
||||
|
||||
__inline__ __device__
|
||||
float warpAllReduceSum(float val) {
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2)
|
||||
#if CUDART_VERSION >= 9000
|
||||
val += __shfl_xor_sync(0xffffffff, val, mask);
|
||||
#else
|
||||
val += __shfl_xor(val, mask);
|
||||
#endif
|
||||
return val;
|
||||
}
|
||||
|
||||
__global__ void compare_2_arrays_kernel(float *one, float *two, int size)
|
||||
{
|
||||
|
@ -949,9 +959,18 @@ __global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, i
|
|||
delta_out[id] += delta_in[id] * w; // [0 or c or (c, h ,w)]
|
||||
float weights_update_tmp = delta_in[id] * in[id] * grad;
|
||||
|
||||
if (!isnan(weights_update_tmp) && !isinf(weights_update_tmp))
|
||||
atomicAdd(&weight_updates_gpu[src_i / step], weights_update_tmp);
|
||||
//weight_updates_gpu[src_i / step] += weights_update_tmp;
|
||||
if (layer_step == 1 && (size/32) > (id/32 + 1)) {
|
||||
float wu = warpAllReduceSum(weights_update_tmp);
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
if (!isnan(wu) && !isinf(wu))
|
||||
atomicAdd(&weight_updates_gpu[src_i / step], wu);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (!isnan(weights_update_tmp) && !isinf(weights_update_tmp))
|
||||
atomicAdd(&weight_updates_gpu[src_i / step], weights_update_tmp);
|
||||
//weight_updates_gpu[src_i / step] += weights_update_tmp;
|
||||
}
|
||||
}
|
||||
else delta_out[id] += delta_in[id];
|
||||
|
||||
|
@ -977,9 +996,18 @@ __global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, i
|
|||
layer_delta[add_index] += delta_in[id] * w;
|
||||
float weights_update_tmp = delta_in[id] * add[add_index] * grad;
|
||||
|
||||
if (!isnan(weights_update_tmp) && !isinf(weights_update_tmp))
|
||||
atomicAdd(&weight_updates_gpu[weights_index], weights_update_tmp);
|
||||
//weight_updates_gpu[weights_index] += weights_update_tmp;
|
||||
if (layer_step == 1 && (size / 32) > (id / 32 + 1)) {
|
||||
float wu = warpAllReduceSum(weights_update_tmp);
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
if (!isnan(wu) && !isinf(wu))
|
||||
atomicAdd(&weight_updates_gpu[weights_index], wu);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (!isnan(weights_update_tmp) && !isinf(weights_update_tmp))
|
||||
atomicAdd(&weight_updates_gpu[weights_index], weights_update_tmp);
|
||||
//weight_updates_gpu[weights_index] += weights_update_tmp;
|
||||
}
|
||||
}
|
||||
else layer_delta[add_index] += delta_in[id];
|
||||
}
|
||||
|
@ -1480,16 +1508,7 @@ extern "C" void scale_channels_gpu(float *in_w_h_c, int size, int channel_size,
|
|||
}
|
||||
|
||||
|
||||
__inline__ __device__
|
||||
float warpAllReduceSum(float val) {
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2)
|
||||
#if CUDART_VERSION >= 9000
|
||||
val += __shfl_xor_sync(0xffffffff, val, mask);
|
||||
#else
|
||||
val += __shfl_xor(val, mask);
|
||||
#endif
|
||||
return val;
|
||||
}
|
||||
|
||||
|
||||
__global__ void backward_scale_channels_kernel(float *in_w_h_c_delta, int size, int channel_size, int batch_size, int scale_wh,
|
||||
float *in_scales_c, float *out_from_delta,
|
||||
|
|
Loading…
Reference in New Issue