Reduce code duplication a bit and make equal_error_rate() give correct results when called on data where all detection scores are identical.

Previously it would say the EER was 0, but really it should have said 1 in this case.
This commit is contained in:
Davis King 2020-04-18 13:57:56 -04:00
parent 0e923cff93
commit a2e45f00b2
2 changed files with 76 additions and 46 deletions

View File

@ -137,49 +137,6 @@ namespace dlib
mean = X*mean;
}
// ----------------------------------------------------------------------------------------
inline std::pair<double,double> equal_error_rate (
const std::vector<double>& low_vals,
const std::vector<double>& high_vals
)
{
std::vector<std::pair<double,int> > temp;
temp.reserve(low_vals.size()+high_vals.size());
for (unsigned long i = 0; i < low_vals.size(); ++i)
temp.push_back(std::make_pair(low_vals[i], -1));
for (unsigned long i = 0; i < high_vals.size(); ++i)
temp.push_back(std::make_pair(high_vals[i], +1));
std::sort(temp.begin(), temp.end());
if (temp.size() == 0)
return std::make_pair(0,0);
double thresh = temp[0].first;
unsigned long num_low_wrong = low_vals.size();
unsigned long num_high_wrong = 0;
double low_error = num_low_wrong/(double)low_vals.size();
double high_error = num_high_wrong/(double)high_vals.size();
for (unsigned long i = 0; i < temp.size() && high_error < low_error; ++i)
{
thresh = temp[i].first;
if (temp[i].second > 0)
{
num_high_wrong++;
high_error = num_high_wrong/(double)high_vals.size();
}
else
{
num_low_wrong--;
low_error = num_low_wrong/(double)low_vals.size();
}
}
return std::make_pair((low_error+high_error)/2, thresh);
}
// ----------------------------------------------------------------------------------------
struct roc_point
@ -199,10 +156,15 @@ namespace dlib
std::vector<std::pair<double,int> > temp;
temp.reserve(true_detections.size()+false_detections.size());
// We use -1 for true labels and +1 for false so when we call std::sort() below it will sort
// runs with equal detection scores so true come first. This will avoid it seeming like we
// can separate true from false when scores are equal in the loop below.
const int true_label = -1;
const int false_label = +1;
for (unsigned long i = 0; i < true_detections.size(); ++i)
temp.push_back(std::make_pair(true_detections[i], +1));
temp.push_back(std::make_pair(true_detections[i], true_label));
for (unsigned long i = 0; i < false_detections.size(); ++i)
temp.push_back(std::make_pair(false_detections[i], -1));
temp.push_back(std::make_pair(false_detections[i], false_label));
std::sort(temp.rbegin(), temp.rend());
@ -214,7 +176,7 @@ namespace dlib
double num_true_included = 0;
for (unsigned long i = 0; i < temp.size(); ++i)
{
if (temp[i].second > 0)
if (temp[i].second == true_label)
num_true_included++;
else
num_false_included++;
@ -229,6 +191,39 @@ namespace dlib
return roc_curve;
}
// ----------------------------------------------------------------------------------------
inline std::pair<double,double> equal_error_rate (
const std::vector<double>& low_vals,
const std::vector<double>& high_vals
)
{
if (low_vals.size() == 0 && high_vals.size() == 0)
return std::make_pair(0,0);
else if (low_vals.size() == 0)
return std::make_pair(0, min(mat(high_vals)));
else if (high_vals.size() == 0)
return std::make_pair(0, max(mat(low_vals))+1);
// Find the point of equal error rates
double best_thresh = 0;
double best_error = 0;
double best_delta = std::numeric_limits<double>::infinity();
for (const auto& pt : compute_roc_curve(high_vals, low_vals))
{
const double false_negative_rate = 1-pt.true_positive_rate;
const double delta = std::abs(false_negative_rate - pt.false_positive_rate);
if (delta < best_delta)
{
best_delta = delta;
best_error = std::max(false_negative_rate, pt.false_positive_rate);
best_thresh = pt.detection_threshold;
}
}
return std::make_pair(best_error, best_thresh);
}
// ----------------------------------------------------------------------------------------
}

View File

@ -801,6 +801,40 @@ namespace
DLIB_TEST(equal_error_rate(vals2, vals1).first == 1);
}
void test_equal_error_rate()
{
auto result = equal_error_rate({}, {});
DLIB_TEST(result.first == 0);
DLIB_TEST(result.second == 0);
// no error case
result = equal_error_rate({1,1,1}, {2,2,2});
DLIB_TEST_MSG(result.first == 0, result.first);
DLIB_TEST_MSG(result.second == 2, result.second);
// max error case
result = equal_error_rate({2,2,2}, {1,1,1});
DLIB_TEST_MSG(result.first == 1, result.first);
DLIB_TEST_MSG(result.second == 2, result.second);
// Another way to have max error
result = equal_error_rate({1,1,1}, {1,1,1});
DLIB_TEST_MSG(result.second == 1, result.second);
DLIB_TEST_MSG(result.first == 1, result.first);
// wildly unbalanced
result = equal_error_rate({}, {1,1,1});
DLIB_TEST_MSG(result.first == 0, result.first);
// wildly unbalanced
result = equal_error_rate({1,1,1}, {});
DLIB_TEST_MSG(result.first == 0, result.first);
// 25% error case
result = equal_error_rate({1,1,1,3}, {2, 2, 0, 2});
DLIB_TEST_MSG(result.first == 0.25, result.first);
DLIB_TEST_MSG(result.second == 2, result.second);
}
void test_running_stats_decayed()
{
print_spinner();
@ -907,6 +941,7 @@ namespace
test_event_corr();
test_running_stats_decayed();
test_running_scalar_covariance_decayed();
test_equal_error_rate();
}
} a;