Fixed ignore_thresh

This commit is contained in:
AlexeyAB 2019-11-14 23:18:21 +03:00
parent 509ba13acf
commit ee370e765d
1 changed files with 14 additions and 20 deletions

View File

@ -129,23 +129,18 @@ box get_yolo_box(float *x, float *biases, int n, int index, int i, int j, int lw
} }
int get_yolo_class(float *output, int classes, int class_index, int stride, float objectness) int compare_yolo_class(float *output, int classes, int class_index, int stride, float objectness, int class_id)
{ {
int class_id = 0; const float conf_thresh = 0.25;
float max_prob = FLT_MIN;
int j; int j;
for (j = 0; j < classes; ++j) { for (j = 0; j < classes; ++j) {
float prob = objectness * output[class_index + stride*j]; float prob = objectness * output[class_index + stride*j];
if (prob > max_prob) { if (prob > conf_thresh) {
max_prob = prob; return 1;
class_id = j;
} }
//int class_index = entry_index(l, 0, n*l.w*l.h + i, 4 + 1 + j);
//float prob = objectness*predictions[class_index];
//dets[count].prob[j] = (prob > thresh) ? prob : 0;
} }
return class_id; return 0;
} }
ious delta_yolo_box(box truth, float *x, float *biases, int n, int index, int i, int j, int lw, int lh, int w, int h, float *delta, float scale, int stride, float iou_normalizer, IOU_LOSS iou_loss) ious delta_yolo_box(box truth, float *x, float *biases, int n, int index, int i, int j, int lw, int lh, int w, int h, float *delta, float scale, int stride, float iou_normalizer, IOU_LOSS iou_loss)
@ -280,6 +275,8 @@ void forward_yolo_layer(const layer l, network_state state)
for (n = 0; n < l.n; ++n) { for (n = 0; n < l.n; ++n) {
int box_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 0); int box_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 0);
box pred = get_yolo_box(l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, state.net.w, state.net.h, l.w*l.h); box pred = get_yolo_box(l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, state.net.w, state.net.h, l.w*l.h);
float best_match_iou = 0;
int best_match_t = 0;
float best_iou = 0; float best_iou = 0;
int best_t = 0; int best_t = 0;
for (t = 0; t < l.max_boxes; ++t) { for (t = 0; t < l.max_boxes; ++t) {
@ -296,14 +293,14 @@ void forward_yolo_layer(const layer l, network_state state)
int class_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4 + 1); int class_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4 + 1);
int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4); int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4);
float objectness = l.output[obj_index]; float objectness = l.output[obj_index];
int pred_class_id = get_yolo_class(l.output, l.classes, class_index, l.w*l.h, objectness); int class_id_match = compare_yolo_class(l.output, l.classes, class_index, l.w*l.h, objectness, class_id);
int class_id_match = 0;
if (class_id == pred_class_id) class_id_match = 1;
else class_id_match = 0;
float iou = box_iou(pred, truth); float iou = box_iou(pred, truth);
//if (iou > best_iou) { if (iou > best_match_iou && class_id_match == 1) {
if (iou > best_iou && class_id_match == 1) { best_match_iou = iou;
best_match_t = t;
}
if (iou > best_iou) {
best_iou = iou; best_iou = iou;
best_t = t; best_t = t;
} }
@ -311,7 +308,7 @@ void forward_yolo_layer(const layer l, network_state state)
int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4); int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4);
avg_anyobj += l.output[obj_index]; avg_anyobj += l.output[obj_index];
l.delta[obj_index] = l.cls_normalizer * (0 - l.output[obj_index]); l.delta[obj_index] = l.cls_normalizer * (0 - l.output[obj_index]);
if (best_iou > l.ignore_thresh) { if (best_match_iou > l.ignore_thresh) {
l.delta[obj_index] = 0; l.delta[obj_index] = 0;
} }
if (best_iou > l.truth_thresh) { if (best_iou > l.truth_thresh) {
@ -376,9 +373,6 @@ void forward_yolo_layer(const layer l, network_state state)
++count; ++count;
++class_count; ++class_count;
//if(iou > .5) recall += 1;
//if(iou > .75) recall75 += 1;
//avg_iou += iou;
if (all_ious.iou > .5) recall += 1; if (all_ious.iou > .5) recall += 1;
if (all_ious.iou > .75) recall75 += 1; if (all_ious.iou > .75) recall75 += 1;
} }