diff --git a/src/tests.c b/src/tests.c index a6c3cd32..eb85a5f6 100644 --- a/src/tests.c +++ b/src/tests.c @@ -548,7 +548,9 @@ void visualize_imagenet_topk(char *filename) score[i] = calloc(topk, sizeof(float)); } + int count = 0; while(n){ + ++count; char *image_path = (char *)n->val; image im = load_image(image_path, 0, 0); n = n->next; @@ -560,37 +562,46 @@ void visualize_imagenet_topk(char *filename) forward_network(net, im.data); image out = get_network_image(net); - int dh = (im.h - h)/h; - int dw = (im.w - w)/w; - for(i = 0; i < out.h; ++i){ - for(j = 0; j < out.w; ++j){ - image sub = get_sub_image(im, dh*i, dw*j, h, w); - for(k = 0; k < out.c; ++k){ + int dh = (im.h - h)/(out.h-1); + int dw = (im.w - w)/(out.w-1); + //printf("%d %d\n", dh, dw); + for(k = 0; k < out.c; ++k){ + float topv = 0; + int topi = -1; + int topj = -1; + for(i = 0; i < out.h; ++i){ + for(j = 0; j < out.w; ++j){ float val = get_pixel(out, i, j, k); - //printf("%f, ", val); - image sub_c = copy_image(sub); - for(l = 0; l < topk; ++l){ - if(val > score[k][l]){ - float swap = score[k][l]; - score[k][l] = val; - val = swap; - - image swapi = vizs[k][l]; - vizs[k][l] = sub_c; - sub_c = swapi; - } + if(val > topv){ + topv = val; + topi = i; + topj = j; + } + } + } + if(topv){ + image sub = get_sub_image(im, dh*topi, dw*topj, h, w); + for(l = 0; l < topk; ++l){ + if(topv > score[k][l]){ + float swap = score[k][l]; + score[k][l] = topv; + topv = swap; + + image swapi = vizs[k][l]; + vizs[k][l] = sub; + sub = swapi; } - free_image(sub_c); } free_image(sub); } } free_image(im); - //printf("\n"); - image grid = grid_images(vizs, num, topk); - show_image(grid, "IMAGENET Visualization"); - save_image(grid, "IMAGENET Grid"); - free_image(grid); + if(count%50 == 0){ + image grid = grid_images(vizs, num, topk); + //show_image(grid, "IMAGENET Visualization"); + save_image(grid, "IMAGENET Grid Single Nonorm"); + free_image(grid); + } } //cvWaitKey(0); } @@ -644,7 +655,7 @@ void visualize_cat() printf("Processing %dx%d image\n", im.h, im.w); resize_network(net, im.h, im.w, im.c); forward_network(net, im.data); - + image out = get_network_image(net); visualize_network(net); cvWaitKey(1000); @@ -778,7 +789,7 @@ int main(int argc, char *argv[]) //features_VOC_image(argv[1], argv[2], argv[3]); //features_VOC_image_size(argv[1], atoi(argv[2]), atoi(argv[3])); //visualize_imagenet_features("data/assira/train.list"); - visualize_imagenet_topk("data/VOC2011.list"); + visualize_imagenet_topk("data/VOC2012.list"); //visualize_cat(); //flip_network(); //test_visualize();