Added the running_scalar_covariance object.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403807
This commit is contained in:
Davis King 2010-08-14 18:42:40 +00:00
parent c0a50d8092
commit 7b90dccb7d
3 changed files with 341 additions and 0 deletions

View File

@ -167,6 +167,182 @@ namespace dlib
T max_value;
};
// ----------------------------------------------------------------------------------------
template <
typename T
>
class running_scalar_covariance
{
public:
running_scalar_covariance()
{
clear();
COMPILE_TIME_ASSERT ((
is_same_type<float,T>::value ||
is_same_type<double,T>::value ||
is_same_type<long double,T>::value
));
}
void clear()
{
sum_xy = 0;
sum_x = 0;
sum_y = 0;
sum_xx = 0;
sum_yy = 0;
n = 0;
}
void add (
const T& x,
const T& y
)
{
sum_xy += x*y;
sum_xx += x*x;
sum_yy += y*y;
sum_x += x;
sum_y += y;
n += 1;
}
T current_n (
) const
{
return n;
}
T mean_x (
) const
{
if (n != 0)
return sum_x/n;
else
return 0;
}
T mean_y (
) const
{
if (n != 0)
return sum_y/n;
else
return 0;
}
T covariance (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 1,
"\tT running_scalar_covariance::covariance()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
T temp = 1/(n-1) * (sum_xy - sum_y*sum_x/n);
// make sure the variance is never negative. This might
// happen due to numerical errors.
if (temp >= 0)
return temp;
else
return 0;
}
T correlation (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 1,
"\tT running_scalar_covariance::correlation()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
return covariance() / std::sqrt(variance_x()*variance_y());
}
T variance_x (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 1,
"\tT running_scalar_covariance::variance_x()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
T temp = 1/(n-1) * (sum_xx - sum_x*sum_x/n);
// make sure the variance is never negative. This might
// happen due to numerical errors.
if (temp >= 0)
return temp;
else
return 0;
}
T variance_y (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 1,
"\tT running_scalar_covariance::variance_y()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
T temp = 1/(n-1) * (sum_yy - sum_y*sum_y/n);
// make sure the variance is never negative. This might
// happen due to numerical errors.
if (temp >= 0)
return temp;
else
return 0;
}
T stddev_x (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 1,
"\tT running_scalar_covariance::stddev_x()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
return std::sqrt(variance_x());
}
T stddev_y (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 1,
"\tT running_scalar_covariance::stddev_y()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
return std::sqrt(variance_y());
}
private:
T sum_xy;
T sum_x;
T sum_y;
T sum_xx;
T sum_yy;
T n;
};
// ----------------------------------------------------------------------------------------
template <

View File

@ -157,6 +157,139 @@ namespace dlib
!*/
};
// ----------------------------------------------------------------------------------------
template <
typename T
>
class running_scalar_covariance
{
/*!
REQUIREMENTS ON T
- T must be a float, double, or long double type
INITIAL VALUE
- mean_x() == 0
- mean_y() == 0
- current_n() == 0
WHAT THIS OBJECT REPRESENTS
This object represents something that can compute the running covariance
of a stream of real number pairs.
!*/
public:
running_scalar_covariance(
);
/*!
ensures
- this object is properly initialized
!*/
void clear(
);
/*!
ensures
- this object has its initial value
- clears all memory of any previous data points
!*/
void add (
const T& x,
const T& y
);
/*!
ensures
- updates the statistics stored in this object so that
the new pair (x,y) is factored into them.
- #current_n() == current_n() + 1
!*/
T current_n (
) const;
/*!
ensures
- returns the number of points given to this object so far.
!*/
T mean_x (
) const;
/*!
ensures
- returns the mean value of all x samples presented to this object
via add().
!*/
T mean_y (
) const;
/*!
ensures
- returns the mean value of all y samples presented to this object
via add().
!*/
T covariance (
) const;
/*!
requires
- current_n() > 1
ensures
- returns the covariance between all the x and y samples presented
to this object via add()
!*/
T correlation (
) const;
/*!
requires
- current_n() > 1
ensures
- returns the correlation coefficient between all the x and y samples
presented to this object via add()
!*/
T variance_x (
) const;
/*!
requires
- current_n() > 1
ensures
- returns the unbiased sample variance value of all x samples presented
to this object via add().
!*/
T variance_y (
) const;
/*!
requires
- current_n() > 1
ensures
- returns the unbiased sample variance value of all y samples presented
to this object via add().
!*/
T stddev_x (
) const;
/*!
requires
- current_n() > 1
ensures
- returns the unbiased sample standard deviation of all x samples
presented to this object via add().
!*/
T stddev_y (
) const;
/*!
requires
- current_n() > 1
ensures
- returns the unbiased sample standard deviation of all y samples
presented to this object via add().
!*/
};
// ----------------------------------------------------------------------------------------
template <

View File

@ -191,12 +191,44 @@ namespace
}
void test_running_stats()
{
print_spinner();
running_stats<double> rs;
running_scalar_covariance<double> rsc1, rsc2;
for (double i = 0; i < 100; ++i)
{
rs.add(i);
rsc1.add(i,i);
rsc2.add(i,i);
rsc2.add(i,-i);
}
// make sure the running_stats and running_scalar_covariance agree
DLIB_TEST_MSG(std::abs(rs.mean() - rsc1.mean_x()) < 1e-10, std::abs(rs.mean() - rsc1.mean_x()));
DLIB_TEST(std::abs(rs.mean() - rsc1.mean_y()) < 1e-10);
DLIB_TEST(std::abs(rs.stddev() - rsc1.stddev_x()) < 1e-10);
DLIB_TEST(std::abs(rs.stddev() - rsc1.stddev_y()) < 1e-10);
DLIB_TEST(std::abs(rs.variance() - rsc1.variance_x()) < 1e-10);
DLIB_TEST(std::abs(rs.variance() - rsc1.variance_y()) < 1e-10);
DLIB_TEST(rs.current_n() == rsc1.current_n());
DLIB_TEST(std::abs(rsc1.correlation() - 1) < 1e-10);
DLIB_TEST(std::abs(rsc2.correlation() - 0) < 1e-10);
}
void perform_test (
)
{
test_random_subset_selector();
test_random_subset_selector2();
test_running_covariance();
test_running_stats();
}
} a;