mirror of https://github.com/AlexeyAB/darknet.git
more go
This commit is contained in:
parent
cff59ba135
commit
67794a52a1
33
src/go.c
33
src/go.c
|
@ -176,7 +176,7 @@ void flip_board(float *board)
|
|||
}
|
||||
}
|
||||
|
||||
void test_go(char *filename, char *weightfile)
|
||||
void test_go(char *filename, char *weightfile, int multi)
|
||||
{
|
||||
network net = parse_network_cfg(filename);
|
||||
if(weightfile){
|
||||
|
@ -191,25 +191,25 @@ void test_go(char *filename, char *weightfile)
|
|||
float *output = network_predict(net, board);
|
||||
copy_cpu(19*19, output, 1, move, 1);
|
||||
int i;
|
||||
#ifdef GPU
|
||||
image bim = float_to_image(19, 19, 1, board);
|
||||
for(i = 1; i < 8; ++i){
|
||||
rotate_image_cw(bim, i);
|
||||
if(i >= 4) flip_image(bim);
|
||||
if(multi){
|
||||
image bim = float_to_image(19, 19, 1, board);
|
||||
for(i = 1; i < 8; ++i){
|
||||
rotate_image_cw(bim, i);
|
||||
if(i >= 4) flip_image(bim);
|
||||
|
||||
float *output = network_predict(net, board);
|
||||
image oim = float_to_image(19, 19, 1, output);
|
||||
float *output = network_predict(net, board);
|
||||
image oim = float_to_image(19, 19, 1, output);
|
||||
|
||||
if(i >= 4) flip_image(oim);
|
||||
rotate_image_cw(oim, -i);
|
||||
if(i >= 4) flip_image(oim);
|
||||
rotate_image_cw(oim, -i);
|
||||
|
||||
axpy_cpu(19*19, 1, output, 1, move, 1);
|
||||
axpy_cpu(19*19, 1, output, 1, move, 1);
|
||||
|
||||
if(i >= 4) flip_image(bim);
|
||||
rotate_image_cw(bim, -i);
|
||||
if(i >= 4) flip_image(bim);
|
||||
rotate_image_cw(bim, -i);
|
||||
}
|
||||
scal_cpu(19*19, 1./8., move, 1);
|
||||
}
|
||||
scal_cpu(19*19, 1./8., move, 1);
|
||||
#endif
|
||||
for(i = 0; i < 19*19; ++i){
|
||||
if(board[i]) move[i] = 0;
|
||||
}
|
||||
|
@ -282,8 +282,9 @@ void run_go(int argc, char **argv)
|
|||
|
||||
char *cfg = argv[3];
|
||||
char *weights = (argc > 4) ? argv[4] : 0;
|
||||
int multi = find_arg(argc, argv, "-multi");
|
||||
if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
|
||||
else if(0==strcmp(argv[2], "test")) test_go(cfg, weights);
|
||||
else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue