Make event_correlation() work on fractional counts

This commit is contained in:
Davis King 2024-10-22 22:26:14 -04:00
parent 6d29e0c7d4
commit 39240959fa
2 changed files with 20 additions and 20 deletions

View File

@ -1827,20 +1827,20 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline double binomial_random_vars_are_different ( inline double binomial_random_vars_are_different (
uint64_t k1, double k1,
uint64_t n1, double n1,
uint64_t k2, double k2,
uint64_t n2 double n2
) )
{ {
DLIB_ASSERT(k1 <= n1, "k1: "<< k1 << " n1: "<< n1); DLIB_ASSERT(k1 <= n1, "k1: "<< k1 << " n1: "<< n1);
DLIB_ASSERT(k2 <= n2, "k2: "<< k2 << " n2: "<< n2); DLIB_ASSERT(k2 <= n2, "k2: "<< k2 << " n2: "<< n2);
const double p1 = k1/(double)n1; const double p1 = n1 != 0 ? k1/n1 : 0;
const double p2 = k2/(double)n2; const double p2 = n2 != 0 ? k2/n2 : 0;
const double p = (k1+k2)/(double)(n1+n2); const double p = (k1+k2)/(n1+n2);
auto ll = [](double p, uint64_t k, uint64_t n) { auto ll = [](double p, double k, double n) {
if (p == 0 || p == 1) if (p == 0 || p == 1)
return 0.0; return 0.0;
return k*std::log(p) + (n-k)*std::log(1-p); return k*std::log(p) + (n-k)*std::log(1-p);
@ -1860,10 +1860,10 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline double event_correlation ( inline double event_correlation (
uint64_t A_count, double A_count,
uint64_t B_count, double B_count,
uint64_t AB_count, double AB_count,
uint64_t total_num_observations double total_num_observations
) )
{ {
DLIB_ASSERT(AB_count <= A_count && A_count <= total_num_observations, DLIB_ASSERT(AB_count <= A_count && A_count <= total_num_observations,

View File

@ -108,10 +108,10 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
double binomial_random_vars_are_different ( double binomial_random_vars_are_different (
uint64_t k1, double k1,
uint64_t n1, double n1,
uint64_t k2, double k2,
uint64_t n2 double n2
); );
/*! /*!
requires requires
@ -138,10 +138,10 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
double event_correlation ( double event_correlation (
uint64_t A_count, double A_count,
uint64_t B_count, double B_count,
uint64_t AB_count, double AB_count,
uint64_t total_num_observations double total_num_observations
); );
/*! /*!
requires requires