mirror of https://github.com/AlexeyAB/darknet.git
weighted [shortcut] layer
This commit is contained in:
parent
e62506629e
commit
9bd88d7fd7
|
@ -120,6 +120,11 @@ typedef enum {
|
|||
YOLO_CENTER = 1 << 0, YOLO_LEFT_TOP = 1 << 1, YOLO_RIGHT_BOTTOM = 1 << 2
|
||||
} YOLO_POINT;
|
||||
|
||||
// parser.h
|
||||
typedef enum {
|
||||
NO_WEIGHTS, PER_FEATURE, PER_CHANNEL
|
||||
} WEIGHTS_TYPE_T;
|
||||
|
||||
|
||||
// image.h
|
||||
typedef enum{
|
||||
|
@ -330,6 +335,7 @@ struct layer {
|
|||
int * input_sizes;
|
||||
float **layers_output;
|
||||
float **layers_delta;
|
||||
WEIGHTS_TYPE_T weights_type;
|
||||
int * map;
|
||||
int * counts;
|
||||
float ** sums;
|
||||
|
|
43
src/blas.c
43
src/blas.c
|
@ -68,36 +68,49 @@ void weighted_delta_cpu(float *a, float *b, float *s, float *da, float *db, floa
|
|||
}
|
||||
}
|
||||
|
||||
void shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, float **layers_output, float *out, float *in)
|
||||
void shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, float **layers_output, float *out, float *in, float *weights, int nweights)
|
||||
{
|
||||
// nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
|
||||
const int layer_step = nweights / (n + 1); // 1 or l.c or (l.c * l.h * l.w)
|
||||
const int step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
|
||||
|
||||
int id;
|
||||
#pragma omp parallel for
|
||||
for (id = 0; id < size; ++id) {
|
||||
|
||||
int src_id = id;
|
||||
int src_i = src_id % src_outputs;
|
||||
const int src_i = src_id % src_outputs;
|
||||
src_id /= src_outputs;
|
||||
int src_b = src_id;
|
||||
|
||||
out[id] = in[id];
|
||||
if(weights) out[id] = in[id] * weights[src_i / step]; // [0 or c or (c, h ,w)]
|
||||
else out[id] = in[id];
|
||||
|
||||
// layers
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int i;
|
||||
for (i = 0; i < n; ++i) {
|
||||
int add_outputs = outputs_of_layers[i];
|
||||
if (src_i < add_outputs) {
|
||||
int add_index = add_outputs*src_b + src_i;
|
||||
int out_index = id;
|
||||
|
||||
float *add = layers_output[i];
|
||||
out[out_index] += add[add_index];
|
||||
const int weights_index = src_i / step + (i + 1)*layer_step; // [0 or c or (c, h ,w)]
|
||||
|
||||
if (weights) out[out_index] += add[add_index] * weights[weights_index]; // [0 or c or (c, h ,w)]
|
||||
else out[out_index] += add[add_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void backward_shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers,
|
||||
float **layers_delta, float *delta_out, float *delta_in)
|
||||
float **layers_delta, float *delta_out, float *delta_in, float *weights, float *weight_updates, int nweights, float *in, float **layers_output)
|
||||
{
|
||||
// nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
|
||||
const int layer_step = nweights / (n + 1); // 1 or l.c or (l.c * l.h * l.w)
|
||||
const int step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
|
||||
|
||||
int id;
|
||||
#pragma omp parallel for
|
||||
for (id = 0; id < size; ++id) {
|
||||
|
@ -106,17 +119,29 @@ void backward_shortcut_multilayer_cpu(int size, int src_outputs, int batch, int
|
|||
src_id /= src_outputs;
|
||||
int src_b = src_id;
|
||||
|
||||
delta_out[id] += delta_in[id];
|
||||
if (weights) {
|
||||
delta_out[id] += delta_in[id] * weights[src_i / step]; // [0 or c or (c, h ,w)]
|
||||
weight_updates[src_i / step] += delta_in[id] * in[id];
|
||||
}
|
||||
else delta_out[id] += delta_in[id];
|
||||
|
||||
// layers
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int i;
|
||||
for (i = 0; i < n; ++i) {
|
||||
int add_outputs = outputs_of_layers[i];
|
||||
if (src_i < add_outputs) {
|
||||
int add_index = add_outputs*src_b + src_i;
|
||||
int out_index = id;
|
||||
|
||||
float *layer_delta = layers_delta[i];
|
||||
layer_delta[add_index] += delta_in[id];
|
||||
if (weights) {
|
||||
float *add = layers_output[i];
|
||||
const int weights_index = src_i / step + (i + 1)*layer_step; // [0 or c or (c, h ,w)]
|
||||
|
||||
layer_delta[add_index] += delta_in[id] * weights[weights_index]; // [0 or c or (c, h ,w)]
|
||||
weight_updates[weights_index] += delta_in[id] * add[add_index];
|
||||
}
|
||||
else layer_delta[add_index] += delta_in[id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,9 +32,9 @@ void fill_cpu(int N, float ALPHA, float * X, int INCX);
|
|||
float dot_cpu(int N, float *X, int INCX, float *Y, int INCY);
|
||||
void test_gpu_blas();
|
||||
void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
|
||||
void shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, float **layers_output, float *out, float *in);
|
||||
void shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers, float **layers_output, float *out, float *in, float *weights, int nweights);
|
||||
void backward_shortcut_multilayer_cpu(int size, int src_outputs, int batch, int n, int *outputs_of_layers,
|
||||
float **layers_delta, float *delta_out, float *delta_in);
|
||||
float **layers_delta, float *delta_out, float *delta_in, float *weights, float *weight_updates, int nweights, float *in, float **layers_output);
|
||||
|
||||
void mean_cpu(float *x, int batch, int filters, int spatial, float *mean);
|
||||
void variance_cpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
|
||||
|
@ -88,8 +88,9 @@ void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *varianc
|
|||
void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
|
||||
void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean);
|
||||
void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
|
||||
void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in);
|
||||
void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_delta_gpu, float *delta_out, float *delta_in);
|
||||
void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights);
|
||||
void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_delta_gpu, float *delta_out, float *delta_in,
|
||||
float *weights, float *weight_updates, int nweights, float *in, float **layers_output);
|
||||
void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
|
||||
void scale_bias_gpu(float *output, float *biases, int batch, int n, int size);
|
||||
void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates);
|
||||
|
|
|
@ -674,17 +674,22 @@ extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX)
|
|||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
__global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in)
|
||||
__global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights)
|
||||
{
|
||||
const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (id >= size) return;
|
||||
|
||||
// nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
|
||||
const int layer_step = nweights / (n + 1); // 1 or l.c or (l.c * l.h * l.w)
|
||||
const int step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
|
||||
|
||||
int src_id = id;
|
||||
int src_i = src_id % src_outputs;
|
||||
const int src_i = src_id % src_outputs;
|
||||
src_id /= src_outputs;
|
||||
int src_b = src_id;
|
||||
|
||||
out[id] = in[id];
|
||||
if (weights_gpu) out[id] = in[id] * weights_gpu[src_i / step]; // [0 or c or (c, h ,w)]
|
||||
else out[id] = in[id];
|
||||
|
||||
// layers
|
||||
for (int i = 0; i < n; ++i) {
|
||||
|
@ -694,33 +699,45 @@ __global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch,
|
|||
int out_index = id;
|
||||
|
||||
float *add = layers_output_gpu[i];
|
||||
out[out_index] += add[add_index];
|
||||
const int weights_index = src_i / step + (i + 1)*layer_step; // [0 or c or (c, h ,w)]
|
||||
|
||||
if (weights_gpu) out[out_index] += add[add_index] * weights_gpu[weights_index]; // [0 or c or (c, h ,w)]
|
||||
else out[out_index] += add[add_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in)
|
||||
extern "C" void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights)
|
||||
{
|
||||
//printf(" src_outputs = %d, batch = %d, n = %d \n", src_outputs, batch, n);
|
||||
int size = batch * src_outputs;
|
||||
shortcut_multilayer_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu, layers_output_gpu, out, in);
|
||||
shortcut_multilayer_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu, layers_output_gpu, out, in, weights_gpu, nweights);
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
|
||||
__global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu,
|
||||
float **layers_delta_gpu, float *delta_out, float *delta_in)
|
||||
float **layers_delta_gpu, float *delta_out, float *delta_in, float *weights_gpu, float *weight_updates_gpu, int nweights, float *in, float **layers_output_gpu)
|
||||
{
|
||||
const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (id >= size) return;
|
||||
|
||||
// nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
|
||||
const int layer_step = nweights / (n + 1); // 1 or l.c or (l.c * l.h * l.w)
|
||||
const int step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
|
||||
//if(id == 0) printf(" layer_step = %d, step = %d \n", layer_step, step);
|
||||
|
||||
int src_id = id;
|
||||
int src_i = src_id % src_outputs;
|
||||
const int src_i = src_id % src_outputs;
|
||||
src_id /= src_outputs;
|
||||
int src_b = src_id;
|
||||
|
||||
delta_out[id] += delta_in[id];
|
||||
if (weights_gpu) {
|
||||
delta_out[id] += delta_in[id] * weights_gpu[src_i / step]; // [0 or c or (c, h ,w)]
|
||||
weight_updates_gpu[src_i / step] += delta_in[id] * in[id];
|
||||
}
|
||||
else delta_out[id] += delta_in[id];
|
||||
|
||||
// layers
|
||||
for (int i = 0; i < n; ++i) {
|
||||
|
@ -730,16 +747,28 @@ __global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, i
|
|||
int out_index = id;
|
||||
|
||||
float *layer_delta = layers_delta_gpu[i];
|
||||
layer_delta[add_index] += delta_in[id];
|
||||
if (weights_gpu) {
|
||||
float *add = layers_output_gpu[i];
|
||||
const int weights_index = src_i / step + (i + 1)*layer_step; // [0 or c or (c, h ,w)]
|
||||
layer_delta[add_index] += delta_in[id] * weights_gpu[weights_index];
|
||||
weight_updates_gpu[weights_index] += delta_in[id] * add[add_index];
|
||||
}
|
||||
else layer_delta[add_index] += delta_in[id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_delta_gpu, float *delta_out, float *delta_in)
|
||||
extern "C" void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu,
|
||||
float **layers_delta_gpu, float *delta_out, float *delta_in, float *weights_gpu, float *weight_updates_gpu, int nweights, float *in, float **layers_output_gpu)
|
||||
{
|
||||
const int layer_step = nweights / (n + 1); // 1 or l.c or (l.c * l.h * l.w)
|
||||
const int step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
|
||||
//printf(" nweights = %d, n = %d, layer_step = %d, step = %d \n", nweights, n, layer_step, step);
|
||||
|
||||
//printf(" src_outputs = %d, batch = %d, n = %d \n", src_outputs, batch, n);
|
||||
int size = batch * src_outputs;
|
||||
backward_shortcut_multilayer_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu, layers_delta_gpu, delta_out, delta_in);
|
||||
backward_shortcut_multilayer_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu,
|
||||
layers_delta_gpu, delta_out, delta_in, weights_gpu, weight_updates_gpu, nweights, in, layers_output_gpu);
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
|
49
src/parser.c
49
src/parser.c
|
@ -816,10 +816,10 @@ layer parse_shortcut(list *options, size_params params, network net)
|
|||
char *activation_s = option_find_str(options, "activation", "logistic");
|
||||
ACTIVATION activation = get_activation(activation_s);
|
||||
|
||||
int assisted_excitation = option_find_float_quiet(options, "assisted_excitation", 0);
|
||||
//char *l = option_find(options, "from");
|
||||
//int index = atoi(l);
|
||||
//if(index < 0) index = params.index + index;
|
||||
char *weights_type_srt = option_find_str_quiet(options, "weights_type", "none");
|
||||
WEIGHTS_TYPE_T weights_type = NO_WEIGHTS;
|
||||
if(strcmp(weights_type_srt, "per_feature") == 0) weights_type = PER_FEATURE;
|
||||
else if (strcmp(weights_type_srt, "per_channel") == 0) weights_type = PER_CHANNEL;
|
||||
|
||||
char *l = option_find(options, "from");
|
||||
int len = strlen(l);
|
||||
|
@ -854,7 +854,11 @@ layer parse_shortcut(list *options, size_params params, network net)
|
|||
}
|
||||
#endif// GPU
|
||||
|
||||
layer s = make_shortcut_layer(params.batch, n, layers, sizes, params.w, params.h, params.c, layers_output, layers_delta, layers_output_gpu, layers_delta_gpu, activation, params.train);
|
||||
layer s = make_shortcut_layer(params.batch, n, layers, sizes, params.w, params.h, params.c, layers_output, layers_delta,
|
||||
layers_output_gpu, layers_delta_gpu, weights_type, activation, params.train);
|
||||
|
||||
free(layers_output_gpu);
|
||||
free(layers_delta_gpu);
|
||||
|
||||
for (i = 0; i < n; ++i) {
|
||||
int index = layers[i];
|
||||
|
@ -1515,6 +1519,18 @@ void save_convolutional_weights_binary(layer l, FILE *fp)
|
|||
}
|
||||
}
|
||||
|
||||
void save_shortcut_weights(layer l, FILE *fp)
|
||||
{
|
||||
#ifdef GPU
|
||||
if (gpu_index >= 0) {
|
||||
pull_shortcut_layer(l);
|
||||
}
|
||||
#endif
|
||||
int num = l.nweights;
|
||||
fwrite(l.weights, sizeof(float), num, fp);
|
||||
|
||||
}
|
||||
|
||||
void save_convolutional_weights(layer l, FILE *fp)
|
||||
{
|
||||
if(l.binary){
|
||||
|
@ -1591,8 +1607,10 @@ void save_weights_upto(network net, char *filename, int cutoff)
|
|||
int i;
|
||||
for(i = 0; i < net.n && i < cutoff; ++i){
|
||||
layer l = net.layers[i];
|
||||
if(l.type == CONVOLUTIONAL && l.share_layer == NULL){
|
||||
if (l.type == CONVOLUTIONAL && l.share_layer == NULL) {
|
||||
save_convolutional_weights(l, fp);
|
||||
} if (l.type == SHORTCUT && l.nweights > 0) {
|
||||
save_shortcut_weights(l, fp);
|
||||
} if(l.type == CONNECTED){
|
||||
save_connected_weights(l, fp);
|
||||
} if(l.type == BATCHNORM){
|
||||
|
@ -1786,6 +1804,22 @@ void load_convolutional_weights(layer l, FILE *fp)
|
|||
#endif
|
||||
}
|
||||
|
||||
void load_shortcut_weights(layer l, FILE *fp)
|
||||
{
|
||||
if (l.binary) {
|
||||
//load_convolutional_weights_binary(l, fp);
|
||||
//return;
|
||||
}
|
||||
int num = l.nweights;
|
||||
int read_bytes;
|
||||
read_bytes = fread(l.weights, sizeof(float), num, fp);
|
||||
if (read_bytes > 0 && read_bytes < num) printf("\n Warning: Unexpected end of wights-file! l.weights - l.index = %d \n", l.index);
|
||||
#ifdef GPU
|
||||
if (gpu_index >= 0) {
|
||||
push_shortcut_layer(l);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void load_weights_upto(network *net, char *filename, int cutoff)
|
||||
{
|
||||
|
@ -1826,6 +1860,9 @@ void load_weights_upto(network *net, char *filename, int cutoff)
|
|||
if(l.type == CONVOLUTIONAL && l.share_layer == NULL){
|
||||
load_convolutional_weights(l, fp);
|
||||
}
|
||||
if (l.type == SHORTCUT && l.nweights > 0) {
|
||||
load_shortcut_weights(l, fp);
|
||||
}
|
||||
if(l.type == CONNECTED){
|
||||
load_connected_weights(l, fp, transpose);
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include <assert.h>
|
||||
|
||||
layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes, int w, int h, int c,
|
||||
float **layers_output, float **layers_delta, float **layers_output_gpu, float **layers_delta_gpu, ACTIVATION activation, int train)
|
||||
float **layers_output, float **layers_delta, float **layers_output_gpu, float **layers_delta_gpu, WEIGHTS_TYPE_T weights_type, ACTIVATION activation, int train)
|
||||
{
|
||||
fprintf(stderr, "Shortcut Layer: ");
|
||||
int i;
|
||||
|
@ -23,6 +23,7 @@ layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes,
|
|||
l.input_sizes = input_sizes;
|
||||
l.layers_output = layers_output;
|
||||
l.layers_delta = layers_delta;
|
||||
l.weights_type = weights_type;
|
||||
|
||||
//l.w = w2;
|
||||
//l.h = h2;
|
||||
|
@ -40,6 +41,18 @@ layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes,
|
|||
if (train) l.delta = (float*)calloc(l.outputs * batch, sizeof(float));
|
||||
l.output = (float*)calloc(l.outputs * batch, sizeof(float));
|
||||
|
||||
if (l.weights_type == PER_FEATURE) l.nweights = (l.n + 1);
|
||||
else if (l.weights_type == PER_CHANNEL) l.nweights = (l.n + 1) * l.c;
|
||||
|
||||
if (l.nweights > 0) {
|
||||
l.weights = (float*)calloc(l.nweights, sizeof(float));
|
||||
float scale = sqrt(2. / l.nweights);
|
||||
for (i = 0; i < l.nweights; ++i) l.weights[i] = 1;// scale*rand_uniform(-1, 1); // rand_normal();
|
||||
|
||||
if (train) l.weight_updates = (float*)calloc(l.nweights, sizeof(float));
|
||||
l.update = update_shortcut_layer;
|
||||
}
|
||||
|
||||
l.forward = forward_shortcut_layer;
|
||||
l.backward = backward_shortcut_layer;
|
||||
#ifndef GPU
|
||||
|
@ -52,16 +65,23 @@ layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes,
|
|||
l.forward_gpu = forward_shortcut_layer_gpu;
|
||||
l.backward_gpu = backward_shortcut_layer_gpu;
|
||||
|
||||
if (l.nweights > 0) {
|
||||
l.update_gpu = update_shortcut_layer_gpu;
|
||||
l.weights_gpu = cuda_make_array(l.weights, l.nweights);
|
||||
if (train) l.weight_updates_gpu = cuda_make_array(l.weight_updates, l.nweights);
|
||||
}
|
||||
|
||||
if (train) l.delta_gpu = cuda_make_array(l.delta, l.outputs*batch);
|
||||
l.output_gpu = cuda_make_array(l.output, l.outputs*batch);
|
||||
|
||||
l.input_sizes_gpu = cuda_make_int_array_new_api(input_sizes, l.n);
|
||||
l.layers_output_gpu = cuda_make_array_pointers((void**)layers_output_gpu, l.n);
|
||||
l.layers_delta_gpu = cuda_make_array_pointers((void**)layers_delta_gpu, l.n);
|
||||
l.layers_output_gpu = (float**)cuda_make_array_pointers((void**)layers_output_gpu, l.n);
|
||||
l.layers_delta_gpu = (float**)cuda_make_array_pointers((void**)layers_delta_gpu, l.n);
|
||||
#endif // GPU
|
||||
|
||||
l.bflops = l.out_w * l.out_h * l.out_c * l.n / 1000000000.;
|
||||
fprintf(stderr, " outputs:%4d x%4d x%4d %5.3f BF\n", l.out_w, l.out_h, l.out_c, l.bflops);
|
||||
if (l.weights_type) l.bflops *= 2;
|
||||
fprintf(stderr, " wt = %d, outputs:%4d x%4d x%4d %5.3f BF\n", l.weights_type, l.out_w, l.out_h, l.out_c, l.bflops);
|
||||
return l;
|
||||
}
|
||||
|
||||
|
@ -95,9 +115,21 @@ void resize_shortcut_layer(layer *l, int w, int h, network *net)
|
|||
l->delta_gpu = cuda_make_array(l->delta, l->outputs*l->batch);
|
||||
}
|
||||
|
||||
float **layers_output_gpu = (float **)calloc(l->n, sizeof(float *));
|
||||
float **layers_delta_gpu = (float **)calloc(l->n, sizeof(float *));
|
||||
|
||||
for (i = 0; i < l->n; ++i) {
|
||||
const int index = l->input_layers[i];
|
||||
layers_output_gpu[i] = net->layers[index].output_gpu;
|
||||
layers_delta_gpu[i] = net->layers[index].delta_gpu;
|
||||
}
|
||||
|
||||
memcpy_ongpu(l->input_sizes_gpu, l->input_sizes, l->n * sizeof(int));
|
||||
memcpy_ongpu(l->layers_output_gpu, l->layers_output, l->n * sizeof(float*));
|
||||
memcpy_ongpu(l->layers_delta_gpu, l->layers_delta, l->n * sizeof(float*));
|
||||
memcpy_ongpu(l->layers_output_gpu, layers_output_gpu, l->n * sizeof(float*));
|
||||
memcpy_ongpu(l->layers_delta_gpu, layers_delta_gpu, l->n * sizeof(float*));
|
||||
|
||||
free(layers_output_gpu);
|
||||
free(layers_delta_gpu);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
@ -108,7 +140,7 @@ void forward_shortcut_layer(const layer l, network_state state)
|
|||
int from_h = state.net.layers[l.index].h;
|
||||
int from_c = state.net.layers[l.index].c;
|
||||
|
||||
if (l.n == 1 && from_w == l.w && from_h == l.h && from_c == l.c) {
|
||||
if (l.nweights == 0 && l.n == 1 && from_w == l.w && from_h == l.h && from_c == l.c) {
|
||||
int size = l.batch * l.w * l.h * l.c;
|
||||
int i;
|
||||
#pragma omp parallel for
|
||||
|
@ -116,7 +148,7 @@ void forward_shortcut_layer(const layer l, network_state state)
|
|||
l.output[i] = state.input[i] + state.net.layers[l.index].output[i];
|
||||
}
|
||||
else {
|
||||
shortcut_multilayer_cpu(l.outputs * l.batch, l.outputs, l.batch, l.n, l.input_sizes, l.layers_output, l.output, state.input);
|
||||
shortcut_multilayer_cpu(l.outputs * l.batch, l.outputs, l.batch, l.n, l.input_sizes, l.layers_output, l.output, state.input, l.weights, l.nweights);
|
||||
}
|
||||
|
||||
//copy_cpu(l.outputs*l.batch, state.input, 1, l.output, 1);
|
||||
|
@ -135,12 +167,24 @@ void backward_shortcut_layer(const layer l, network_state state)
|
|||
else gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
|
||||
|
||||
backward_shortcut_multilayer_cpu(l.outputs * l.batch, l.outputs, l.batch, l.n, l.input_sizes,
|
||||
l.layers_delta, state.delta, l.delta);
|
||||
l.layers_delta, state.delta, l.delta, l.weights, l.weight_updates, l.nweights, state.input, l.layers_output);
|
||||
|
||||
//axpy_cpu(l.outputs*l.batch, 1, l.delta, 1, state.delta, 1);
|
||||
//shortcut_cpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta, l.w, l.h, l.c, state.net.layers[l.index].delta);
|
||||
}
|
||||
|
||||
void update_shortcut_layer(layer l, int batch, float learning_rate_init, float momentum, float decay)
|
||||
{
|
||||
float learning_rate = learning_rate_init*l.learning_rate_scale;
|
||||
//float momentum = a.momentum;
|
||||
//float decay = a.decay;
|
||||
//int batch = a.batch;
|
||||
|
||||
axpy_cpu(l.nweights, -decay*batch, l.weights, 1, l.weight_updates, 1);
|
||||
axpy_cpu(l.nweights, learning_rate / batch, l.weight_updates, 1, l.weights, 1);
|
||||
scal_cpu(l.nweights, momentum, l.weight_updates, 1);
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
void forward_shortcut_layer_gpu(const layer l, network_state state)
|
||||
{
|
||||
|
@ -159,7 +203,7 @@ void forward_shortcut_layer_gpu(const layer l, network_state state)
|
|||
//}
|
||||
//else
|
||||
{
|
||||
shortcut_multilayer_gpu(l.outputs, l.batch, l.n, l.input_sizes_gpu, l.layers_output_gpu, l.output_gpu, state.input);
|
||||
shortcut_multilayer_gpu(l.outputs, l.batch, l.n, l.input_sizes_gpu, l.layers_output_gpu, l.output_gpu, state.input, l.weights_gpu, l.nweights);
|
||||
}
|
||||
|
||||
if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
|
||||
|
@ -174,9 +218,42 @@ void backward_shortcut_layer_gpu(const layer l, network_state state)
|
|||
else if (l.activation == MISH) gradient_array_mish_ongpu(l.outputs*l.batch, l.activation_input_gpu, l.delta_gpu);
|
||||
else gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
|
||||
|
||||
backward_shortcut_multilayer_gpu(l.outputs, l.batch, l.n, l.input_sizes_gpu, l.layers_delta_gpu, state.delta, l.delta_gpu);
|
||||
backward_shortcut_multilayer_gpu(l.outputs, l.batch, l.n, l.input_sizes_gpu, l.layers_delta_gpu, state.delta, l.delta_gpu,
|
||||
l.weights_gpu, l.weight_updates_gpu, l.nweights, state.input, l.layers_output_gpu);
|
||||
|
||||
//axpy_ongpu(l.outputs*l.batch, 1, l.delta_gpu, 1, state.delta, 1);
|
||||
//shortcut_gpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta_gpu, l.w, l.h, l.c, state.net.layers[l.index].delta_gpu);
|
||||
}
|
||||
|
||||
void update_shortcut_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay)
|
||||
{
|
||||
float learning_rate = learning_rate_init*l.learning_rate_scale;
|
||||
//float momentum = a.momentum;
|
||||
//float decay = a.decay;
|
||||
//int batch = a.batch;
|
||||
|
||||
fix_nan_and_inf(l.weight_updates_gpu, l.nweights);
|
||||
fix_nan_and_inf(l.weights_gpu, l.nweights);
|
||||
|
||||
axpy_ongpu(l.nweights, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1);
|
||||
axpy_ongpu(l.nweights, learning_rate / batch, l.weight_updates_gpu, 1, l.weights_gpu, 1);
|
||||
scal_ongpu(l.nweights, momentum, l.weight_updates_gpu, 1);
|
||||
|
||||
//if (l.clip) {
|
||||
// constrain_gpu(l.nweights, l.clip, l.weights_gpu, 1);
|
||||
//}
|
||||
}
|
||||
|
||||
void pull_shortcut_layer(layer l)
|
||||
{
|
||||
cuda_pull_array_async(l.weights_gpu, l.weights, l.nweights);
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
cudaStreamSynchronize(get_cuda_stream());
|
||||
}
|
||||
|
||||
void push_shortcut_layer(layer l)
|
||||
{
|
||||
cuda_push_array(l.weights_gpu, l.weights, l.nweights);
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -8,14 +8,18 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
layer make_shortcut_layer(int batch, int n, int *input_layers, int* input_sizes, int w, int h, int c,
|
||||
float **layers_output, float **layers_delta, float **layers_output_gpu, float **layers_delta_gpu, ACTIVATION activation, int train);
|
||||
float **layers_output, float **layers_delta, float **layers_output_gpu, float **layers_delta_gpu, WEIGHTS_TYPE_T weights_type, ACTIVATION activation, int train);
|
||||
void forward_shortcut_layer(const layer l, network_state state);
|
||||
void backward_shortcut_layer(const layer l, network_state state);
|
||||
void update_shortcut_layer(layer l, int batch, float learning_rate_init, float momentum, float decay);
|
||||
void resize_shortcut_layer(layer *l, int w, int h, network *net);
|
||||
|
||||
#ifdef GPU
|
||||
void forward_shortcut_layer_gpu(const layer l, network_state state);
|
||||
void backward_shortcut_layer_gpu(const layer l, network_state state);
|
||||
void update_shortcut_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay);
|
||||
void pull_shortcut_layer(layer l);
|
||||
void push_shortcut_layer(layer l);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
Loading…
Reference in New Issue