diff --git a/src/demo.c b/src/demo.c index 2337a0ad..6ae3c336 100644 --- a/src/demo.c +++ b/src/demo.c @@ -148,7 +148,7 @@ void demo(char *cfgfile, char *weightfile, float thresh, float hier_thresh, int load_weights(&net, weightfile); } set_batch_network(&net, 1); - + fuse_conv_batchnorm(net); srand(2222222); if(filename){ diff --git a/src/detector.c b/src/detector.c index 45d3b226..25968e66 100644 --- a/src/detector.c +++ b/src/detector.c @@ -419,6 +419,7 @@ void validate_detector_recall(char *datacfg, char *cfgfile, char *weightfile) load_weights(&net, weightfile); } set_batch_network(&net, 1); + fuse_conv_batchnorm(net); srand(time(0)); //list *plist = get_paths("data/coco_val_5k.list"); @@ -526,6 +527,7 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float load_weights(&net, weightfile); } set_batch_network(&net, 1); + fuse_conv_batchnorm(net); srand(time(0)); list *plist = get_paths(valid_images); @@ -1022,6 +1024,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam load_weights(&net, weightfile); } set_batch_network(&net, 1); + fuse_conv_batchnorm(net); srand(2222222); clock_t time; char buff[256]; diff --git a/src/network.c b/src/network.c index 175c1029..bfade3cc 100644 --- a/src/network.c +++ b/src/network.c @@ -748,3 +748,38 @@ void free_network(network net) free(net.workspace); #endif } + + +void fuse_conv_batchnorm(network net) +{ + int j; + for (j = 0; j < net.n; ++j) { + layer *l = &net.layers[j]; + + if (l->type == CONVOLUTIONAL) { + printf(" Fuse Convolutional layer \t\t l->size = %d \n", l->size); + + if (l->batch_normalize) { + int f; + for (f = 0; f < l->n; ++f) + { + l->biases[f] = l->biases[f] - l->scales[f] * l->rolling_mean[f] / (sqrtf(l->rolling_variance[f]) + .000001f); + + const size_t filter_size = l->size*l->size*l->c; + int i; + for (i = 0; i < filter_size; ++i) { + int w_index = f*filter_size + i; + + l->weights[w_index] = l->weights[w_index] * l->scales[f] / (sqrtf(l->rolling_variance[f]) + .000001f); + } + } + + l->batch_normalize = 0; + push_convolutional_layer(*l); + } + } + else { + printf(" Skip layer: %d \n", l->type); + } + } +} diff --git a/src/network.h b/src/network.h index 5b405c4a..965693ad 100644 --- a/src/network.h +++ b/src/network.h @@ -138,6 +138,7 @@ void free_detections(detection *dets, int n); int get_network_nuisance(network net); int get_network_background(network net); +void fuse_conv_batchnorm(network net); #ifdef __cplusplus }