mirror of https://github.com/davisking/dlib.git
Added missing input validation to loss_mmod_. Specifically, the loss layer now
checks if the user is giving truth boxes that can't be detected because the non-max suppression settings would prevent them from being output at the same time. If this happens then we print a warning message and set one of the offending boxes to "ignore".
This commit is contained in:
parent
bf55c4e8e1
commit
29db3ee566
|
@ -697,12 +697,12 @@ namespace dlib
|
|||
double loss = 0;
|
||||
|
||||
float* g = grad.host_write_only();
|
||||
// zero initialize grad.
|
||||
for (auto&& x : grad)
|
||||
x = 0;
|
||||
for (size_t i = 0; i < grad.size(); ++i)
|
||||
g[i] = 0;
|
||||
|
||||
const float* out_data = output_tensor.host();
|
||||
|
||||
std::vector<size_t> truth_idxs; truth_idxs.reserve(truth->size());
|
||||
std::vector<intermediate_detection> dets;
|
||||
for (long i = 0; i < output_tensor.num_samples(); ++i)
|
||||
{
|
||||
|
@ -726,14 +726,17 @@ namespace dlib
|
|||
loss -= 1;
|
||||
continue;
|
||||
}
|
||||
loss -= out_data[(k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x()];
|
||||
const size_t idx = (k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x();
|
||||
loss -= out_data[idx];
|
||||
// compute gradient
|
||||
g[(k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x()] = -scale;
|
||||
g[idx] = -scale;
|
||||
truth_idxs.push_back(idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
// This box was ignored so shouldn't have been counted in the loss.
|
||||
loss -= 1;
|
||||
truth_idxs.push_back(0);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -772,6 +775,33 @@ namespace dlib
|
|||
}
|
||||
}
|
||||
|
||||
// Check if any of the truth boxes are unobtainable because the NMS is
|
||||
// killing them. If so, automatically set those unobtainable boxes to
|
||||
// ignore and print a warning message to the user.
|
||||
for (size_t i = 0; i < hit_truth_table.size(); ++i)
|
||||
{
|
||||
if (!hit_truth_table[i] && !(*truth)[i].ignore)
|
||||
{
|
||||
// So we didn't hit this truth box. Is that because there is
|
||||
// another, different truth box, that overlaps it according to NMS?
|
||||
const std::pair<double,unsigned int> hittruth = find_best_match(*truth, (*truth)[i], i);
|
||||
if (hittruth.second == i)
|
||||
continue;
|
||||
rectangle best_matching_truth_box = (*truth)[hittruth.second];
|
||||
if (options.overlaps_nms(best_matching_truth_box, (*truth)[i]))
|
||||
{
|
||||
const size_t idx = truth_idxs[i];
|
||||
// We are ignoring this box so we shouldn't have counted it in the
|
||||
// loss in the first place. So we subtract out the loss values we
|
||||
// added for it in the code above.
|
||||
loss -= 1-out_data[idx];
|
||||
g[idx] = 0;
|
||||
std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << (*truth)[i].rect;
|
||||
std::cout << " that is suppressed by non-max-suppression ";
|
||||
std::cout << "because it is overlapped by another truth rectangle located at " << best_matching_truth_box << "." << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hit_truth_table.assign(hit_truth_table.size(), false);
|
||||
final_dets.clear();
|
||||
|
@ -1012,12 +1042,21 @@ namespace dlib
|
|||
const std::vector<mmod_rect>& boxes,
|
||||
const rectangle& rect
|
||||
) const
|
||||
{
|
||||
return find_best_match(boxes, rect, boxes.size());
|
||||
}
|
||||
|
||||
std::pair<double,unsigned int> find_best_match(
|
||||
const std::vector<mmod_rect>& boxes,
|
||||
const rectangle& rect,
|
||||
const size_t excluded_idx
|
||||
) const
|
||||
{
|
||||
double match = 0;
|
||||
unsigned int best_idx = 0;
|
||||
for (unsigned long i = 0; i < boxes.size(); ++i)
|
||||
{
|
||||
if (boxes[i].ignore)
|
||||
if (boxes[i].ignore || excluded_idx == i)
|
||||
continue;
|
||||
|
||||
const double new_match = box_intersection_over_union(rect, boxes[i]);
|
||||
|
|
Loading…
Reference in New Issue