From b267e34487231a401e79699c24294072096ae3f9 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Sun, 5 Jan 2020 14:55:21 +0300 Subject: [PATCH] multilayer [shortcut] in progress --- build/darknet/x64/partial.cmd | 4 ++ include/darknet.h | 6 +++ src/blas.c | 54 +++++++++++++++++++ src/blas.h | 5 ++ src/blas_kernels.cu | 69 ++++++++++++++++++++++++ src/dark_cuda.c | 15 ++++++ src/dark_cuda.h | 1 + src/layer.c | 5 ++ src/network.c | 2 +- src/network_kernels.cu | 4 +- src/parser.c | 51 +++++++++++++++--- src/shortcut_layer.c | 98 ++++++++++++++++++++++++----------- src/shortcut_layer.h | 5 +- 13 files changed, 277 insertions(+), 42 deletions(-) diff --git a/build/darknet/x64/partial.cmd b/build/darknet/x64/partial.cmd index 5aa972f3..79292b50 100644 --- a/build/darknet/x64/partial.cmd +++ b/build/darknet/x64/partial.cmd @@ -6,6 +6,10 @@ rem Download Yolo9000: http://pjreddie.com/media/files/yolo9000.weights rem darknet.exe partial cfg/tiny-yolo-voc.cfg tiny-yolo-voc.weights tiny-yolo-voc.conv.13 13 + +darknet.exe partial cfg/csresnext50.cfg csresnext50.weights csresnext50.conv.75 75 + + darknet.exe partial cfg/darknet53_448.cfg darknet53_448.weights darknet53.conv.74 74 diff --git a/include/darknet.h b/include/darknet.h index 43bf23d8..bd85e880 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -328,6 +328,8 @@ struct layer { int * indexes; int * input_layers; int * input_sizes; + float **layers_output; + float **layers_delta; int * map; int * counts; float ** sums; @@ -575,6 +577,10 @@ struct layer { float *gt_gpu; float *a_avg_gpu; + + int *input_sizes_gpu; + float **layers_output_gpu; + float **layers_delta_gpu; #ifdef CUDNN cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc; cudnnTensorDescriptor_t srcTensorDesc16, dstTensorDesc16; diff --git a/src/blas.c b/src/blas.c index d00cb89d..9533f74d 100644 --- a/src/blas.c +++ b/src/blas.c @@ -68,6 +68,60 @@ void weighted_delta_cpu(float *a, float *b, float *s, float *da, float *db, floa } } +void shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, float **layers_output, float *out, float *in) +{ + int id; + #pragma omp parallel for + for (id = 0; id < size; ++id) { + + int src_id = id; + int src_i = src_id % src_outputs; + src_id /= src_outputs; + int src_b = src_id; + + out[id] = in[id]; + + // layers + for (int i = 0; i < n; ++i) { + int add_outputs = outputs_of_layers[i]; + if (src_i < add_outputs) { + int add_index = add_outputs*src_b + src_i; + int out_index = id; + + float *add = layers_output[i]; + out[out_index] += add[add_index]; + } + } + } +} + +void backward_shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, + float **layers_delta, float *delta_out, float *delta_in) +{ + int id; + #pragma omp parallel for + for (id = 0; id < size; ++id) { + int src_id = id; + int src_i = src_id % src_outputs; + src_id /= src_outputs; + int src_b = src_id; + + delta_out[id] += delta_in[id]; + + // layers + for (int i = 0; i < n; ++i) { + int add_outputs = outputs_of_layers[i]; + if (src_i < add_outputs) { + int add_index = add_outputs*src_b + src_i; + int out_index = id; + + float *layer_delta = layers_delta[i]; + layer_delta[add_index] += delta_in[id]; + } + } + } +} + void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out) { int stride = w1/w2; diff --git a/src/blas.h b/src/blas.h index ebe7fb0e..a2f7e5d4 100644 --- a/src/blas.h +++ b/src/blas.h @@ -32,6 +32,9 @@ void fill_cpu(int N, float ALPHA, float * X, int INCX); float dot_cpu(int N, float *X, int INCX, float *Y, int INCY); void test_gpu_blas(); void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out); +void shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, float **layers_output, float *out, float *in); +void backward_shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, + float **layers_delta, float *delta_out, float *delta_in); void mean_cpu(float *x, int batch, int filters, int spatial, float *mean); void variance_cpu(float *x, float *mean, int batch, int filters, int spatial, float *variance); @@ -84,6 +87,8 @@ void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *varianc void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance); void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean); void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out); +void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in); +void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_delta_gpu, float *delta_out, float *delta_in); void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out); void scale_bias_gpu(float *output, float *biases, int batch, int n, int size); void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates); diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index 7f84f337..d1ee4b80 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -668,6 +668,75 @@ extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX) CHECK_CUDA(cudaPeekAtLastError()); } +__global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in) +{ + const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (id >= size) return; + + int src_id = id; + int src_i = src_id % src_outputs; + src_id /= src_outputs; + int src_b = src_id; + + out[id] = in[id]; + + // layers + for (int i = 0; i < n; ++i) { + int add_outputs = outputs_of_layers_gpu[i]; + if (src_i < add_outputs) { + int add_index = add_outputs*src_b + src_i; + int out_index = id; + + float *add = layers_output_gpu[i]; + out[out_index] += add[add_index]; + } + } +} + +extern "C" void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in) +{ + //printf(" src_outputs = %d, batch = %d, n = %d \n", src_outputs, batch, n); + int size = batch * src_outputs; + shortcut_multilayer_kernel << > > (size, src_outputs, batch, n, outputs_of_layers_gpu, layers_output_gpu, out, in); + CHECK_CUDA(cudaPeekAtLastError()); +} + + + +__global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, + float **layers_delta_gpu, float *delta_out, float *delta_in) +{ + const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (id >= size) return; + + int src_id = id; + int src_i = src_id % src_outputs; + src_id /= src_outputs; + int src_b = src_id; + + delta_out[id] += delta_in[id]; + + // layers + for (int i = 0; i < n; ++i) { + int add_outputs = outputs_of_layers_gpu[i]; + if (src_i < add_outputs) { + int add_index = add_outputs*src_b + src_i; + int out_index = id; + + float *layer_delta = layers_delta_gpu[i]; + layer_delta[add_index] += delta_in[id]; + } + } +} + +extern "C" void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_delta_gpu, float *delta_out, float *delta_in) +{ + //printf(" src_outputs = %d, batch = %d, n = %d \n", src_outputs, batch, n); + int size = batch * src_outputs; + backward_shortcut_multilayer_kernel << > > (size, src_outputs, batch, n, outputs_of_layers_gpu, layers_delta_gpu, delta_out, delta_in); + CHECK_CUDA(cudaPeekAtLastError()); +} + __global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out) { int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; diff --git a/src/dark_cuda.c b/src/dark_cuda.c index 7606be3c..e7b6e6af 100644 --- a/src/dark_cuda.c +++ b/src/dark_cuda.c @@ -369,6 +369,21 @@ float *cuda_make_array(float *x, size_t n) return x_gpu; } +void **cuda_make_array_pointers(void **x, size_t n) +{ + void **x_gpu; + size_t size = sizeof(void*) * n; + cudaError_t status = cudaMalloc((void **)&x_gpu, size); + if (status != cudaSuccess) fprintf(stderr, " Try to set subdivisions=64 in your cfg-file. \n"); + CHECK_CUDA(status); + if (x) { + status = cudaMemcpyAsync(x_gpu, x, size, cudaMemcpyDefault, get_cuda_stream()); + CHECK_CUDA(status); + } + if (!x_gpu) error("Cuda malloc failed\n"); + return x_gpu; +} + void cuda_random(float *x_gpu, size_t n) { static curandGenerator_t gen[16]; diff --git a/src/dark_cuda.h b/src/dark_cuda.h index fe4176cf..4798d729 100644 --- a/src/dark_cuda.h +++ b/src/dark_cuda.h @@ -62,6 +62,7 @@ extern "C" { float *cuda_make_array_pinned_preallocated(float *x, size_t n); float *cuda_make_array_pinned(float *x, size_t n); float *cuda_make_array(float *x, size_t n); + void **cuda_make_array_pointers(void **x, size_t n); int *cuda_make_int_array(size_t n); int *cuda_make_int_array_new_api(int *x, size_t n); void cuda_push_array(float *x_gpu, float *x, size_t n); diff --git a/src/layer.c b/src/layer.c index 2ae781d0..1acc4bf9 100644 --- a/src/layer.c +++ b/src/layer.c @@ -65,6 +65,8 @@ void free_layer_custom(layer l, int keep_cudnn_desc) if (l.indexes) free(l.indexes); if (l.input_layers) free(l.input_layers); if (l.input_sizes) free(l.input_sizes); + if (l.layers_output) free(l.layers_output); + if (l.layers_delta) free(l.layers_delta); if (l.map) free(l.map); if (l.rand) free(l.rand); if (l.cost) free(l.cost); @@ -190,6 +192,9 @@ void free_layer_custom(layer l, int keep_cudnn_desc) if (l.rand_gpu) cuda_free(l.rand_gpu); if (l.squared_gpu) cuda_free(l.squared_gpu); if (l.norms_gpu) cuda_free(l.norms_gpu); + if (l.input_sizes_gpu) cuda_free(l.input_sizes_gpu); + if (l.layers_output_gpu) cuda_free(l.layers_output_gpu); + if (l.layers_delta_gpu) cuda_free(l.layers_delta_gpu); // CONV-LSTM if (l.f_gpu) cuda_free(l.f_gpu); diff --git a/src/network.c b/src/network.c index 0a2117a8..0de38420 100644 --- a/src/network.c +++ b/src/network.c @@ -538,7 +538,7 @@ int resize_network(network *net, int w, int h) }else if(l.type == ROUTE){ resize_route_layer(&l, net); }else if (l.type == SHORTCUT) { - resize_shortcut_layer(&l, w, h); + resize_shortcut_layer(&l, w, h, net); }else if (l.type == SCALE_CHANNELS) { resize_scale_channels_layer(&l, net); }else if (l.type == DROPOUT) { diff --git a/src/network_kernels.cu b/src/network_kernels.cu index 5386a420..2146f0ae 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -130,8 +130,8 @@ void forward_network_gpu(network net, network_state state) printf("\n\nSorted by time:\n"); qsort(sorted_avg_time_per_layer, net.n, sizeof(time_benchmark_layers), time_comparator); for (i = 0; i < net.n; ++i) { - //printf("\layer %d - type: %d - avg_time %lf ms \n", avg_time_per_layer[i].layer_id, avg_time_per_layer[i].layer_type, avg_time_per_layer[i].time); - printf("\%d - layer %d - type: %d - avg_time %lf ms \n", i, sorted_avg_time_per_layer[i].layer_id, sorted_avg_time_per_layer[i].layer_type, sorted_avg_time_per_layer[i].time); + //printf("layer %d - type: %d - avg_time %lf ms \n", avg_time_per_layer[i].layer_id, avg_time_per_layer[i].layer_type, avg_time_per_layer[i].time); + printf("%d - layer %d - type: %d - avg_time %lf ms \n", i, sorted_avg_time_per_layer[i].layer_id, sorted_avg_time_per_layer[i].layer_type, sorted_avg_time_per_layer[i].time); } } diff --git a/src/parser.c b/src/parser.c index ec9e3007..d4b13ccb 100644 --- a/src/parser.c +++ b/src/parser.c @@ -817,15 +817,54 @@ layer parse_shortcut(list *options, size_params params, network net) ACTIVATION activation = get_activation(activation_s); int assisted_excitation = option_find_float_quiet(options, "assisted_excitation", 0); + //char *l = option_find(options, "from"); + //int index = atoi(l); + //if(index < 0) index = params.index + index; + char *l = option_find(options, "from"); - int index = atoi(l); - if(index < 0) index = params.index + index; + int len = strlen(l); + if (!l) error("Route Layer must specify input layers: from = ..."); + int n = 1; + int i; + for (i = 0; i < len; ++i) { + if (l[i] == ',') ++n; + } - int batch = params.batch; - layer from = net.layers[index]; - if (from.antialiasing) from = *from.input_layer; + int* layers = (int*)calloc(n, sizeof(int)); + int* sizes = (int*)calloc(n, sizeof(int)); + float **layers_output = (float **)calloc(n, sizeof(float *)); + float **layers_delta = (float **)calloc(n, sizeof(float *)); + float **layers_output_gpu = (float **)calloc(n, sizeof(float *)); + float **layers_delta_gpu = (float **)calloc(n, sizeof(float *)); - layer s = make_shortcut_layer(batch, index, params.w, params.h, params.c, from.out_w, from.out_h, from.out_c, assisted_excitation, activation, params.train); + for (i = 0; i < n; ++i) { + int index = atoi(l); + l = strchr(l, ',') + 1; + if (index < 0) index = params.index + index; + layers[i] = index; + sizes[i] = params.net.layers[index].outputs; + layers_output[i] = params.net.layers[index].output; + layers_delta[i] = params.net.layers[index].delta; + + } + +#ifdef GPU + for (i = 0; i < n; ++i) { + layers_output_gpu[i] = params.net.layers[layers[i]].output_gpu; + layers_delta_gpu[i] = params.net.layers[layers[i]].delta_gpu; + } +#endif// GPU + + layer s = make_shortcut_layer(params.batch, n, layers, sizes, params.w, params.h, params.c, layers_output, layers_delta, layers_output_gpu, layers_delta_gpu, activation, params.train); + + for (i = 0; i < n; ++i) { + int index = layers[i]; + assert(params.w == net.layers[index].out_w && params.h == net.layers[index].out_h); + + if (params.w != net.layers[index].out_w || params.h != net.layers[index].out_h || params.c != net.layers[index].out_c) + fprintf(stderr, " w = %d, w2 = %d, h = %d, h2 = %d, c = %d, c2 = %d \n", + params.w, net.layers[index].out_w, params.h, net.layers[index].out_h, params.c, params.net.layers[index].out_c); + } return s; } diff --git a/src/shortcut_layer.c b/src/shortcut_layer.c index f9b9209a..f87d6b61 100644 --- a/src/shortcut_layer.c +++ b/src/shortcut_layer.c @@ -6,29 +6,36 @@ #include #include -layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2, int assisted_excitation, ACTIVATION activation, int train) +layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes, int w, int h, int c, + float **layers_output, float **layers_delta, float **layers_output_gpu, float **layers_delta_gpu, ACTIVATION activation, int train) { - if(assisted_excitation) fprintf(stderr, "Shortcut Layer - AE: %d\n", index); - else fprintf(stderr,"Shortcut Layer: %d\n", index); + fprintf(stderr, "Shortcut Layer: "); + int i; + for(i = 0; i < n; ++i) fprintf(stderr, "%d, ", input_layers[i]); + layer l = { (LAYER_TYPE)0 }; l.train = train; l.type = SHORTCUT; l.batch = batch; l.activation = activation; - l.w = w2; - l.h = h2; - l.c = c2; - l.out_w = w; - l.out_h = h; - l.out_c = c; + l.n = n; + l.input_layers = input_layers; + l.input_sizes = input_sizes; + l.layers_output = layers_output; + l.layers_delta = layers_delta; + + //l.w = w2; + //l.h = h2; + //l.c = c2; + l.w = l.out_w = w; + l.h = l.out_h = h; + l.c = l.out_c = c; l.outputs = w*h*c; l.inputs = l.outputs; - l.assisted_excitation = assisted_excitation; + //if(w != w2 || h != h2 || c != c2) fprintf(stderr, " w = %d, w2 = %d, h = %d, h2 = %d, c = %d, c2 = %d \n", w, w2, h, h2, c, c2); - if(w != w2 || h != h2 || c != c2) fprintf(stderr, " w = %d, w2 = %d, h = %d, h2 = %d, c = %d, c2 = %d \n", w, w2, h, h2, c, c2); - - l.index = index; + l.index = l.input_layers[0]; if (train) l.delta = (float*)calloc(l.outputs * batch, sizeof(float)); l.output = (float*)calloc(l.outputs * batch, sizeof(float)); @@ -47,17 +54,18 @@ layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int if (train) l.delta_gpu = cuda_make_array(l.delta, l.outputs*batch); l.output_gpu = cuda_make_array(l.output, l.outputs*batch); - if (l.assisted_excitation) - { - const int size = l.out_w * l.out_h * l.batch; - l.gt_gpu = cuda_make_array(NULL, size); - l.a_avg_gpu = cuda_make_array(NULL, size); - } + + l.input_sizes_gpu = cuda_make_array(input_sizes, l.n); + l.layers_output_gpu = cuda_make_array_pointers(layers_output_gpu, l.n); + l.layers_delta_gpu = cuda_make_array_pointers(layers_delta_gpu, l.n); #endif // GPU + + l.bflops = l.out_w * l.out_h * l.out_c * l.n / 1000000000.; + fprintf(stderr, " outputs:%4d x%4d x%4d %5.3f BF\n", l.out_w, l.out_h, l.out_c, l.bflops); return l; } -void resize_shortcut_layer(layer *l, int w, int h) +void resize_shortcut_layer(layer *l, int w, int h, network *net) { //assert(l->w == l->out_w); //assert(l->h == l->out_h); @@ -68,6 +76,13 @@ void resize_shortcut_layer(layer *l, int w, int h) if (l->train) l->delta = (float*)realloc(l->delta, l->outputs * l->batch * sizeof(float)); l->output = (float*)realloc(l->output, l->outputs * l->batch * sizeof(float)); + int i; + for (i = 0; i < l->n; ++i) { + int index = l->input_layers[i]; + l->input_sizes[i] = net->layers[index].outputs; + assert(l->w == net->layers[index].w && l->h == net->layers[index].h); + } + #ifdef GPU cuda_free(l->output_gpu); l->output_gpu = cuda_make_array(l->output, l->outputs*l->batch); @@ -82,7 +97,11 @@ void resize_shortcut_layer(layer *l, int w, int h) void forward_shortcut_layer(const layer l, network_state state) { - if (l.w == l.out_w && l.h == l.out_h && l.c == l.out_c) { + int from_w = state.net.layers[l.index].w; + int from_h = state.net.layers[l.index].h; + int from_c = state.net.layers[l.index].c; + + if (l.n == 1 && from_w == l.w && from_h == l.h && from_c == l.c) { int size = l.batch * l.w * l.h * l.c; int i; #pragma omp parallel for @@ -90,16 +109,16 @@ void forward_shortcut_layer(const layer l, network_state state) l.output[i] = state.input[i] + state.net.layers[l.index].output[i]; } else { - copy_cpu(l.outputs*l.batch, state.input, 1, l.output, 1); - shortcut_cpu(l.batch, l.w, l.h, l.c, state.net.layers[l.index].output, l.out_w, l.out_h, l.out_c, l.output); + shortcut_multilayer_cpu(l.outputs * l.batch, l.outputs, l.batch, l.n, l.input_sizes, l.layers_output, l.output, state.input); } + //copy_cpu(l.outputs*l.batch, state.input, 1, l.output, 1); + //shortcut_cpu(l.batch, from_w, from_h, from_c, state.net.layers[l.index].output, l.out_w, l.out_h, l.out_c, l.output); + //activate_array(l.output, l.outputs*l.batch, l.activation); if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.output); else if (l.activation == MISH) activate_array_mish(l.output, l.outputs*l.batch, l.activation_input, l.output); else activate_array_cpu_custom(l.output, l.outputs*l.batch, l.activation); - - if (l.assisted_excitation && state.train) assisted_excitation_forward(l, state); } void backward_shortcut_layer(const layer l, network_state state) @@ -108,8 +127,11 @@ void backward_shortcut_layer(const layer l, network_state state) else if (l.activation == MISH) gradient_array_mish(l.outputs*l.batch, l.activation_input, l.delta); else gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta); - axpy_cpu(l.outputs*l.batch, 1, l.delta, 1, state.delta, 1); - shortcut_cpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta, l.w, l.h, l.c, state.net.layers[l.index].delta); + backward_shortcut_multilayer_cpu(l.outputs * l.batch, l.outputs, l.batch, l.n, l.input_sizes, + l.layers_delta, state.delta, l.delta); + + //axpy_cpu(l.outputs*l.batch, 1, l.delta, 1, state.delta, 1); + //shortcut_cpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta, l.w, l.h, l.c, state.net.layers[l.index].delta); } #ifdef GPU @@ -118,13 +140,25 @@ void forward_shortcut_layer_gpu(const layer l, network_state state) //copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1); //simple_copy_ongpu(l.outputs*l.batch, state.input, l.output_gpu); //shortcut_gpu(l.batch, l.w, l.h, l.c, state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu); - input_shortcut_gpu(state.input, l.batch, l.w, l.h, l.c, state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu); + + //input_shortcut_gpu(state.input, l.batch, l.w, l.h, l.c, state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu); + + //----------- + //if (l.outputs == l.input_sizes[0]) + //if(l.n == 1) + //{ + // input_shortcut_gpu(state.input, l.batch, state.net.layers[l.index].w, state.net.layers[l.index].h, state.net.layers[l.index].c, + // state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu); + //} + //else + { + shortcut_multilayer_gpu(l.outputs, l.batch, l.n, l.input_sizes_gpu, l.layers_output_gpu, l.output_gpu, state.input); + } if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu); else if (l.activation == MISH) activate_array_mish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu); else activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); - if (l.assisted_excitation && state.train) assisted_excitation_forward_gpu(l, state); } void backward_shortcut_layer_gpu(const layer l, network_state state) @@ -133,7 +167,9 @@ void backward_shortcut_layer_gpu(const layer l, network_state state) else if (l.activation == MISH) gradient_array_mish_ongpu(l.outputs*l.batch, l.activation_input_gpu, l.delta_gpu); else gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu); - axpy_ongpu(l.outputs*l.batch, 1, l.delta_gpu, 1, state.delta, 1); - shortcut_gpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta_gpu, l.w, l.h, l.c, state.net.layers[l.index].delta_gpu); + backward_shortcut_multilayer_gpu(l.outputs, l.batch, l.n, l.input_sizes_gpu, l.layers_delta_gpu, state.delta, l.delta_gpu); + + //axpy_ongpu(l.outputs*l.batch, 1, l.delta_gpu, 1, state.delta, 1); + //shortcut_gpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta_gpu, l.w, l.h, l.c, state.net.layers[l.index].delta_gpu); } #endif diff --git a/src/shortcut_layer.h b/src/shortcut_layer.h index b245868f..9341a4a3 100644 --- a/src/shortcut_layer.h +++ b/src/shortcut_layer.h @@ -7,10 +7,11 @@ #ifdef __cplusplus extern "C" { #endif -layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2, int assisted_excitation, ACTIVATION activation, int train); +layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes, int w, int h, int c, + float **layers_output, float **layers_delta, float **layers_output_gpu, float **layers_delta_gpu, ACTIVATION activation, int train); void forward_shortcut_layer(const layer l, network_state state); void backward_shortcut_layer(const layer l, network_state state); -void resize_shortcut_layer(layer *l, int w, int h); +void resize_shortcut_layer(layer *l, int w, int h, network *net); #ifdef GPU void forward_shortcut_layer_gpu(const layer l, network_state state);