diff --git a/src/go.c b/src/go.c index 6607e7a9..9d31539a 100644 --- a/src/go.c +++ b/src/go.c @@ -176,7 +176,7 @@ void flip_board(float *board) } } -void test_go(char *filename, char *weightfile) +void test_go(char *filename, char *weightfile, int multi) { network net = parse_network_cfg(filename); if(weightfile){ @@ -191,25 +191,25 @@ void test_go(char *filename, char *weightfile) float *output = network_predict(net, board); copy_cpu(19*19, output, 1, move, 1); int i; -#ifdef GPU - image bim = float_to_image(19, 19, 1, board); - for(i = 1; i < 8; ++i){ - rotate_image_cw(bim, i); - if(i >= 4) flip_image(bim); + if(multi){ + image bim = float_to_image(19, 19, 1, board); + for(i = 1; i < 8; ++i){ + rotate_image_cw(bim, i); + if(i >= 4) flip_image(bim); - float *output = network_predict(net, board); - image oim = float_to_image(19, 19, 1, output); + float *output = network_predict(net, board); + image oim = float_to_image(19, 19, 1, output); - if(i >= 4) flip_image(oim); - rotate_image_cw(oim, -i); + if(i >= 4) flip_image(oim); + rotate_image_cw(oim, -i); - axpy_cpu(19*19, 1, output, 1, move, 1); + axpy_cpu(19*19, 1, output, 1, move, 1); - if(i >= 4) flip_image(bim); - rotate_image_cw(bim, -i); + if(i >= 4) flip_image(bim); + rotate_image_cw(bim, -i); + } + scal_cpu(19*19, 1./8., move, 1); } - scal_cpu(19*19, 1./8., move, 1); -#endif for(i = 0; i < 19*19; ++i){ if(board[i]) move[i] = 0; } @@ -282,8 +282,9 @@ void run_go(int argc, char **argv) char *cfg = argv[3]; char *weights = (argc > 4) ? argv[4] : 0; + int multi = find_arg(argc, argv, "-multi"); if(0==strcmp(argv[2], "train")) train_go(cfg, weights); - else if(0==strcmp(argv[2], "test")) test_go(cfg, weights); + else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi); }