Detection is accelerated by 7 percent (fused conv and batch_norm layers)

This commit is contained in:
AlexeyAB 2018-04-03 16:42:00 +03:00
parent ee8a941922
commit 1b2c70f82a
4 changed files with 40 additions and 1 deletions

View File

@ -148,7 +148,7 @@ void demo(char *cfgfile, char *weightfile, float thresh, float hier_thresh, int
load_weights(&net, weightfile);
}
set_batch_network(&net, 1);
fuse_conv_batchnorm(net);
srand(2222222);
if(filename){

View File

@ -419,6 +419,7 @@ void validate_detector_recall(char *datacfg, char *cfgfile, char *weightfile)
load_weights(&net, weightfile);
}
set_batch_network(&net, 1);
fuse_conv_batchnorm(net);
srand(time(0));
//list *plist = get_paths("data/coco_val_5k.list");
@ -526,6 +527,7 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float
load_weights(&net, weightfile);
}
set_batch_network(&net, 1);
fuse_conv_batchnorm(net);
srand(time(0));
list *plist = get_paths(valid_images);
@ -1022,6 +1024,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
load_weights(&net, weightfile);
}
set_batch_network(&net, 1);
fuse_conv_batchnorm(net);
srand(2222222);
clock_t time;
char buff[256];

View File

@ -748,3 +748,38 @@ void free_network(network net)
free(net.workspace);
#endif
}
void fuse_conv_batchnorm(network net)
{
int j;
for (j = 0; j < net.n; ++j) {
layer *l = &net.layers[j];
if (l->type == CONVOLUTIONAL) {
printf(" Fuse Convolutional layer \t\t l->size = %d \n", l->size);
if (l->batch_normalize) {
int f;
for (f = 0; f < l->n; ++f)
{
l->biases[f] = l->biases[f] - l->scales[f] * l->rolling_mean[f] / (sqrtf(l->rolling_variance[f]) + .000001f);
const size_t filter_size = l->size*l->size*l->c;
int i;
for (i = 0; i < filter_size; ++i) {
int w_index = f*filter_size + i;
l->weights[w_index] = l->weights[w_index] * l->scales[f] / (sqrtf(l->rolling_variance[f]) + .000001f);
}
}
l->batch_normalize = 0;
push_convolutional_layer(*l);
}
}
else {
printf(" Skip layer: %d \n", l->type);
}
}
}

View File

@ -138,6 +138,7 @@ void free_detections(detection *dets, int n);
int get_network_nuisance(network net);
int get_network_background(network net);
void fuse_conv_batchnorm(network net);
#ifdef __cplusplus
}