diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index 95177705b..6f62c6701 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -725,6 +725,7 @@ namespace dlib DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size()); + double det_thresh_speed_adjust = 0; // we will scale the loss so that it doesn't get really huge @@ -741,9 +742,17 @@ namespace dlib std::vector dets; for (long i = 0; i < output_tensor.num_samples(); ++i) { - tensor_to_dets(input_tensor, output_tensor, i, dets, -options.loss_per_false_alarm, sub); + tensor_to_dets(input_tensor, output_tensor, i, dets, -options.loss_per_false_alarm + det_thresh_speed_adjust, sub); const unsigned long max_num_dets = 50 + truth->size()*5; + // Prevent calls to tensor_to_dets() from running for a really long time + // due to the production of an obscene number of detections. + const unsigned long max_num_initial_dets = max_num_dets*100; + if (dets.size() >= max_num_initial_dets) + { + det_thresh_speed_adjust = std::max(det_thresh_speed_adjust,dets[max_num_initial_dets].detection_confidence + options.loss_per_false_alarm); + } + // The loss will measure the number of incorrect detections. A detection is // incorrect if it doesn't hit a truth rectangle or if it is a duplicate detection