diff --git a/dlib/statistics/average_precision.h b/dlib/statistics/average_precision.h index 978a2c4c6..87408904d 100644 --- a/dlib/statistics/average_precision.h +++ b/dlib/statistics/average_precision.h @@ -28,17 +28,28 @@ namespace dlib ) { using namespace dlib::impl; - double precision_sum = 0; double relevant_count = 0; + // find the precision values + std::vector precision; for (unsigned long i = 0; i < items.size(); ++i) { if (get_bool_part(items[i])) { ++relevant_count; - precision_sum += relevant_count / (i+1); + precision.push_back(relevant_count / (i+1)); } } + double precision_sum = 0; + double max_val = 0; + // now sum over the interpolated precision values + for (std::vector::reverse_iterator i = precision.rbegin(); i != precision.rend(); ++i) + { + max_val = std::max(max_val, *i); + precision_sum += max_val; + } + + relevant_count += missing_relevant_items; if (relevant_count != 0) diff --git a/dlib/statistics/average_precision_abstract.h b/dlib/statistics/average_precision_abstract.h index 147c3ea8f..daf95619b 100644 --- a/dlib/statistics/average_precision_abstract.h +++ b/dlib/statistics/average_precision_abstract.h @@ -29,6 +29,12 @@ namespace dlib the second true has a precision of 0.5, giving an average of 0.75). - As a special case, if item contains no true elements then the average precision is considered to be 1. + - Note that we use the interpolated precision. That is, the interpolated + precision at a recall value r is set to the maximum precision obtained at any + higher recall value. Or in other words, we interpolate the precision/recall + curve so that precision is monotonically decreasing. Therefore, the average + precision value returned by this function is the area under this interpolated + precision/recall curve. - This function will add in missing_relevant_items number of items with a precision of zero into the average value returned. For example, the average precision of the ranking [true, true] if there are 2 missing relevant items diff --git a/dlib/test/statistics.cpp b/dlib/test/statistics.cpp index ab2e4c3e7..d7e0d7cbe 100644 --- a/dlib/test/statistics.cpp +++ b/dlib/test/statistics.cpp @@ -446,6 +446,11 @@ namespace items.push_back(true); DLIB_TEST(std::abs(average_precision(items) - (2.0+3.0/4.0)/3.0) < 1e-14); + + items.push_back(true); + + DLIB_TEST(std::abs(average_precision(items) - (2.0 + 4.0/5.0 + 4.0/5.0)/4.0) < 1e-14); + DLIB_TEST(std::abs(average_precision(items,1) - (2.0 + 4.0/5.0 + 4.0/5.0)/5.0) < 1e-14); } void perform_test (