weighted [shortcut] layer

This commit is contained in:
AlexeyAB 2020-01-07 17:11:57 +03:00
parent e62506629e
commit 9bd88d7fd7
7 changed files with 222 additions and 43 deletions

View File

@ -120,6 +120,11 @@ typedef enum {
YOLO_CENTER = 1 << 0, YOLO_LEFT_TOP = 1 << 1, YOLO_RIGHT_BOTTOM = 1 << 2
} YOLO_POINT;
// parser.h
typedef enum {
NO_WEIGHTS, PER_FEATURE, PER_CHANNEL
} WEIGHTS_TYPE_T;
// image.h
typedef enum{
@ -330,6 +335,7 @@ struct layer {
int * input_sizes;
float **layers_output;
float **layers_delta;
WEIGHTS_TYPE_T weights_type;
int * map;
int * counts;
float ** sums;

View File

@ -68,36 +68,49 @@ void weighted_delta_cpu(float *a, float *b, float *s, float *da, float *db, floa
}
}
void shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, float **layers_output, float *out, float *in)
void shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, float **layers_output, float *out, float *in, float *weights, int nweights)
{
// nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
const int layer_step = nweights / (n + 1); // 1 or l.c or (l.c * l.h * l.w)
const int step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
int id;
#pragma omp parallel for
for (id = 0; id < size; ++id) {
int src_id = id;
int src_i = src_id % src_outputs;
const int src_i = src_id % src_outputs;
src_id /= src_outputs;
int src_b = src_id;
out[id] = in[id];
if(weights) out[id] = in[id] * weights[src_i / step]; // [0 or c or (c, h ,w)]
else out[id] = in[id];
// layers
for (int i = 0; i < n; ++i) {
int i;
for (i = 0; i < n; ++i) {
int add_outputs = outputs_of_layers[i];
if (src_i < add_outputs) {
int add_index = add_outputs*src_b + src_i;
int out_index = id;
float *add = layers_output[i];
out[out_index] += add[add_index];
const int weights_index = src_i / step + (i + 1)*layer_step; // [0 or c or (c, h ,w)]
if (weights) out[out_index] += add[add_index] * weights[weights_index]; // [0 or c or (c, h ,w)]
else out[out_index] += add[add_index];
}
}
}
}
void backward_shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers,
float **layers_delta, float *delta_out, float *delta_in)
float **layers_delta, float *delta_out, float *delta_in, float *weights, float *weight_updates, int nweights, float *in, float **layers_output)
{
// nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
const int layer_step = nweights / (n + 1); // 1 or l.c or (l.c * l.h * l.w)
const int step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
int id;
#pragma omp parallel for
for (id = 0; id < size; ++id) {
@ -106,17 +119,29 @@ void backward_shortcut_multilayer_cpu(int size, int src_outputs, int batch, int
src_id /= src_outputs;
int src_b = src_id;
delta_out[id] += delta_in[id];
if (weights) {
delta_out[id] += delta_in[id] * weights[src_i / step]; // [0 or c or (c, h ,w)]
weight_updates[src_i / step] += delta_in[id] * in[id];
}
else delta_out[id] += delta_in[id];
// layers
for (int i = 0; i < n; ++i) {
int i;
for (i = 0; i < n; ++i) {
int add_outputs = outputs_of_layers[i];
if (src_i < add_outputs) {
int add_index = add_outputs*src_b + src_i;
int out_index = id;
float *layer_delta = layers_delta[i];
layer_delta[add_index] += delta_in[id];
if (weights) {
float *add = layers_output[i];
const int weights_index = src_i / step + (i + 1)*layer_step; // [0 or c or (c, h ,w)]
layer_delta[add_index] += delta_in[id] * weights[weights_index]; // [0 or c or (c, h ,w)]
weight_updates[weights_index] += delta_in[id] * add[add_index];
}
else layer_delta[add_index] += delta_in[id];
}
}
}

View File

@ -32,9 +32,9 @@ void fill_cpu(int N, float ALPHA, float * X, int INCX);
float dot_cpu(int N, float *X, int INCX, float *Y, int INCY);
void test_gpu_blas();
void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
void shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, float **layers_output, float *out, float *in);
void shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, float **layers_output, float *out, float *in, float *weights, int nweights);
void backward_shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers,
float **layers_delta, float *delta_out, float *delta_in);
float **layers_delta, float *delta_out, float *delta_in, float *weights, float *weight_updates, int nweights, float *in, float **layers_output);
void mean_cpu(float *x, int batch, int filters, int spatial, float *mean);
void variance_cpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
@ -88,8 +88,9 @@ void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *varianc
void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean);
void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in);
void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_delta_gpu, float *delta_out, float *delta_in);
void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights);
void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_delta_gpu, float *delta_out, float *delta_in,
float *weights, float *weight_updates, int nweights, float *in, float **layers_output);
void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
void scale_bias_gpu(float *output, float *biases, int batch, int n, int size);
void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates);

View File

@ -674,17 +674,22 @@ extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX)
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in)
__global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights)
{
const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (id >= size) return;
// nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
const int layer_step = nweights / (n + 1); // 1 or l.c or (l.c * l.h * l.w)
const int step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
int src_id = id;
int src_i = src_id % src_outputs;
const int src_i = src_id % src_outputs;
src_id /= src_outputs;
int src_b = src_id;
out[id] = in[id];
if (weights_gpu) out[id] = in[id] * weights_gpu[src_i / step]; // [0 or c or (c, h ,w)]
else out[id] = in[id];
// layers
for (int i = 0; i < n; ++i) {
@ -694,33 +699,45 @@ __global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch,
int out_index = id;
float *add = layers_output_gpu[i];
out[out_index] += add[add_index];
const int weights_index = src_i / step + (i + 1)*layer_step; // [0 or c or (c, h ,w)]
if (weights_gpu) out[out_index] += add[add_index] * weights_gpu[weights_index]; // [0 or c or (c, h ,w)]
else out[out_index] += add[add_index];
}
}
}
extern "C" void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in)
extern "C" void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights)
{
//printf(" src_outputs = %d, batch = %d, n = %d \n", src_outputs, batch, n);
int size = batch * src_outputs;
shortcut_multilayer_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu, layers_output_gpu, out, in);
shortcut_multilayer_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu, layers_output_gpu, out, in, weights_gpu, nweights);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu,
float **layers_delta_gpu, float *delta_out, float *delta_in)
float **layers_delta_gpu, float *delta_out, float *delta_in, float *weights_gpu, float *weight_updates_gpu, int nweights, float *in, float **layers_output_gpu)
{
const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (id >= size) return;
// nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
const int layer_step = nweights / (n + 1); // 1 or l.c or (l.c * l.h * l.w)
const int step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
//if(id == 0) printf(" layer_step = %d, step = %d \n", layer_step, step);
int src_id = id;
int src_i = src_id % src_outputs;
const int src_i = src_id % src_outputs;
src_id /= src_outputs;
int src_b = src_id;
delta_out[id] += delta_in[id];
if (weights_gpu) {
delta_out[id] += delta_in[id] * weights_gpu[src_i / step]; // [0 or c or (c, h ,w)]
weight_updates_gpu[src_i / step] += delta_in[id] * in[id];
}
else delta_out[id] += delta_in[id];
// layers
for (int i = 0; i < n; ++i) {
@ -730,16 +747,28 @@ __global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, i
int out_index = id;
float *layer_delta = layers_delta_gpu[i];
layer_delta[add_index] += delta_in[id];
if (weights_gpu) {
float *add = layers_output_gpu[i];
const int weights_index = src_i / step + (i + 1)*layer_step; // [0 or c or (c, h ,w)]
layer_delta[add_index] += delta_in[id] * weights_gpu[weights_index];
weight_updates_gpu[weights_index] += delta_in[id] * add[add_index];
}
else layer_delta[add_index] += delta_in[id];
}
}
}
extern "C" void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_delta_gpu, float *delta_out, float *delta_in)
extern "C" void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu,
float **layers_delta_gpu, float *delta_out, float *delta_in, float *weights_gpu, float *weight_updates_gpu, int nweights, float *in, float **layers_output_gpu)
{
const int layer_step = nweights / (n + 1); // 1 or l.c or (l.c * l.h * l.w)
const int step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
//printf(" nweights = %d, n = %d, layer_step = %d, step = %d \n", nweights, n, layer_step, step);
//printf(" src_outputs = %d, batch = %d, n = %d \n", src_outputs, batch, n);
int size = batch * src_outputs;
backward_shortcut_multilayer_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu, layers_delta_gpu, delta_out, delta_in);
backward_shortcut_multilayer_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu,
layers_delta_gpu, delta_out, delta_in, weights_gpu, weight_updates_gpu, nweights, in, layers_output_gpu);
CHECK_CUDA(cudaPeekAtLastError());
}

View File

@ -816,10 +816,10 @@ layer parse_shortcut(list *options, size_params params, network net)
char *activation_s = option_find_str(options, "activation", "logistic");
ACTIVATION activation = get_activation(activation_s);
int assisted_excitation = option_find_float_quiet(options, "assisted_excitation", 0);
//char *l = option_find(options, "from");
//int index = atoi(l);
//if(index < 0) index = params.index + index;
char *weights_type_srt = option_find_str_quiet(options, "weights_type", "none");
WEIGHTS_TYPE_T weights_type = NO_WEIGHTS;
if(strcmp(weights_type_srt, "per_feature") == 0) weights_type = PER_FEATURE;
else if (strcmp(weights_type_srt, "per_channel") == 0) weights_type = PER_CHANNEL;
char *l = option_find(options, "from");
int len = strlen(l);
@ -854,7 +854,11 @@ layer parse_shortcut(list *options, size_params params, network net)
}
#endif// GPU
layer s = make_shortcut_layer(params.batch, n, layers, sizes, params.w, params.h, params.c, layers_output, layers_delta, layers_output_gpu, layers_delta_gpu, activation, params.train);
layer s = make_shortcut_layer(params.batch, n, layers, sizes, params.w, params.h, params.c, layers_output, layers_delta,
layers_output_gpu, layers_delta_gpu, weights_type, activation, params.train);
free(layers_output_gpu);
free(layers_delta_gpu);
for (i = 0; i < n; ++i) {
int index = layers[i];
@ -1515,6 +1519,18 @@ void save_convolutional_weights_binary(layer l, FILE *fp)
}
}
void save_shortcut_weights(layer l, FILE *fp)
{
#ifdef GPU
if (gpu_index >= 0) {
pull_shortcut_layer(l);
}
#endif
int num = l.nweights;
fwrite(l.weights, sizeof(float), num, fp);
}
void save_convolutional_weights(layer l, FILE *fp)
{
if(l.binary){
@ -1591,8 +1607,10 @@ void save_weights_upto(network net, char *filename, int cutoff)
int i;
for(i = 0; i < net.n && i < cutoff; ++i){
layer l = net.layers[i];
if(l.type == CONVOLUTIONAL && l.share_layer == NULL){
if (l.type == CONVOLUTIONAL && l.share_layer == NULL) {
save_convolutional_weights(l, fp);
} if (l.type == SHORTCUT && l.nweights > 0) {
save_shortcut_weights(l, fp);
} if(l.type == CONNECTED){
save_connected_weights(l, fp);
} if(l.type == BATCHNORM){
@ -1786,6 +1804,22 @@ void load_convolutional_weights(layer l, FILE *fp)
#endif
}
void load_shortcut_weights(layer l, FILE *fp)
{
if (l.binary) {
//load_convolutional_weights_binary(l, fp);
//return;
}
int num = l.nweights;
int read_bytes;
read_bytes = fread(l.weights, sizeof(float), num, fp);
if (read_bytes > 0 && read_bytes < num) printf("\n Warning: Unexpected end of wights-file! l.weights - l.index = %d \n", l.index);
#ifdef GPU
if (gpu_index >= 0) {
push_shortcut_layer(l);
}
#endif
}
void load_weights_upto(network *net, char *filename, int cutoff)
{
@ -1826,6 +1860,9 @@ void load_weights_upto(network *net, char *filename, int cutoff)
if(l.type == CONVOLUTIONAL && l.share_layer == NULL){
load_convolutional_weights(l, fp);
}
if (l.type == SHORTCUT && l.nweights > 0) {
load_shortcut_weights(l, fp);
}
if(l.type == CONNECTED){
load_connected_weights(l, fp, transpose);
}

View File

@ -7,7 +7,7 @@
#include <assert.h>
layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes, int w, int h, int c,
float **layers_output, float **layers_delta, float **layers_output_gpu, float **layers_delta_gpu, ACTIVATION activation, int train)
float **layers_output, float **layers_delta, float **layers_output_gpu, float **layers_delta_gpu, WEIGHTS_TYPE_T weights_type, ACTIVATION activation, int train)
{
fprintf(stderr, "Shortcut Layer: ");
int i;
@ -23,6 +23,7 @@ layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes,
l.input_sizes = input_sizes;
l.layers_output = layers_output;
l.layers_delta = layers_delta;
l.weights_type = weights_type;
//l.w = w2;
//l.h = h2;
@ -40,6 +41,18 @@ layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes,
if (train) l.delta = (float*)calloc(l.outputs * batch, sizeof(float));
l.output = (float*)calloc(l.outputs * batch, sizeof(float));
if (l.weights_type == PER_FEATURE) l.nweights = (l.n + 1);
else if (l.weights_type == PER_CHANNEL) l.nweights = (l.n + 1) * l.c;
if (l.nweights > 0) {
l.weights = (float*)calloc(l.nweights, sizeof(float));
float scale = sqrt(2. / l.nweights);
for (i = 0; i < l.nweights; ++i) l.weights[i] = 1;// scale*rand_uniform(-1, 1); // rand_normal();
if (train) l.weight_updates = (float*)calloc(l.nweights, sizeof(float));
l.update = update_shortcut_layer;
}
l.forward = forward_shortcut_layer;
l.backward = backward_shortcut_layer;
#ifndef GPU
@ -52,16 +65,23 @@ layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes,
l.forward_gpu = forward_shortcut_layer_gpu;
l.backward_gpu = backward_shortcut_layer_gpu;
if (l.nweights > 0) {
l.update_gpu = update_shortcut_layer_gpu;
l.weights_gpu = cuda_make_array(l.weights, l.nweights);
if (train) l.weight_updates_gpu = cuda_make_array(l.weight_updates, l.nweights);
}
if (train) l.delta_gpu = cuda_make_array(l.delta, l.outputs*batch);
l.output_gpu = cuda_make_array(l.output, l.outputs*batch);
l.input_sizes_gpu = cuda_make_int_array_new_api(input_sizes, l.n);
l.layers_output_gpu = cuda_make_array_pointers((void**)layers_output_gpu, l.n);
l.layers_delta_gpu = cuda_make_array_pointers((void**)layers_delta_gpu, l.n);
l.layers_output_gpu = (float**)cuda_make_array_pointers((void**)layers_output_gpu, l.n);
l.layers_delta_gpu = (float**)cuda_make_array_pointers((void**)layers_delta_gpu, l.n);
#endif // GPU
l.bflops = l.out_w * l.out_h * l.out_c * l.n / 1000000000.;
fprintf(stderr, " outputs:%4d x%4d x%4d %5.3f BF\n", l.out_w, l.out_h, l.out_c, l.bflops);
if (l.weights_type) l.bflops *= 2;
fprintf(stderr, " wt = %d, outputs:%4d x%4d x%4d %5.3f BF\n", l.weights_type, l.out_w, l.out_h, l.out_c, l.bflops);
return l;
}
@ -95,9 +115,21 @@ void resize_shortcut_layer(layer *l, int w, int h, network *net)
l->delta_gpu = cuda_make_array(l->delta, l->outputs*l->batch);
}
float **layers_output_gpu = (float **)calloc(l->n, sizeof(float *));
float **layers_delta_gpu = (float **)calloc(l->n, sizeof(float *));
for (i = 0; i < l->n; ++i) {
const int index = l->input_layers[i];
layers_output_gpu[i] = net->layers[index].output_gpu;
layers_delta_gpu[i] = net->layers[index].delta_gpu;
}
memcpy_ongpu(l->input_sizes_gpu, l->input_sizes, l->n * sizeof(int));
memcpy_ongpu(l->layers_output_gpu, l->layers_output, l->n * sizeof(float*));
memcpy_ongpu(l->layers_delta_gpu, l->layers_delta, l->n * sizeof(float*));
memcpy_ongpu(l->layers_output_gpu, layers_output_gpu, l->n * sizeof(float*));
memcpy_ongpu(l->layers_delta_gpu, layers_delta_gpu, l->n * sizeof(float*));
free(layers_output_gpu);
free(layers_delta_gpu);
#endif
}
@ -108,7 +140,7 @@ void forward_shortcut_layer(const layer l, network_state state)
int from_h = state.net.layers[l.index].h;
int from_c = state.net.layers[l.index].c;
if (l.n == 1 && from_w == l.w && from_h == l.h && from_c == l.c) {
if (l.nweights == 0 && l.n == 1 && from_w == l.w && from_h == l.h && from_c == l.c) {
int size = l.batch * l.w * l.h * l.c;
int i;
#pragma omp parallel for
@ -116,7 +148,7 @@ void forward_shortcut_layer(const layer l, network_state state)
l.output[i] = state.input[i] + state.net.layers[l.index].output[i];
}
else {
shortcut_multilayer_cpu(l.outputs * l.batch, l.outputs, l.batch, l.n, l.input_sizes, l.layers_output, l.output, state.input);
shortcut_multilayer_cpu(l.outputs * l.batch, l.outputs, l.batch, l.n, l.input_sizes, l.layers_output, l.output, state.input, l.weights, l.nweights);
}
//copy_cpu(l.outputs*l.batch, state.input, 1, l.output, 1);
@ -135,12 +167,24 @@ void backward_shortcut_layer(const layer l, network_state state)
else gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
backward_shortcut_multilayer_cpu(l.outputs * l.batch, l.outputs, l.batch, l.n, l.input_sizes,
l.layers_delta, state.delta, l.delta);
l.layers_delta, state.delta, l.delta, l.weights, l.weight_updates, l.nweights, state.input, l.layers_output);
//axpy_cpu(l.outputs*l.batch, 1, l.delta, 1, state.delta, 1);
//shortcut_cpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta, l.w, l.h, l.c, state.net.layers[l.index].delta);
}
void update_shortcut_layer(layer l, int batch, float learning_rate_init, float momentum, float decay)
{
float learning_rate = learning_rate_init*l.learning_rate_scale;
//float momentum = a.momentum;
//float decay = a.decay;
//int batch = a.batch;
axpy_cpu(l.nweights, -decay*batch, l.weights, 1, l.weight_updates, 1);
axpy_cpu(l.nweights, learning_rate / batch, l.weight_updates, 1, l.weights, 1);
scal_cpu(l.nweights, momentum, l.weight_updates, 1);
}
#ifdef GPU
void forward_shortcut_layer_gpu(const layer l, network_state state)
{
@ -159,7 +203,7 @@ void forward_shortcut_layer_gpu(const layer l, network_state state)
//}
//else
{
shortcut_multilayer_gpu(l.outputs, l.batch, l.n, l.input_sizes_gpu, l.layers_output_gpu, l.output_gpu, state.input);
shortcut_multilayer_gpu(l.outputs, l.batch, l.n, l.input_sizes_gpu, l.layers_output_gpu, l.output_gpu, state.input, l.weights_gpu, l.nweights);
}
if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
@ -174,9 +218,42 @@ void backward_shortcut_layer_gpu(const layer l, network_state state)
else if (l.activation == MISH) gradient_array_mish_ongpu(l.outputs*l.batch, l.activation_input_gpu, l.delta_gpu);
else gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
backward_shortcut_multilayer_gpu(l.outputs, l.batch, l.n, l.input_sizes_gpu, l.layers_delta_gpu, state.delta, l.delta_gpu);
backward_shortcut_multilayer_gpu(l.outputs, l.batch, l.n, l.input_sizes_gpu, l.layers_delta_gpu, state.delta, l.delta_gpu,
l.weights_gpu, l.weight_updates_gpu, l.nweights, state.input, l.layers_output_gpu);
//axpy_ongpu(l.outputs*l.batch, 1, l.delta_gpu, 1, state.delta, 1);
//shortcut_gpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta_gpu, l.w, l.h, l.c, state.net.layers[l.index].delta_gpu);
}
void update_shortcut_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay)
{
float learning_rate = learning_rate_init*l.learning_rate_scale;
//float momentum = a.momentum;
//float decay = a.decay;
//int batch = a.batch;
fix_nan_and_inf(l.weight_updates_gpu, l.nweights);
fix_nan_and_inf(l.weights_gpu, l.nweights);
axpy_ongpu(l.nweights, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1);
axpy_ongpu(l.nweights, learning_rate / batch, l.weight_updates_gpu, 1, l.weights_gpu, 1);
scal_ongpu(l.nweights, momentum, l.weight_updates_gpu, 1);
//if (l.clip) {
// constrain_gpu(l.nweights, l.clip, l.weights_gpu, 1);
//}
}
void pull_shortcut_layer(layer l)
{
cuda_pull_array_async(l.weights_gpu, l.weights, l.nweights);
CHECK_CUDA(cudaPeekAtLastError());
cudaStreamSynchronize(get_cuda_stream());
}
void push_shortcut_layer(layer l)
{
cuda_push_array(l.weights_gpu, l.weights, l.nweights);
CHECK_CUDA(cudaPeekAtLastError());
}
#endif

View File

@ -8,14 +8,18 @@
extern "C" {
#endif
layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes, int w, int h, int c,
float **layers_output, float **layers_delta, float **layers_output_gpu, float **layers_delta_gpu, ACTIVATION activation, int train);
float **layers_output, float **layers_delta, float **layers_output_gpu, float **layers_delta_gpu, WEIGHTS_TYPE_T weights_type, ACTIVATION activation, int train);
void forward_shortcut_layer(const layer l, network_state state);
void backward_shortcut_layer(const layer l, network_state state);
void update_shortcut_layer(layer l, int batch, float learning_rate_init, float momentum, float decay);
void resize_shortcut_layer(layer *l, int w, int h, network *net);
#ifdef GPU
void forward_shortcut_layer_gpu(const layer l, network_state state);
void backward_shortcut_layer_gpu(const layer l, network_state state);
void update_shortcut_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay);
void pull_shortcut_layer(layer l);
void push_shortcut_layer(layer l);
#endif
#ifdef __cplusplus