Changed average_precision() to use interpolated precision. So now it uses the

same metric as the one used by the Pascal VOC.
This commit is contained in:
Davis King 2013-03-31 11:12:43 -04:00
parent 7ff4f6f485
commit 238effb9c6
3 changed files with 24 additions and 2 deletions

View File

@ -28,17 +28,28 @@ namespace dlib
)
{
using namespace dlib::impl;
double precision_sum = 0;
double relevant_count = 0;
// find the precision values
std::vector<double> precision;
for (unsigned long i = 0; i < items.size(); ++i)
{
if (get_bool_part(items[i]))
{
++relevant_count;
precision_sum += relevant_count / (i+1);
precision.push_back(relevant_count / (i+1));
}
}
double precision_sum = 0;
double max_val = 0;
// now sum over the interpolated precision values
for (std::vector<double>::reverse_iterator i = precision.rbegin(); i != precision.rend(); ++i)
{
max_val = std::max(max_val, *i);
precision_sum += max_val;
}
relevant_count += missing_relevant_items;
if (relevant_count != 0)

View File

@ -29,6 +29,12 @@ namespace dlib
the second true has a precision of 0.5, giving an average of 0.75).
- As a special case, if item contains no true elements then the average
precision is considered to be 1.
- Note that we use the interpolated precision. That is, the interpolated
precision at a recall value r is set to the maximum precision obtained at any
higher recall value. Or in other words, we interpolate the precision/recall
curve so that precision is monotonically decreasing. Therefore, the average
precision value returned by this function is the area under this interpolated
precision/recall curve.
- This function will add in missing_relevant_items number of items with a
precision of zero into the average value returned. For example, the average
precision of the ranking [true, true] if there are 2 missing relevant items

View File

@ -446,6 +446,11 @@ namespace
items.push_back(true);
DLIB_TEST(std::abs(average_precision(items) - (2.0+3.0/4.0)/3.0) < 1e-14);
items.push_back(true);
DLIB_TEST(std::abs(average_precision(items) - (2.0 + 4.0/5.0 + 4.0/5.0)/4.0) < 1e-14);
DLIB_TEST(std::abs(average_precision(items,1) - (2.0 + 4.0/5.0 + 4.0/5.0)/5.0) < 1e-14);
}
void perform_test (