mirror of https://github.com/AlexeyAB/darknet.git
multilayer [shortcut] in progress
This commit is contained in:
parent
8bd7dcbd58
commit
b267e34487
|
@ -6,6 +6,10 @@ rem Download Yolo9000: http://pjreddie.com/media/files/yolo9000.weights
|
|||
rem darknet.exe partial cfg/tiny-yolo-voc.cfg tiny-yolo-voc.weights tiny-yolo-voc.conv.13 13
|
||||
|
||||
|
||||
|
||||
darknet.exe partial cfg/csresnext50.cfg csresnext50.weights csresnext50.conv.75 75
|
||||
|
||||
|
||||
darknet.exe partial cfg/darknet53_448.cfg darknet53_448.weights darknet53.conv.74 74
|
||||
|
||||
|
||||
|
|
|
@ -328,6 +328,8 @@ struct layer {
|
|||
int * indexes;
|
||||
int * input_layers;
|
||||
int * input_sizes;
|
||||
float **layers_output;
|
||||
float **layers_delta;
|
||||
int * map;
|
||||
int * counts;
|
||||
float ** sums;
|
||||
|
@ -575,6 +577,10 @@ struct layer {
|
|||
|
||||
float *gt_gpu;
|
||||
float *a_avg_gpu;
|
||||
|
||||
int *input_sizes_gpu;
|
||||
float **layers_output_gpu;
|
||||
float **layers_delta_gpu;
|
||||
#ifdef CUDNN
|
||||
cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;
|
||||
cudnnTensorDescriptor_t srcTensorDesc16, dstTensorDesc16;
|
||||
|
|
54
src/blas.c
54
src/blas.c
|
@ -68,6 +68,60 @@ void weighted_delta_cpu(float *a, float *b, float *s, float *da, float *db, floa
|
|||
}
|
||||
}
|
||||
|
||||
void shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, float **layers_output, float *out, float *in)
|
||||
{
|
||||
int id;
|
||||
#pragma omp parallel for
|
||||
for (id = 0; id < size; ++id) {
|
||||
|
||||
int src_id = id;
|
||||
int src_i = src_id % src_outputs;
|
||||
src_id /= src_outputs;
|
||||
int src_b = src_id;
|
||||
|
||||
out[id] = in[id];
|
||||
|
||||
// layers
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int add_outputs = outputs_of_layers[i];
|
||||
if (src_i < add_outputs) {
|
||||
int add_index = add_outputs*src_b + src_i;
|
||||
int out_index = id;
|
||||
|
||||
float *add = layers_output[i];
|
||||
out[out_index] += add[add_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void backward_shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers,
|
||||
float **layers_delta, float *delta_out, float *delta_in)
|
||||
{
|
||||
int id;
|
||||
#pragma omp parallel for
|
||||
for (id = 0; id < size; ++id) {
|
||||
int src_id = id;
|
||||
int src_i = src_id % src_outputs;
|
||||
src_id /= src_outputs;
|
||||
int src_b = src_id;
|
||||
|
||||
delta_out[id] += delta_in[id];
|
||||
|
||||
// layers
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int add_outputs = outputs_of_layers[i];
|
||||
if (src_i < add_outputs) {
|
||||
int add_index = add_outputs*src_b + src_i;
|
||||
int out_index = id;
|
||||
|
||||
float *layer_delta = layers_delta[i];
|
||||
layer_delta[add_index] += delta_in[id];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
|
||||
{
|
||||
int stride = w1/w2;
|
||||
|
|
|
@ -32,6 +32,9 @@ void fill_cpu(int N, float ALPHA, float * X, int INCX);
|
|||
float dot_cpu(int N, float *X, int INCX, float *Y, int INCY);
|
||||
void test_gpu_blas();
|
||||
void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
|
||||
void shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, float **layers_output, float *out, float *in);
|
||||
void backward_shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers,
|
||||
float **layers_delta, float *delta_out, float *delta_in);
|
||||
|
||||
void mean_cpu(float *x, int batch, int filters, int spatial, float *mean);
|
||||
void variance_cpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
|
||||
|
@ -84,6 +87,8 @@ void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *varianc
|
|||
void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
|
||||
void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean);
|
||||
void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
|
||||
void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in);
|
||||
void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_delta_gpu, float *delta_out, float *delta_in);
|
||||
void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
|
||||
void scale_bias_gpu(float *output, float *biases, int batch, int n, int size);
|
||||
void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates);
|
||||
|
|
|
@ -668,6 +668,75 @@ extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX)
|
|||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
__global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in)
|
||||
{
|
||||
const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (id >= size) return;
|
||||
|
||||
int src_id = id;
|
||||
int src_i = src_id % src_outputs;
|
||||
src_id /= src_outputs;
|
||||
int src_b = src_id;
|
||||
|
||||
out[id] = in[id];
|
||||
|
||||
// layers
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int add_outputs = outputs_of_layers_gpu[i];
|
||||
if (src_i < add_outputs) {
|
||||
int add_index = add_outputs*src_b + src_i;
|
||||
int out_index = id;
|
||||
|
||||
float *add = layers_output_gpu[i];
|
||||
out[out_index] += add[add_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in)
|
||||
{
|
||||
//printf(" src_outputs = %d, batch = %d, n = %d \n", src_outputs, batch, n);
|
||||
int size = batch * src_outputs;
|
||||
shortcut_multilayer_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu, layers_output_gpu, out, in);
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
|
||||
__global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu,
|
||||
float **layers_delta_gpu, float *delta_out, float *delta_in)
|
||||
{
|
||||
const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (id >= size) return;
|
||||
|
||||
int src_id = id;
|
||||
int src_i = src_id % src_outputs;
|
||||
src_id /= src_outputs;
|
||||
int src_b = src_id;
|
||||
|
||||
delta_out[id] += delta_in[id];
|
||||
|
||||
// layers
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int add_outputs = outputs_of_layers_gpu[i];
|
||||
if (src_i < add_outputs) {
|
||||
int add_index = add_outputs*src_b + src_i;
|
||||
int out_index = id;
|
||||
|
||||
float *layer_delta = layers_delta_gpu[i];
|
||||
layer_delta[add_index] += delta_in[id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_delta_gpu, float *delta_out, float *delta_in)
|
||||
{
|
||||
//printf(" src_outputs = %d, batch = %d, n = %d \n", src_outputs, batch, n);
|
||||
int size = batch * src_outputs;
|
||||
backward_shortcut_multilayer_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu, layers_delta_gpu, delta_out, delta_in);
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
__global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
|
||||
{
|
||||
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
|
|
|
@ -369,6 +369,21 @@ float *cuda_make_array(float *x, size_t n)
|
|||
return x_gpu;
|
||||
}
|
||||
|
||||
void **cuda_make_array_pointers(void **x, size_t n)
|
||||
{
|
||||
void **x_gpu;
|
||||
size_t size = sizeof(void*) * n;
|
||||
cudaError_t status = cudaMalloc((void **)&x_gpu, size);
|
||||
if (status != cudaSuccess) fprintf(stderr, " Try to set subdivisions=64 in your cfg-file. \n");
|
||||
CHECK_CUDA(status);
|
||||
if (x) {
|
||||
status = cudaMemcpyAsync(x_gpu, x, size, cudaMemcpyDefault, get_cuda_stream());
|
||||
CHECK_CUDA(status);
|
||||
}
|
||||
if (!x_gpu) error("Cuda malloc failed\n");
|
||||
return x_gpu;
|
||||
}
|
||||
|
||||
void cuda_random(float *x_gpu, size_t n)
|
||||
{
|
||||
static curandGenerator_t gen[16];
|
||||
|
|
|
@ -62,6 +62,7 @@ extern "C" {
|
|||
float *cuda_make_array_pinned_preallocated(float *x, size_t n);
|
||||
float *cuda_make_array_pinned(float *x, size_t n);
|
||||
float *cuda_make_array(float *x, size_t n);
|
||||
void **cuda_make_array_pointers(void **x, size_t n);
|
||||
int *cuda_make_int_array(size_t n);
|
||||
int *cuda_make_int_array_new_api(int *x, size_t n);
|
||||
void cuda_push_array(float *x_gpu, float *x, size_t n);
|
||||
|
|
|
@ -65,6 +65,8 @@ void free_layer_custom(layer l, int keep_cudnn_desc)
|
|||
if (l.indexes) free(l.indexes);
|
||||
if (l.input_layers) free(l.input_layers);
|
||||
if (l.input_sizes) free(l.input_sizes);
|
||||
if (l.layers_output) free(l.layers_output);
|
||||
if (l.layers_delta) free(l.layers_delta);
|
||||
if (l.map) free(l.map);
|
||||
if (l.rand) free(l.rand);
|
||||
if (l.cost) free(l.cost);
|
||||
|
@ -190,6 +192,9 @@ void free_layer_custom(layer l, int keep_cudnn_desc)
|
|||
if (l.rand_gpu) cuda_free(l.rand_gpu);
|
||||
if (l.squared_gpu) cuda_free(l.squared_gpu);
|
||||
if (l.norms_gpu) cuda_free(l.norms_gpu);
|
||||
if (l.input_sizes_gpu) cuda_free(l.input_sizes_gpu);
|
||||
if (l.layers_output_gpu) cuda_free(l.layers_output_gpu);
|
||||
if (l.layers_delta_gpu) cuda_free(l.layers_delta_gpu);
|
||||
|
||||
// CONV-LSTM
|
||||
if (l.f_gpu) cuda_free(l.f_gpu);
|
||||
|
|
|
@ -538,7 +538,7 @@ int resize_network(network *net, int w, int h)
|
|||
}else if(l.type == ROUTE){
|
||||
resize_route_layer(&l, net);
|
||||
}else if (l.type == SHORTCUT) {
|
||||
resize_shortcut_layer(&l, w, h);
|
||||
resize_shortcut_layer(&l, w, h, net);
|
||||
}else if (l.type == SCALE_CHANNELS) {
|
||||
resize_scale_channels_layer(&l, net);
|
||||
}else if (l.type == DROPOUT) {
|
||||
|
|
|
@ -130,8 +130,8 @@ void forward_network_gpu(network net, network_state state)
|
|||
printf("\n\nSorted by time:\n");
|
||||
qsort(sorted_avg_time_per_layer, net.n, sizeof(time_benchmark_layers), time_comparator);
|
||||
for (i = 0; i < net.n; ++i) {
|
||||
//printf("\layer %d - type: %d - avg_time %lf ms \n", avg_time_per_layer[i].layer_id, avg_time_per_layer[i].layer_type, avg_time_per_layer[i].time);
|
||||
printf("\%d - layer %d - type: %d - avg_time %lf ms \n", i, sorted_avg_time_per_layer[i].layer_id, sorted_avg_time_per_layer[i].layer_type, sorted_avg_time_per_layer[i].time);
|
||||
//printf("layer %d - type: %d - avg_time %lf ms \n", avg_time_per_layer[i].layer_id, avg_time_per_layer[i].layer_type, avg_time_per_layer[i].time);
|
||||
printf("%d - layer %d - type: %d - avg_time %lf ms \n", i, sorted_avg_time_per_layer[i].layer_id, sorted_avg_time_per_layer[i].layer_type, sorted_avg_time_per_layer[i].time);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
51
src/parser.c
51
src/parser.c
|
@ -817,15 +817,54 @@ layer parse_shortcut(list *options, size_params params, network net)
|
|||
ACTIVATION activation = get_activation(activation_s);
|
||||
|
||||
int assisted_excitation = option_find_float_quiet(options, "assisted_excitation", 0);
|
||||
//char *l = option_find(options, "from");
|
||||
//int index = atoi(l);
|
||||
//if(index < 0) index = params.index + index;
|
||||
|
||||
char *l = option_find(options, "from");
|
||||
int index = atoi(l);
|
||||
if(index < 0) index = params.index + index;
|
||||
int len = strlen(l);
|
||||
if (!l) error("Route Layer must specify input layers: from = ...");
|
||||
int n = 1;
|
||||
int i;
|
||||
for (i = 0; i < len; ++i) {
|
||||
if (l[i] == ',') ++n;
|
||||
}
|
||||
|
||||
int batch = params.batch;
|
||||
layer from = net.layers[index];
|
||||
if (from.antialiasing) from = *from.input_layer;
|
||||
int* layers = (int*)calloc(n, sizeof(int));
|
||||
int* sizes = (int*)calloc(n, sizeof(int));
|
||||
float **layers_output = (float **)calloc(n, sizeof(float *));
|
||||
float **layers_delta = (float **)calloc(n, sizeof(float *));
|
||||
float **layers_output_gpu = (float **)calloc(n, sizeof(float *));
|
||||
float **layers_delta_gpu = (float **)calloc(n, sizeof(float *));
|
||||
|
||||
layer s = make_shortcut_layer(batch, index, params.w, params.h, params.c, from.out_w, from.out_h, from.out_c, assisted_excitation, activation, params.train);
|
||||
for (i = 0; i < n; ++i) {
|
||||
int index = atoi(l);
|
||||
l = strchr(l, ',') + 1;
|
||||
if (index < 0) index = params.index + index;
|
||||
layers[i] = index;
|
||||
sizes[i] = params.net.layers[index].outputs;
|
||||
layers_output[i] = params.net.layers[index].output;
|
||||
layers_delta[i] = params.net.layers[index].delta;
|
||||
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
for (i = 0; i < n; ++i) {
|
||||
layers_output_gpu[i] = params.net.layers[layers[i]].output_gpu;
|
||||
layers_delta_gpu[i] = params.net.layers[layers[i]].delta_gpu;
|
||||
}
|
||||
#endif// GPU
|
||||
|
||||
layer s = make_shortcut_layer(params.batch, n, layers, sizes, params.w, params.h, params.c, layers_output, layers_delta, layers_output_gpu, layers_delta_gpu, activation, params.train);
|
||||
|
||||
for (i = 0; i < n; ++i) {
|
||||
int index = layers[i];
|
||||
assert(params.w == net.layers[index].out_w && params.h == net.layers[index].out_h);
|
||||
|
||||
if (params.w != net.layers[index].out_w || params.h != net.layers[index].out_h || params.c != net.layers[index].out_c)
|
||||
fprintf(stderr, " w = %d, w2 = %d, h = %d, h2 = %d, c = %d, c2 = %d \n",
|
||||
params.w, net.layers[index].out_w, params.h, net.layers[index].out_h, params.c, params.net.layers[index].out_c);
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
|
|
|
@ -6,29 +6,36 @@
|
|||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
|
||||
layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2, int assisted_excitation, ACTIVATION activation, int train)
|
||||
layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes, int w, int h, int c,
|
||||
float **layers_output, float **layers_delta, float **layers_output_gpu, float **layers_delta_gpu, ACTIVATION activation, int train)
|
||||
{
|
||||
if(assisted_excitation) fprintf(stderr, "Shortcut Layer - AE: %d\n", index);
|
||||
else fprintf(stderr,"Shortcut Layer: %d\n", index);
|
||||
fprintf(stderr, "Shortcut Layer: ");
|
||||
int i;
|
||||
for(i = 0; i < n; ++i) fprintf(stderr, "%d, ", input_layers[i]);
|
||||
|
||||
layer l = { (LAYER_TYPE)0 };
|
||||
l.train = train;
|
||||
l.type = SHORTCUT;
|
||||
l.batch = batch;
|
||||
l.activation = activation;
|
||||
l.w = w2;
|
||||
l.h = h2;
|
||||
l.c = c2;
|
||||
l.out_w = w;
|
||||
l.out_h = h;
|
||||
l.out_c = c;
|
||||
l.n = n;
|
||||
l.input_layers = input_layers;
|
||||
l.input_sizes = input_sizes;
|
||||
l.layers_output = layers_output;
|
||||
l.layers_delta = layers_delta;
|
||||
|
||||
//l.w = w2;
|
||||
//l.h = h2;
|
||||
//l.c = c2;
|
||||
l.w = l.out_w = w;
|
||||
l.h = l.out_h = h;
|
||||
l.c = l.out_c = c;
|
||||
l.outputs = w*h*c;
|
||||
l.inputs = l.outputs;
|
||||
|
||||
l.assisted_excitation = assisted_excitation;
|
||||
//if(w != w2 || h != h2 || c != c2) fprintf(stderr, " w = %d, w2 = %d, h = %d, h2 = %d, c = %d, c2 = %d \n", w, w2, h, h2, c, c2);
|
||||
|
||||
if(w != w2 || h != h2 || c != c2) fprintf(stderr, " w = %d, w2 = %d, h = %d, h2 = %d, c = %d, c2 = %d \n", w, w2, h, h2, c, c2);
|
||||
|
||||
l.index = index;
|
||||
l.index = l.input_layers[0];
|
||||
|
||||
if (train) l.delta = (float*)calloc(l.outputs * batch, sizeof(float));
|
||||
l.output = (float*)calloc(l.outputs * batch, sizeof(float));
|
||||
|
@ -47,17 +54,18 @@ layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int
|
|||
|
||||
if (train) l.delta_gpu = cuda_make_array(l.delta, l.outputs*batch);
|
||||
l.output_gpu = cuda_make_array(l.output, l.outputs*batch);
|
||||
if (l.assisted_excitation)
|
||||
{
|
||||
const int size = l.out_w * l.out_h * l.batch;
|
||||
l.gt_gpu = cuda_make_array(NULL, size);
|
||||
l.a_avg_gpu = cuda_make_array(NULL, size);
|
||||
}
|
||||
|
||||
l.input_sizes_gpu = cuda_make_array(input_sizes, l.n);
|
||||
l.layers_output_gpu = cuda_make_array_pointers(layers_output_gpu, l.n);
|
||||
l.layers_delta_gpu = cuda_make_array_pointers(layers_delta_gpu, l.n);
|
||||
#endif // GPU
|
||||
|
||||
l.bflops = l.out_w * l.out_h * l.out_c * l.n / 1000000000.;
|
||||
fprintf(stderr, " outputs:%4d x%4d x%4d %5.3f BF\n", l.out_w, l.out_h, l.out_c, l.bflops);
|
||||
return l;
|
||||
}
|
||||
|
||||
void resize_shortcut_layer(layer *l, int w, int h)
|
||||
void resize_shortcut_layer(layer *l, int w, int h, network *net)
|
||||
{
|
||||
//assert(l->w == l->out_w);
|
||||
//assert(l->h == l->out_h);
|
||||
|
@ -68,6 +76,13 @@ void resize_shortcut_layer(layer *l, int w, int h)
|
|||
if (l->train) l->delta = (float*)realloc(l->delta, l->outputs * l->batch * sizeof(float));
|
||||
l->output = (float*)realloc(l->output, l->outputs * l->batch * sizeof(float));
|
||||
|
||||
int i;
|
||||
for (i = 0; i < l->n; ++i) {
|
||||
int index = l->input_layers[i];
|
||||
l->input_sizes[i] = net->layers[index].outputs;
|
||||
assert(l->w == net->layers[index].w && l->h == net->layers[index].h);
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
cuda_free(l->output_gpu);
|
||||
l->output_gpu = cuda_make_array(l->output, l->outputs*l->batch);
|
||||
|
@ -82,7 +97,11 @@ void resize_shortcut_layer(layer *l, int w, int h)
|
|||
|
||||
void forward_shortcut_layer(const layer l, network_state state)
|
||||
{
|
||||
if (l.w == l.out_w && l.h == l.out_h && l.c == l.out_c) {
|
||||
int from_w = state.net.layers[l.index].w;
|
||||
int from_h = state.net.layers[l.index].h;
|
||||
int from_c = state.net.layers[l.index].c;
|
||||
|
||||
if (l.n == 1 && from_w == l.w && from_h == l.h && from_c == l.c) {
|
||||
int size = l.batch * l.w * l.h * l.c;
|
||||
int i;
|
||||
#pragma omp parallel for
|
||||
|
@ -90,16 +109,16 @@ void forward_shortcut_layer(const layer l, network_state state)
|
|||
l.output[i] = state.input[i] + state.net.layers[l.index].output[i];
|
||||
}
|
||||
else {
|
||||
copy_cpu(l.outputs*l.batch, state.input, 1, l.output, 1);
|
||||
shortcut_cpu(l.batch, l.w, l.h, l.c, state.net.layers[l.index].output, l.out_w, l.out_h, l.out_c, l.output);
|
||||
shortcut_multilayer_cpu(l.outputs * l.batch, l.outputs, l.batch, l.n, l.input_sizes, l.layers_output, l.output, state.input);
|
||||
}
|
||||
|
||||
//copy_cpu(l.outputs*l.batch, state.input, 1, l.output, 1);
|
||||
//shortcut_cpu(l.batch, from_w, from_h, from_c, state.net.layers[l.index].output, l.out_w, l.out_h, l.out_c, l.output);
|
||||
|
||||
//activate_array(l.output, l.outputs*l.batch, l.activation);
|
||||
if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.output);
|
||||
else if (l.activation == MISH) activate_array_mish(l.output, l.outputs*l.batch, l.activation_input, l.output);
|
||||
else activate_array_cpu_custom(l.output, l.outputs*l.batch, l.activation);
|
||||
|
||||
if (l.assisted_excitation && state.train) assisted_excitation_forward(l, state);
|
||||
}
|
||||
|
||||
void backward_shortcut_layer(const layer l, network_state state)
|
||||
|
@ -108,8 +127,11 @@ void backward_shortcut_layer(const layer l, network_state state)
|
|||
else if (l.activation == MISH) gradient_array_mish(l.outputs*l.batch, l.activation_input, l.delta);
|
||||
else gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
|
||||
|
||||
axpy_cpu(l.outputs*l.batch, 1, l.delta, 1, state.delta, 1);
|
||||
shortcut_cpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta, l.w, l.h, l.c, state.net.layers[l.index].delta);
|
||||
backward_shortcut_multilayer_cpu(l.outputs * l.batch, l.outputs, l.batch, l.n, l.input_sizes,
|
||||
l.layers_delta, state.delta, l.delta);
|
||||
|
||||
//axpy_cpu(l.outputs*l.batch, 1, l.delta, 1, state.delta, 1);
|
||||
//shortcut_cpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta, l.w, l.h, l.c, state.net.layers[l.index].delta);
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
|
@ -118,13 +140,25 @@ void forward_shortcut_layer_gpu(const layer l, network_state state)
|
|||
//copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1);
|
||||
//simple_copy_ongpu(l.outputs*l.batch, state.input, l.output_gpu);
|
||||
//shortcut_gpu(l.batch, l.w, l.h, l.c, state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu);
|
||||
input_shortcut_gpu(state.input, l.batch, l.w, l.h, l.c, state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu);
|
||||
|
||||
//input_shortcut_gpu(state.input, l.batch, l.w, l.h, l.c, state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu);
|
||||
|
||||
//-----------
|
||||
//if (l.outputs == l.input_sizes[0])
|
||||
//if(l.n == 1)
|
||||
//{
|
||||
// input_shortcut_gpu(state.input, l.batch, state.net.layers[l.index].w, state.net.layers[l.index].h, state.net.layers[l.index].c,
|
||||
// state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu);
|
||||
//}
|
||||
//else
|
||||
{
|
||||
shortcut_multilayer_gpu(l.outputs, l.batch, l.n, l.input_sizes_gpu, l.layers_output_gpu, l.output_gpu, state.input);
|
||||
}
|
||||
|
||||
if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
|
||||
else if (l.activation == MISH) activate_array_mish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
|
||||
else activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||
|
||||
if (l.assisted_excitation && state.train) assisted_excitation_forward_gpu(l, state);
|
||||
}
|
||||
|
||||
void backward_shortcut_layer_gpu(const layer l, network_state state)
|
||||
|
@ -133,7 +167,9 @@ void backward_shortcut_layer_gpu(const layer l, network_state state)
|
|||
else if (l.activation == MISH) gradient_array_mish_ongpu(l.outputs*l.batch, l.activation_input_gpu, l.delta_gpu);
|
||||
else gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
|
||||
|
||||
axpy_ongpu(l.outputs*l.batch, 1, l.delta_gpu, 1, state.delta, 1);
|
||||
shortcut_gpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta_gpu, l.w, l.h, l.c, state.net.layers[l.index].delta_gpu);
|
||||
backward_shortcut_multilayer_gpu(l.outputs, l.batch, l.n, l.input_sizes_gpu, l.layers_delta_gpu, state.delta, l.delta_gpu);
|
||||
|
||||
//axpy_ongpu(l.outputs*l.batch, 1, l.delta_gpu, 1, state.delta, 1);
|
||||
//shortcut_gpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta_gpu, l.w, l.h, l.c, state.net.layers[l.index].delta_gpu);
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -7,10 +7,11 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2, int assisted_excitation, ACTIVATION activation, int train);
|
||||
layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes, int w, int h, int c,
|
||||
float **layers_output, float **layers_delta, float **layers_output_gpu, float **layers_delta_gpu, ACTIVATION activation, int train);
|
||||
void forward_shortcut_layer(const layer l, network_state state);
|
||||
void backward_shortcut_layer(const layer l, network_state state);
|
||||
void resize_shortcut_layer(layer *l, int w, int h);
|
||||
void resize_shortcut_layer(layer *l, int w, int h, network *net);
|
||||
|
||||
#ifdef GPU
|
||||
void forward_shortcut_layer_gpu(const layer l, network_state state);
|
||||
|
|
Loading…
Reference in New Issue