mirror of https://github.com/davisking/dlib.git
Reduce code duplication a bit and make equal_error_rate() give correct results when called on data where all detection scores are identical.
Previously it would say the EER was 0, but really it should have said 1 in this case.
This commit is contained in:
parent
0e923cff93
commit
a2e45f00b2
|
@ -137,49 +137,6 @@ namespace dlib
|
|||
mean = X*mean;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
inline std::pair<double,double> equal_error_rate (
|
||||
const std::vector<double>& low_vals,
|
||||
const std::vector<double>& high_vals
|
||||
)
|
||||
{
|
||||
std::vector<std::pair<double,int> > temp;
|
||||
temp.reserve(low_vals.size()+high_vals.size());
|
||||
for (unsigned long i = 0; i < low_vals.size(); ++i)
|
||||
temp.push_back(std::make_pair(low_vals[i], -1));
|
||||
for (unsigned long i = 0; i < high_vals.size(); ++i)
|
||||
temp.push_back(std::make_pair(high_vals[i], +1));
|
||||
|
||||
std::sort(temp.begin(), temp.end());
|
||||
|
||||
if (temp.size() == 0)
|
||||
return std::make_pair(0,0);
|
||||
|
||||
double thresh = temp[0].first;
|
||||
|
||||
unsigned long num_low_wrong = low_vals.size();
|
||||
unsigned long num_high_wrong = 0;
|
||||
double low_error = num_low_wrong/(double)low_vals.size();
|
||||
double high_error = num_high_wrong/(double)high_vals.size();
|
||||
for (unsigned long i = 0; i < temp.size() && high_error < low_error; ++i)
|
||||
{
|
||||
thresh = temp[i].first;
|
||||
if (temp[i].second > 0)
|
||||
{
|
||||
num_high_wrong++;
|
||||
high_error = num_high_wrong/(double)high_vals.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
num_low_wrong--;
|
||||
low_error = num_low_wrong/(double)low_vals.size();
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair((low_error+high_error)/2, thresh);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
struct roc_point
|
||||
|
@ -199,10 +156,15 @@ namespace dlib
|
|||
|
||||
std::vector<std::pair<double,int> > temp;
|
||||
temp.reserve(true_detections.size()+false_detections.size());
|
||||
// We use -1 for true labels and +1 for false so when we call std::sort() below it will sort
|
||||
// runs with equal detection scores so true come first. This will avoid it seeming like we
|
||||
// can separate true from false when scores are equal in the loop below.
|
||||
const int true_label = -1;
|
||||
const int false_label = +1;
|
||||
for (unsigned long i = 0; i < true_detections.size(); ++i)
|
||||
temp.push_back(std::make_pair(true_detections[i], +1));
|
||||
temp.push_back(std::make_pair(true_detections[i], true_label));
|
||||
for (unsigned long i = 0; i < false_detections.size(); ++i)
|
||||
temp.push_back(std::make_pair(false_detections[i], -1));
|
||||
temp.push_back(std::make_pair(false_detections[i], false_label));
|
||||
|
||||
std::sort(temp.rbegin(), temp.rend());
|
||||
|
||||
|
@ -214,7 +176,7 @@ namespace dlib
|
|||
double num_true_included = 0;
|
||||
for (unsigned long i = 0; i < temp.size(); ++i)
|
||||
{
|
||||
if (temp[i].second > 0)
|
||||
if (temp[i].second == true_label)
|
||||
num_true_included++;
|
||||
else
|
||||
num_false_included++;
|
||||
|
@ -229,6 +191,39 @@ namespace dlib
|
|||
return roc_curve;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
inline std::pair<double,double> equal_error_rate (
|
||||
const std::vector<double>& low_vals,
|
||||
const std::vector<double>& high_vals
|
||||
)
|
||||
{
|
||||
if (low_vals.size() == 0 && high_vals.size() == 0)
|
||||
return std::make_pair(0,0);
|
||||
else if (low_vals.size() == 0)
|
||||
return std::make_pair(0, min(mat(high_vals)));
|
||||
else if (high_vals.size() == 0)
|
||||
return std::make_pair(0, max(mat(low_vals))+1);
|
||||
|
||||
// Find the point of equal error rates
|
||||
double best_thresh = 0;
|
||||
double best_error = 0;
|
||||
double best_delta = std::numeric_limits<double>::infinity();
|
||||
for (const auto& pt : compute_roc_curve(high_vals, low_vals))
|
||||
{
|
||||
const double false_negative_rate = 1-pt.true_positive_rate;
|
||||
const double delta = std::abs(false_negative_rate - pt.false_positive_rate);
|
||||
if (delta < best_delta)
|
||||
{
|
||||
best_delta = delta;
|
||||
best_error = std::max(false_negative_rate, pt.false_positive_rate);
|
||||
best_thresh = pt.detection_threshold;
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(best_error, best_thresh);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
|
|
@ -801,6 +801,40 @@ namespace
|
|||
DLIB_TEST(equal_error_rate(vals2, vals1).first == 1);
|
||||
}
|
||||
|
||||
void test_equal_error_rate()
|
||||
{
|
||||
auto result = equal_error_rate({}, {});
|
||||
DLIB_TEST(result.first == 0);
|
||||
DLIB_TEST(result.second == 0);
|
||||
|
||||
// no error case
|
||||
result = equal_error_rate({1,1,1}, {2,2,2});
|
||||
DLIB_TEST_MSG(result.first == 0, result.first);
|
||||
DLIB_TEST_MSG(result.second == 2, result.second);
|
||||
|
||||
// max error case
|
||||
result = equal_error_rate({2,2,2}, {1,1,1});
|
||||
DLIB_TEST_MSG(result.first == 1, result.first);
|
||||
DLIB_TEST_MSG(result.second == 2, result.second);
|
||||
// Another way to have max error
|
||||
result = equal_error_rate({1,1,1}, {1,1,1});
|
||||
DLIB_TEST_MSG(result.second == 1, result.second);
|
||||
DLIB_TEST_MSG(result.first == 1, result.first);
|
||||
|
||||
// wildly unbalanced
|
||||
result = equal_error_rate({}, {1,1,1});
|
||||
DLIB_TEST_MSG(result.first == 0, result.first);
|
||||
|
||||
// wildly unbalanced
|
||||
result = equal_error_rate({1,1,1}, {});
|
||||
DLIB_TEST_MSG(result.first == 0, result.first);
|
||||
|
||||
// 25% error case
|
||||
result = equal_error_rate({1,1,1,3}, {2, 2, 0, 2});
|
||||
DLIB_TEST_MSG(result.first == 0.25, result.first);
|
||||
DLIB_TEST_MSG(result.second == 2, result.second);
|
||||
}
|
||||
|
||||
void test_running_stats_decayed()
|
||||
{
|
||||
print_spinner();
|
||||
|
@ -907,6 +941,7 @@ namespace
|
|||
test_event_corr();
|
||||
test_running_stats_decayed();
|
||||
test_running_scalar_covariance_decayed();
|
||||
test_equal_error_rate();
|
||||
}
|
||||
} a;
|
||||
|
||||
|
|
Loading…
Reference in New Issue