mirror of https://github.com/AlexeyAB/darknet.git
Fix [implicit] layer
This commit is contained in:
parent
846c79b6d4
commit
81b768bae0
|
@ -2460,12 +2460,14 @@ __global__ void backward_implicit_kernel(int size, int batch, int nweights, floa
|
|||
const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (id >= size) return;
|
||||
|
||||
weight_updates_gpu[id % nweights] += delta_gpu[id];
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
weight_updates_gpu[id] += delta_gpu[id + i * nweights];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void backward_implicit_gpu(int batch, int nweights, float *weight_updates_gpu, float *delta_gpu)
|
||||
{
|
||||
int size = batch * nweights;
|
||||
int size = nweights;
|
||||
backward_implicit_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, batch, nweights, weight_updates_gpu, delta_gpu);
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
|
|
@ -67,7 +67,6 @@ void forward_implicit_layer(const layer l, network_state state)
|
|||
void backward_implicit_layer(const layer l, network_state state)
|
||||
{
|
||||
int i;
|
||||
#pragma omp parallel for
|
||||
for (i = 0; i < l.nweights * l.batch; ++i) {
|
||||
l.weight_updates[i % l.nweights] += l.delta[i];
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue