diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index f269f04e..4dcc9fac 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -943,11 +943,15 @@ __global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, i if (weights_normalizion == RELU_NORMALIZATION) w = relu(w) / sum; else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum; + if (weights_normalizion == RELU_NORMALIZATION) grad = w; + else if (weights_normalizion == SOFTMAX_NORMALIZATION) grad = w*(1-w); + delta_out[id] += delta_in[id] * w; // [0 or c or (c, h ,w)] float weights_update_tmp = delta_in[id] * in[id] * grad; if (!isnan(weights_update_tmp) && !isinf(weights_update_tmp)) - weight_updates_gpu[src_i / step] += weights_update_tmp; + atomicAdd(&weight_updates_gpu[src_i / step], weights_update_tmp); + //weight_updates_gpu[src_i / step] += weights_update_tmp; } else delta_out[id] += delta_in[id]; @@ -967,11 +971,15 @@ __global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, i if (weights_normalizion == RELU_NORMALIZATION) w = relu(w) / sum; else if (weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum; + if (weights_normalizion == RELU_NORMALIZATION) grad = w; + else if (weights_normalizion == SOFTMAX_NORMALIZATION) grad = w*(1 - w); + layer_delta[add_index] += delta_in[id] * w; float weights_update_tmp = delta_in[id] * add[add_index] * grad; if (!isnan(weights_update_tmp) && !isinf(weights_update_tmp)) - weight_updates_gpu[weights_index] += weights_update_tmp; + atomicAdd(&weight_updates_gpu[weights_index], weights_update_tmp); + //weight_updates_gpu[weights_index] += weights_update_tmp; } else layer_delta[add_index] += delta_in[id]; }