diff --git a/dlib/statistics/statistics.h b/dlib/statistics/statistics.h index 492ddf8e7..914f5482f 100644 --- a/dlib/statistics/statistics.h +++ b/dlib/statistics/statistics.h @@ -1827,20 +1827,20 @@ namespace dlib // ---------------------------------------------------------------------------------------- inline double binomial_random_vars_are_different ( - uint64_t k1, - uint64_t n1, - uint64_t k2, - uint64_t n2 + double k1, + double n1, + double k2, + double n2 ) { DLIB_ASSERT(k1 <= n1, "k1: "<< k1 << " n1: "<< n1); DLIB_ASSERT(k2 <= n2, "k2: "<< k2 << " n2: "<< n2); - const double p1 = k1/(double)n1; - const double p2 = k2/(double)n2; - const double p = (k1+k2)/(double)(n1+n2); + const double p1 = n1 != 0 ? k1/n1 : 0; + const double p2 = n2 != 0 ? k2/n2 : 0; + const double p = (k1+k2)/(n1+n2); - auto ll = [](double p, uint64_t k, uint64_t n) { + auto ll = [](double p, double k, double n) { if (p == 0 || p == 1) return 0.0; return k*std::log(p) + (n-k)*std::log(1-p); @@ -1860,10 +1860,10 @@ namespace dlib // ---------------------------------------------------------------------------------------- inline double event_correlation ( - uint64_t A_count, - uint64_t B_count, - uint64_t AB_count, - uint64_t total_num_observations + double A_count, + double B_count, + double AB_count, + double total_num_observations ) { DLIB_ASSERT(AB_count <= A_count && A_count <= total_num_observations, diff --git a/dlib/statistics/statistics_abstract.h b/dlib/statistics/statistics_abstract.h index b5738196d..24432ded2 100644 --- a/dlib/statistics/statistics_abstract.h +++ b/dlib/statistics/statistics_abstract.h @@ -108,10 +108,10 @@ namespace dlib // ---------------------------------------------------------------------------------------- double binomial_random_vars_are_different ( - uint64_t k1, - uint64_t n1, - uint64_t k2, - uint64_t n2 + double k1, + double n1, + double k2, + double n2 ); /*! requires @@ -138,10 +138,10 @@ namespace dlib // ---------------------------------------------------------------------------------------- double event_correlation ( - uint64_t A_count, - uint64_t B_count, - uint64_t AB_count, - uint64_t total_num_observations + double A_count, + double B_count, + double AB_count, + double total_num_observations ); /*! requires