Fixed adversarial training

This commit is contained in:
AlexeyAB 2021-06-29 23:54:58 +03:00
parent 85c6278ef1
commit c2221f07f8
3 changed files with 18 additions and 5 deletions

View File

@ -1056,7 +1056,7 @@ LIB_API void optimize_picture(network *net, image orig, int max_layer, float sca
// image.h // image.h
LIB_API void make_image_red(image im); LIB_API void make_image_red(image im);
LIB_API image make_attention_image(int img_size, float *original_delta_cpu, float *original_input_cpu, int w, int h, int c); LIB_API image make_attention_image(int img_size, float *original_delta_cpu, float *original_input_cpu, int w, int h, int c, float alpha);
LIB_API image resize_image(image im, int w, int h); LIB_API image resize_image(image im, int w, int h);
LIB_API void quantize_image(image im); LIB_API void quantize_image(image im);
LIB_API void copy_image_from_bytes(image im, char *pdata); LIB_API void copy_image_from_bytes(image im, char *pdata);

View File

@ -1351,7 +1351,7 @@ void make_image_red(image im)
} }
} }
image make_attention_image(int img_size, float *original_delta_cpu, float *original_input_cpu, int w, int h, int c) image make_attention_image(int img_size, float *original_delta_cpu, float *original_input_cpu, int w, int h, int c, float alpha)
{ {
image attention_img; image attention_img;
attention_img.w = w; attention_img.w = w;
@ -1379,7 +1379,7 @@ image make_attention_image(int img_size, float *original_delta_cpu, float *origi
image resized = resize_image(attention_img, w / 4, h / 4); image resized = resize_image(attention_img, w / 4, h / 4);
attention_img = resize_image(resized, w, h); attention_img = resize_image(resized, w, h);
free_image(resized); free_image(resized);
for (k = 0; k < img_size; ++k) attention_img.data[k] += original_input_cpu[k]; for (k = 0; k < img_size; ++k) attention_img.data[k] = attention_img.data[k]*alpha + (1-alpha)*original_input_cpu[k];
//normalize_image(attention_img); //normalize_image(attention_img);
//show_image(attention_img, "delta"); //show_image(attention_img, "delta");

View File

@ -76,7 +76,7 @@ void forward_network_gpu(network net, network_state state)
for(i = 0; i < net.n; ++i){ for(i = 0; i < net.n; ++i){
state.index = i; state.index = i;
layer l = net.layers[i]; layer l = net.layers[i];
if(l.delta_gpu && state.train && l.train){ if(l.delta_gpu && state.train){
fill_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1); fill_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1);
} }
@ -235,12 +235,25 @@ void backward_network_gpu(network net, network_state state)
cuda_pull_array(original_input, original_input_cpu, img_size); cuda_pull_array(original_input, original_input_cpu, img_size);
cuda_pull_array(original_delta, original_delta_cpu, img_size); cuda_pull_array(original_delta, original_delta_cpu, img_size);
image attention_img = make_attention_image(img_size, original_delta_cpu, original_input_cpu, net.w, net.h, net.c); image attention_img = make_attention_image(img_size, original_delta_cpu, original_input_cpu, net.w, net.h, net.c, 0.7);
show_image(attention_img, "attention_img"); show_image(attention_img, "attention_img");
resize_window_cv("attention_img", 500, 500); resize_window_cv("attention_img", 500, 500);
//static int img_counter = 0;
//img_counter++;
//char buff[256];
//sprintf(buff, "attention_img_%d.png", img_counter);
//save_image_png(attention_img, buff);
free_image(attention_img); free_image(attention_img);
image attention_mask_img = make_attention_image(img_size, original_delta_cpu, original_delta_cpu, net.w, net.h, net.c, 1.0);
show_image(attention_mask_img, "attention_mask_img");
resize_window_cv("attention_mask_img", 500, 500);
//sprintf(buff, "attention_mask_img_%d.png", img_counter);
//save_image_png(attention_mask_img, buff);
free_image(attention_mask_img);
free(original_input_cpu); free(original_input_cpu);
free(original_delta_cpu); free(original_delta_cpu);
} }