From f0c0b307d5253de14f21fa8bc9acd58217d13901 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= <1671644+arrufat@users.noreply.github.com> Date: Wed, 8 Dec 2021 03:49:06 +0100 Subject: [PATCH] Fix crash when truth center is outside of the image (#2471) --- dlib/dnn/loss.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index 1a75559b5..a690c1bc9 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -3682,7 +3682,7 @@ namespace dlib tensor& grad = layer(sub).get_gradient_input(); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); - const double input_area = input_tensor.nr() * input_tensor.nc(); + const rectangle input_rect(input_tensor.nr(), input_tensor.nc()); float* g = grad.host(); // Compute the objectness loss for all grid cells @@ -3709,7 +3709,7 @@ namespace dlib double best_iou = 0; for (const yolo_rect& truth_box : *truth) { - if (truth_box.ignore) + if (truth_box.ignore || !input_rect.contains(center(truth_box.rect))) continue; best_iou = std::max(best_iou, box_intersection_over_union(truth_box.rect, pred.rect)); } @@ -3724,7 +3724,7 @@ namespace dlib // Now find the best anchor box for each truth box for (const yolo_rect& truth_box : *truth) { - if (truth_box.ignore) + if (truth_box.ignore || !input_rect.contains(center(truth_box.rect))) continue; const dpoint t_center = dcenter(truth_box); double best_iou = 0; @@ -3780,7 +3780,7 @@ namespace dlib const double th = truth_box.rect.height() / (anchors[a].height + truth_box.rect.height()); // Scale regression error according to the truth size - const double scale_box = options.lambda_box * (2 - truth_box.rect.area() / input_area); + const double scale_box = options.lambda_box * (2 - truth_box.rect.area() / input_rect.area()); // Compute the gradient for the box coordinates g[x_idx] = scale_box * (out_data[x_idx] - tx);