diff --git a/dlib/matrix/matrix_la.h b/dlib/matrix/matrix_la.h index 700c7aea0..8b648dee2 100644 --- a/dlib/matrix/matrix_la.h +++ b/dlib/matrix/matrix_la.h @@ -1290,6 +1290,25 @@ convergence: return matrix_diag_op(op(reciprocal(diag(m)))); } +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const matrix_diag_op > pinv ( + const matrix_diag_exp& m, + double tol + ) + { + DLIB_ASSERT(tol >= 0, + "\tconst matrix_exp::type pinv(const matrix_exp& m)" + << "\n\t tol can't be negative" + << "\n\t tol: "< op; + return matrix_diag_op(op(reciprocal(round_zeros(diag(m),tol)))); + } + // ---------------------------------------------------------------------------------------- template @@ -1519,7 +1538,8 @@ convergence: typename EXP > const matrix pinv_helper ( - const matrix_exp& m + const matrix_exp& m, + double tol ) /*! ensures @@ -1541,8 +1561,8 @@ convergence: const double machine_eps = std::numeric_limits::epsilon(); // compute a reasonable epsilon below which we round to zero before doing the - // reciprocal - const double eps = machine_eps*std::max(m.nr(),m.nc())*max(w); + // reciprocal. Unless a non-zero tol is given then we just use tol. + const double eps = (tol!=0) ? tol : machine_eps*std::max(m.nr(),m.nc())*max(w); // now compute the pseudoinverse return tmp(scale_columns(v,reciprocal(round_zeros(w,eps))))*trans(u); @@ -1552,15 +1572,21 @@ convergence: typename EXP > const matrix pinv ( - const matrix_exp& m + const matrix_exp& m, + double tol = 0 ) { + DLIB_ASSERT(tol >= 0, + "\tconst matrix_exp::type pinv(const matrix_exp& m)" + << "\n\t tol can't be negative" + << "\n\t tol: "< m.nr()) - return trans(pinv_helper(trans(m))); + return trans(pinv_helper(trans(m),tol)); else - return pinv_helper(m); + return pinv_helper(m,tol); } // ---------------------------------------------------------------------------------------- diff --git a/dlib/matrix/matrix_la_abstract.h b/dlib/matrix/matrix_la_abstract.h index a942537c4..ea8fee5bf 100644 --- a/dlib/matrix/matrix_la_abstract.h +++ b/dlib/matrix/matrix_la_abstract.h @@ -31,12 +31,20 @@ namespace dlib // ---------------------------------------------------------------------------------------- const matrix pinv ( - const matrix_exp& m + const matrix_exp& m, + double tol = 0 ); /*! + requires + - tol >= 0 ensures - returns the Moore-Penrose pseudoinverse of m. - The returned matrix has m.nc() rows and m.nr() columns. + - if (tol == 0) then + - singular values less than max(m.nr(),m.nc()) times the machine epsilon + times the largest singular value are ignored. + - else + - singular values less than tol are ignored. !*/ // ---------------------------------------------------------------------------------------- diff --git a/dlib/test/matrix.cpp b/dlib/test/matrix.cpp index aefb3b33a..5409cd161 100644 --- a/dlib/test/matrix.cpp +++ b/dlib/test/matrix.cpp @@ -67,6 +67,40 @@ namespace DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix()))); DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m)))); DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m)))); + + mi = pinv(m,1e-12); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m)))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m)))); + + m = diagm(diag(m)); + mi = pinv(diagm(diag(m)),1e-12); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m)))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m)))); + + mi = pinv(m,0); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m)))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m)))); + + m = diagm(diag(m)); + mi = pinv(diagm(diag(m)),0); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m)))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m)))); } { matrix m(5,5);