add thresholds (#6831)

This commit is contained in:
8k 2021-07-16 06:28:25 +09:00 committed by GitHub
parent 01cf2b4c86
commit 4361b7a322
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 6 deletions

View File

@ -1047,7 +1047,7 @@ LIB_API void reset_rnn(network *net);
LIB_API float *network_predict_image(network *net, image im);
LIB_API float *network_predict_image_letterbox(network *net, image im);
LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, int letter_box, network *existing_net);
LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port, int show_imgs, int benchmark_layers, char* chart_path);
LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, float thresh, float iou_thresh, int mjpeg_port, int show_imgs, int benchmark_layers, char* chart_path);
LIB_API void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh,
float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile, int letter_box, int benchmark_layers);
LIB_API int network_width(network *net);

View File

@ -23,7 +23,7 @@ int check_mistakes = 0;
static int coco_ids[] = { 1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90 };
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port, int show_imgs, int benchmark_layers, char* chart_path)
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, float thresh, float iou_thresh, int mjpeg_port, int show_imgs, int benchmark_layers, char* chart_path)
{
list *options = read_data_cfg(datacfg);
char *train_images = option_find_str(options, "train", "data/train.txt");
@ -305,7 +305,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
//next_map_calc = fmax(next_map_calc, 400);
if (calc_map) {
printf("\n (next mAP calculation at %d iterations) ", next_map_calc);
if (mean_average_precision > 0) printf("\n Last accuracy mAP@0.5 = %2.2f %%, best = %2.2f %% ", mean_average_precision * 100, best_map * 100);
if (mean_average_precision > 0) printf("\n Last accuracy mAP@%0.2f = %2.2f %%, best = %2.2f %% ", iou_thresh, mean_average_precision * 100, best_map * 100);
}
if (net.cudnn_half) {
@ -353,8 +353,8 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
//network net_combined = combine_train_valid_networks(net, net_map);
iter_map = iteration;
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, 0, net.letter_box, &net_map);// &net_combined);
printf("\n mean_average_precision (mAP@0.5) = %f \n", mean_average_precision);
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, thresh, iou_thresh, 0, net.letter_box, &net_map);// &net_combined);
printf("\n mean_average_precision (mAP@%0.2f) = %f \n", iou_thresh, mean_average_precision);
if (mean_average_precision > best_map) {
best_map = mean_average_precision;
printf("New best mAP!\n");
@ -2017,7 +2017,7 @@ void run_detector(int argc, char **argv)
if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0;
char *filename = (argc > 6) ? argv[6] : 0;
if (0 == strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels, outfile, letter_box, benchmark_layers);
else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map, mjpeg_port, show_imgs, benchmark_layers, chart_path);
else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map, thresh, iou_thresh, mjpeg_port, show_imgs, benchmark_layers, chart_path);
else if (0 == strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
else if (0 == strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights);
else if (0 == strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, map_points, letter_box, NULL);