mirror of https://github.com/AlexeyAB/darknet.git
Fixed CUDA-implementation of back-propagation for [scale_channels]-layer
This commit is contained in:
parent
aeb9da6918
commit
1ed71f4b29
|
@ -1119,17 +1119,45 @@ extern "C" void scale_channels_gpu(float *in_w_h_c, int size, int channel_size,
|
|||
}
|
||||
|
||||
|
||||
__inline__ __device__
|
||||
float warpAllReduceSum(float val) {
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2)
|
||||
#if CUDART_VERSION >= 9000
|
||||
val += __shfl_xor_sync(0xffffffff, val, mask);
|
||||
#else
|
||||
val += __shfl_xor(val, mask);
|
||||
#endif
|
||||
return val;
|
||||
}
|
||||
|
||||
__global__ void backward_scale_channels_kernel(float *in_w_h_c_delta, int size, int channel_size,
|
||||
float *in_scales_c, float *out_from_delta,
|
||||
float *in_from_output, float *out_state_delta)
|
||||
{
|
||||
const int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
if (index < size) {
|
||||
out_state_delta[index / channel_size] += in_w_h_c_delta[index] * in_from_output[index]; // l.delta * from (should be divided by channel_size?)
|
||||
out_from_delta[index] += in_scales_c[index / channel_size] * in_w_h_c_delta[index]; // input * l.delta
|
||||
int osd_index = index / channel_size;
|
||||
|
||||
//out_state_delta[index / channel_size] += in_w_h_c_delta[index] / channel_size;
|
||||
//out_from_delta[index] = in_w_h_c_delta[index];
|
||||
if (index < size) {
|
||||
//out_state_delta[osd_index] += in_w_h_c_delta[index] * in_from_output[index]; // l.delta * from (should be divided by channel_size?)
|
||||
|
||||
int warp_id = index / 32;
|
||||
int index_warp_start = warp_id * 32;
|
||||
int osd_index_warp_start = index_warp_start / channel_size;
|
||||
int osd_index_warp_end = (index_warp_start + 31) / channel_size;
|
||||
|
||||
if (osd_index_warp_start == osd_index_warp_end) // all thread in warp process the same channel
|
||||
{
|
||||
float sum = warpAllReduceSum(in_w_h_c_delta[index] * in_from_output[index]); // l.delta * from
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
atomicAdd(&out_state_delta[osd_index], sum);
|
||||
//out_state_delta[osd_index] += sum;
|
||||
}
|
||||
}
|
||||
else {
|
||||
atomicAdd(&out_state_delta[osd_index], in_w_h_c_delta[index] * in_from_output[index]); // l.delta * from
|
||||
}
|
||||
|
||||
out_from_delta[index] += in_scales_c[osd_index] * in_w_h_c_delta[index]; // input * l.delta // atomic isn't required here
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -86,9 +86,9 @@ void backward_scale_channels_layer(const layer l, network_state state)
|
|||
int i;
|
||||
#pragma omp parallel for
|
||||
for (i = 0; i < size; ++i) {
|
||||
state.delta[i / channel_size] += l.delta[i] * from_output[i]; // l.delta * from (should be divided by channel_size?)
|
||||
state.delta[i / channel_size] += l.delta[i] * from_output[i] / channel_size; // l.delta * from (should be divided by channel_size?)
|
||||
|
||||
from_delta[i] = state.input[i / channel_size] * l.delta[i]; // input * l.delta
|
||||
from_delta[i] += state.input[i / channel_size] * l.delta[i]; // input * l.delta
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,7 +112,6 @@ void backward_scale_channels_layer_gpu(const layer l, network_state state)
|
|||
float *from_output = state.net.layers[l.index].output_gpu;
|
||||
float *from_delta = state.net.layers[l.index].delta_gpu;
|
||||
|
||||
|
||||
backward_scale_channels_gpu(l.delta_gpu, size, channel_size, state.input, from_delta, from_output, state.delta);
|
||||
}
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue