- Added the tmax, tmin, and tabs templates

- Changed the diag() function so that it is allowed to take
     non-square matrices.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402697
This commit is contained in:
Davis King 2008-12-03 23:36:11 +00:00
parent cae2fb2867
commit 7206a12c44
4 changed files with 56 additions and 29 deletions

View File

@ -568,6 +568,52 @@ namespace dlib
inline double put_in_range(const double& a, const double& b, const double& val)
{ return put_in_range<double>(a,b,val); }
// ----------------------------------------------------------------------------------------
/*! tabs
This is a template to compute the absolute value a number at compile time.
For example,
abs<-4>::value == 4
abs<4>::value == 4
!*/
template <long x, typename enabled=void>
struct tabs { const static long value = x; };
template <long x>
struct tabs<x,typename enable_if_c<(x < 0)>::type> { const static long value = -x; };
// ----------------------------------------------------------------------------------------
/*! tmax
This is a template to compute the max of two values at compile time
For example,
abs<4,7>::value == 7
!*/
template <long x, long y, typename enabled=void>
struct tmax { const static long value = x; };
template <long x, long y>
struct tmax<x,y,typename enable_if_c<(y > x)>::type> { const static long value = y; };
// ----------------------------------------------------------------------------------------
/*! tmin
This is a template to compute the min of two values at compile time
For example,
abs<4,7>::value == 4
!*/
template <long x, long y, typename enabled=void>
struct tmin { const static long value = x; };
template <long x, long y>
struct tmin<x,y,typename enable_if_c<(y < x)>::type> { const static long value = y; };
// ----------------------------------------------------------------------------------------
/*!A is_function

View File

@ -2948,7 +2948,7 @@ convergence:
template <typename EXP>
struct op : has_destructive_aliasing
{
const static long NR = EXP::NC;
const static long NR = (EXP::NC&&EXP::NR)? (tmin<EXP::NR,EXP::NC>::value) : (0);
const static long NC = 1;
typedef typename EXP::type type;
typedef typename EXP::mem_manager_type mem_manager_type;
@ -2957,7 +2957,7 @@ convergence:
{ return m(r,r); }
template <typename M>
static long nr (const M& m) { return m.nr(); }
static long nr (const M& m) { return std::min(m.nc(),m.nr()); }
template <typename M>
static long nc (const M& m) { return 1; }
};
@ -2970,14 +2970,6 @@ convergence:
const matrix_exp<EXP>& m
)
{
// You can only get the diagonal for square matrices.
COMPILE_TIME_ASSERT(EXP::NR == EXP::NC);
DLIB_ASSERT(m.nr() == m.nc(),
"\tconst matrix_exp diag(const matrix_exp& m)"
<< "\n\tYou can only apply diag() to a square matrix"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
);
typedef matrix_unary_exp<matrix_exp<EXP>,op_diag> exp;
return matrix_exp<exp>(exp(m));
}

View File

@ -8,6 +8,7 @@
#include "../matrix.h"
#include "../rand.h"
#include "../enable_if.h"
#include "../algs.h"
#include "quantum_computing_abstract.h"
namespace dlib
@ -34,22 +35,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
// This is a template to compute the absolute value a number at compile time
template <long x, typename enabled=void>
struct abs { const static long value = x; };
template <long x>
struct abs<x,typename enable_if_c<(x < 0)>::type> { const static long value = -x; };
// ------------------------------------------------------------------------------------
// This is a template to compute the max of two values at compile time
template <long x, long y, typename enabled=void>
struct max { const static long value = x; };
template <long x, long y>
struct max<x,y,typename enable_if_c<(y > x)>::type> { const static long value = y; };
// ------------------------------------------------------------------------------------
}
typedef std::complex<double> qc_scalar_type;
@ -657,7 +642,7 @@ namespace dlib
target_mask <<= 1;
}
static const long num_bits = qc_helpers::abs<control_bit-target_bit>::value+1;
static const long num_bits = tabs<control_bit-target_bit>::value+1;
static const long dims = qc_helpers::exp_2_n<num_bits>::value;
const qc_scalar_type operator() (long r, long c) const
@ -742,8 +727,8 @@ namespace dlib
target_mask <<= 1;
}
static const long num_bits = qc_helpers::max<qc_helpers::abs<control_bit1-target_bit>::value,
qc_helpers::abs<control_bit2-target_bit>::value>::value+1;
static const long num_bits = tmax<tabs<control_bit1-target_bit>::value,
tabs<control_bit2-target_bit>::value>::value+1;
static const long dims = qc_helpers::exp_2_n<num_bits>::value;
const qc_scalar_type operator() (long r, long c) const

View File

@ -120,6 +120,10 @@ namespace
mrc.set_size(3,4);
set_all_elements(mrc,1);
DLIB_CASSERT(diag(mrc) == uniform_matrix<double>(3,1,1),"");
DLIB_CASSERT(diag(matrix<double>(mrc)) == uniform_matrix<double>(3,1,1),"");
matrix<double,2,3> mrc2;
set_all_elements(mrc2,1);
DLIB_CASSERT((removerc<1,1>(mrc) == mrc2),"");