From a2e45f00b28583328aaa297fb9f43e90804c76e1 Mon Sep 17 00:00:00 2001 From: Davis King Date: Sat, 18 Apr 2020 13:57:56 -0400 Subject: [PATCH] Reduce code duplication a bit and make equal_error_rate() give correct results when called on data where all detection scores are identical. Previously it would say the EER was 0, but really it should have said 1 in this case. --- dlib/statistics/lda.h | 87 +++++++++++++++++++--------------------- dlib/test/statistics.cpp | 35 ++++++++++++++++ 2 files changed, 76 insertions(+), 46 deletions(-) diff --git a/dlib/statistics/lda.h b/dlib/statistics/lda.h index 38de3fd1e..36bbaf989 100644 --- a/dlib/statistics/lda.h +++ b/dlib/statistics/lda.h @@ -137,49 +137,6 @@ namespace dlib mean = X*mean; } -// ---------------------------------------------------------------------------------------- - - inline std::pair equal_error_rate ( - const std::vector& low_vals, - const std::vector& high_vals - ) - { - std::vector > temp; - temp.reserve(low_vals.size()+high_vals.size()); - for (unsigned long i = 0; i < low_vals.size(); ++i) - temp.push_back(std::make_pair(low_vals[i], -1)); - for (unsigned long i = 0; i < high_vals.size(); ++i) - temp.push_back(std::make_pair(high_vals[i], +1)); - - std::sort(temp.begin(), temp.end()); - - if (temp.size() == 0) - return std::make_pair(0,0); - - double thresh = temp[0].first; - - unsigned long num_low_wrong = low_vals.size(); - unsigned long num_high_wrong = 0; - double low_error = num_low_wrong/(double)low_vals.size(); - double high_error = num_high_wrong/(double)high_vals.size(); - for (unsigned long i = 0; i < temp.size() && high_error < low_error; ++i) - { - thresh = temp[i].first; - if (temp[i].second > 0) - { - num_high_wrong++; - high_error = num_high_wrong/(double)high_vals.size(); - } - else - { - num_low_wrong--; - low_error = num_low_wrong/(double)low_vals.size(); - } - } - - return std::make_pair((low_error+high_error)/2, thresh); - } - // ---------------------------------------------------------------------------------------- struct roc_point @@ -199,10 +156,15 @@ namespace dlib std::vector > temp; temp.reserve(true_detections.size()+false_detections.size()); + // We use -1 for true labels and +1 for false so when we call std::sort() below it will sort + // runs with equal detection scores so true come first. This will avoid it seeming like we + // can separate true from false when scores are equal in the loop below. + const int true_label = -1; + const int false_label = +1; for (unsigned long i = 0; i < true_detections.size(); ++i) - temp.push_back(std::make_pair(true_detections[i], +1)); + temp.push_back(std::make_pair(true_detections[i], true_label)); for (unsigned long i = 0; i < false_detections.size(); ++i) - temp.push_back(std::make_pair(false_detections[i], -1)); + temp.push_back(std::make_pair(false_detections[i], false_label)); std::sort(temp.rbegin(), temp.rend()); @@ -214,7 +176,7 @@ namespace dlib double num_true_included = 0; for (unsigned long i = 0; i < temp.size(); ++i) { - if (temp[i].second > 0) + if (temp[i].second == true_label) num_true_included++; else num_false_included++; @@ -229,6 +191,39 @@ namespace dlib return roc_curve; } +// ---------------------------------------------------------------------------------------- + + inline std::pair equal_error_rate ( + const std::vector& low_vals, + const std::vector& high_vals + ) + { + if (low_vals.size() == 0 && high_vals.size() == 0) + return std::make_pair(0,0); + else if (low_vals.size() == 0) + return std::make_pair(0, min(mat(high_vals))); + else if (high_vals.size() == 0) + return std::make_pair(0, max(mat(low_vals))+1); + + // Find the point of equal error rates + double best_thresh = 0; + double best_error = 0; + double best_delta = std::numeric_limits::infinity(); + for (const auto& pt : compute_roc_curve(high_vals, low_vals)) + { + const double false_negative_rate = 1-pt.true_positive_rate; + const double delta = std::abs(false_negative_rate - pt.false_positive_rate); + if (delta < best_delta) + { + best_delta = delta; + best_error = std::max(false_negative_rate, pt.false_positive_rate); + best_thresh = pt.detection_threshold; + } + } + + return std::make_pair(best_error, best_thresh); + } + // ---------------------------------------------------------------------------------------- } diff --git a/dlib/test/statistics.cpp b/dlib/test/statistics.cpp index 0394286ad..93d564b49 100644 --- a/dlib/test/statistics.cpp +++ b/dlib/test/statistics.cpp @@ -801,6 +801,40 @@ namespace DLIB_TEST(equal_error_rate(vals2, vals1).first == 1); } + void test_equal_error_rate() + { + auto result = equal_error_rate({}, {}); + DLIB_TEST(result.first == 0); + DLIB_TEST(result.second == 0); + + // no error case + result = equal_error_rate({1,1,1}, {2,2,2}); + DLIB_TEST_MSG(result.first == 0, result.first); + DLIB_TEST_MSG(result.second == 2, result.second); + + // max error case + result = equal_error_rate({2,2,2}, {1,1,1}); + DLIB_TEST_MSG(result.first == 1, result.first); + DLIB_TEST_MSG(result.second == 2, result.second); + // Another way to have max error + result = equal_error_rate({1,1,1}, {1,1,1}); + DLIB_TEST_MSG(result.second == 1, result.second); + DLIB_TEST_MSG(result.first == 1, result.first); + + // wildly unbalanced + result = equal_error_rate({}, {1,1,1}); + DLIB_TEST_MSG(result.first == 0, result.first); + + // wildly unbalanced + result = equal_error_rate({1,1,1}, {}); + DLIB_TEST_MSG(result.first == 0, result.first); + + // 25% error case + result = equal_error_rate({1,1,1,3}, {2, 2, 0, 2}); + DLIB_TEST_MSG(result.first == 0.25, result.first); + DLIB_TEST_MSG(result.second == 2, result.second); + } + void test_running_stats_decayed() { print_spinner(); @@ -907,6 +941,7 @@ namespace test_event_corr(); test_running_stats_decayed(); test_running_scalar_covariance_decayed(); + test_equal_error_rate(); } } a;