mirror of https://github.com/davisking/dlib.git
Added missing input validation to loss_mmod_. Specifically, the loss layer now
checks if the user is giving truth boxes that can't be detected because the non-max suppression settings would prevent them from being output at the same time. If this happens then we print a warning message and set one of the offending boxes to "ignore".
This commit is contained in:
parent
bf55c4e8e1
commit
29db3ee566
|
@ -697,12 +697,12 @@ namespace dlib
|
||||||
double loss = 0;
|
double loss = 0;
|
||||||
|
|
||||||
float* g = grad.host_write_only();
|
float* g = grad.host_write_only();
|
||||||
// zero initialize grad.
|
for (size_t i = 0; i < grad.size(); ++i)
|
||||||
for (auto&& x : grad)
|
g[i] = 0;
|
||||||
x = 0;
|
|
||||||
|
|
||||||
const float* out_data = output_tensor.host();
|
const float* out_data = output_tensor.host();
|
||||||
|
|
||||||
|
std::vector<size_t> truth_idxs; truth_idxs.reserve(truth->size());
|
||||||
std::vector<intermediate_detection> dets;
|
std::vector<intermediate_detection> dets;
|
||||||
for (long i = 0; i < output_tensor.num_samples(); ++i)
|
for (long i = 0; i < output_tensor.num_samples(); ++i)
|
||||||
{
|
{
|
||||||
|
@ -726,14 +726,17 @@ namespace dlib
|
||||||
loss -= 1;
|
loss -= 1;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
loss -= out_data[(k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x()];
|
const size_t idx = (k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x();
|
||||||
|
loss -= out_data[idx];
|
||||||
// compute gradient
|
// compute gradient
|
||||||
g[(k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x()] = -scale;
|
g[idx] = -scale;
|
||||||
|
truth_idxs.push_back(idx);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
// This box was ignored so shouldn't have been counted in the loss.
|
// This box was ignored so shouldn't have been counted in the loss.
|
||||||
loss -= 1;
|
loss -= 1;
|
||||||
|
truth_idxs.push_back(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -772,6 +775,33 @@ namespace dlib
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if any of the truth boxes are unobtainable because the NMS is
|
||||||
|
// killing them. If so, automatically set those unobtainable boxes to
|
||||||
|
// ignore and print a warning message to the user.
|
||||||
|
for (size_t i = 0; i < hit_truth_table.size(); ++i)
|
||||||
|
{
|
||||||
|
if (!hit_truth_table[i] && !(*truth)[i].ignore)
|
||||||
|
{
|
||||||
|
// So we didn't hit this truth box. Is that because there is
|
||||||
|
// another, different truth box, that overlaps it according to NMS?
|
||||||
|
const std::pair<double,unsigned int> hittruth = find_best_match(*truth, (*truth)[i], i);
|
||||||
|
if (hittruth.second == i)
|
||||||
|
continue;
|
||||||
|
rectangle best_matching_truth_box = (*truth)[hittruth.second];
|
||||||
|
if (options.overlaps_nms(best_matching_truth_box, (*truth)[i]))
|
||||||
|
{
|
||||||
|
const size_t idx = truth_idxs[i];
|
||||||
|
// We are ignoring this box so we shouldn't have counted it in the
|
||||||
|
// loss in the first place. So we subtract out the loss values we
|
||||||
|
// added for it in the code above.
|
||||||
|
loss -= 1-out_data[idx];
|
||||||
|
g[idx] = 0;
|
||||||
|
std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << (*truth)[i].rect;
|
||||||
|
std::cout << " that is suppressed by non-max-suppression ";
|
||||||
|
std::cout << "because it is overlapped by another truth rectangle located at " << best_matching_truth_box << "." << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
hit_truth_table.assign(hit_truth_table.size(), false);
|
hit_truth_table.assign(hit_truth_table.size(), false);
|
||||||
final_dets.clear();
|
final_dets.clear();
|
||||||
|
@ -1012,12 +1042,21 @@ namespace dlib
|
||||||
const std::vector<mmod_rect>& boxes,
|
const std::vector<mmod_rect>& boxes,
|
||||||
const rectangle& rect
|
const rectangle& rect
|
||||||
) const
|
) const
|
||||||
|
{
|
||||||
|
return find_best_match(boxes, rect, boxes.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<double,unsigned int> find_best_match(
|
||||||
|
const std::vector<mmod_rect>& boxes,
|
||||||
|
const rectangle& rect,
|
||||||
|
const size_t excluded_idx
|
||||||
|
) const
|
||||||
{
|
{
|
||||||
double match = 0;
|
double match = 0;
|
||||||
unsigned int best_idx = 0;
|
unsigned int best_idx = 0;
|
||||||
for (unsigned long i = 0; i < boxes.size(); ++i)
|
for (unsigned long i = 0; i < boxes.size(); ++i)
|
||||||
{
|
{
|
||||||
if (boxes[i].ignore)
|
if (boxes[i].ignore || excluded_idx == i)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
const double new_match = box_intersection_over_union(rect, boxes[i]);
|
const double new_match = box_intersection_over_union(rect, boxes[i]);
|
||||||
|
|
Loading…
Reference in New Issue