Updated the equal() function so that it can compare complex matrices.

I also changed a matrix test case to be more robust.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403075
This commit is contained in:
Davis King 2009-05-29 17:08:00 +00:00
parent eeb0b8061d
commit f36501d761
3 changed files with 74 additions and 4 deletions

View File

@ -16,10 +16,32 @@
#include "matrix_expressions.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
/*!A is_complex
This is a template that can be used to determine if a type is a specialization
of the std::complex template class.
For example:
is_complex<float>::value == false
is_complex<std::complex<float> >::value == true
!*/
template <typename T>
struct is_complex { static const bool value = false; };
template <typename T>
struct is_complex<std::complex<T> > { static const bool value = true; };
template <typename T>
struct is_complex<std::complex<T>& > { static const bool value = true; };
template <typename T>
struct is_complex<const std::complex<T>& > { static const bool value = true; };
template <typename T>
struct is_complex<const std::complex<T> > { static const bool value = true; };
// ----------------------------------------------------------------------------------------
template <typename EXP>
@ -1657,7 +1679,7 @@ namespace dlib
typename EXP1,
typename EXP2
>
bool equal (
typename disable_if<is_complex<typename EXP1::type>,bool>::type equal (
const matrix_exp<EXP1>& a,
const matrix_exp<EXP2>& b,
const typename EXP1::type eps = 100*std::numeric_limits<typename EXP1::type>::epsilon()
@ -1680,6 +1702,34 @@ namespace dlib
return true;
}
template <
typename EXP1,
typename EXP2
>
typename enable_if<is_complex<typename EXP1::type>,bool>::type equal (
const matrix_exp<EXP1>& a,
const matrix_exp<EXP2>& b,
const typename EXP1::type::value_type eps = 100*std::numeric_limits<typename EXP1::type::value_type>::epsilon()
)
{
// check if the dimensions don't match
if (a.nr() != b.nr() || a.nc() != b.nc())
return false;
for (long r = 0; r < a.nr(); ++r)
{
for (long c = 0; c < a.nc(); ++c)
{
if (std::abs(real(a(r,c)-b(r,c))) > eps ||
std::abs(imag(a(r,c)-b(r,c))) > eps)
return false;
}
}
// no non-equal points found so we return true
return true;
}
// ----------------------------------------------------------------------------------------
struct op_scale_columns

View File

@ -501,6 +501,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// if matrix_exp contains non-complex types (e.g. float, double)
bool equal (
const matrix_exp& a,
const matrix_exp& b,
@ -516,6 +517,25 @@ namespace dlib
- returns true
!*/
// ----------------------------------------------------------------------------------------
// if matrix_exp contains std::complex types
bool equal (
const matrix_exp& a,
const matrix_exp& b,
const matrix_exp::type::value_type epsilon = 100*std::numeric_limits<matrix_exp::type::value_type>::epsilon()
);
/*!
ensures
- if (a and b don't have the same dimensions) then
- returns false
- else if (there exists an r and c such that abs(real(a(r,c)-b(r,c))) > epsilon
or abs(imag(a(r,c)-b(r,c))) > epsilon) then
- returns false
- else
- returns true
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp pointwise_multiply (

View File

@ -857,7 +857,7 @@ namespace
m = val1;
m2 = val2;
DLIB_TEST(reciprocal(m) == m2);
DLIB_TEST(equal(reciprocal(m) , m2));
}
{
matrix<complex<float> > m(2,2), m2(2,2);
@ -865,7 +865,7 @@ namespace
m = val1;
m2 = val2;
DLIB_TEST(reciprocal(m) == m2);
DLIB_TEST(equal(reciprocal(m) , m2));
}
{