From 298d3a4a56ec5a86fcf6eb78a4be20899ccc9bfb Mon Sep 17 00:00:00 2001 From: Davis King Date: Sat, 10 Jan 2015 11:53:29 -0500 Subject: [PATCH 1/2] Fixed compute_lda_transform() so it works properly when the class covariance matrices are singular even after performing PCA. --- dlib/statistics/lda.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlib/statistics/lda.h b/dlib/statistics/lda.h index c81d7b08b..241e54f06 100644 --- a/dlib/statistics/lda.h +++ b/dlib/statistics/lda.h @@ -120,7 +120,7 @@ namespace dlib matrix W; svd3(Sw, A, W, H); W = sqrt(W); - W = reciprocal(round_zeros(W,max(W)*1e-5)); + W = reciprocal(lowerbound(W,max(W)*1e-5)); A = trans(H*diagm(W))*Sb*H*diagm(W); matrix v,s,u; svd3(A, v, s, u); From 5a2cfe7e8133d702f23b8eaa4bbb9eb2a13406cc Mon Sep 17 00:00:00 2001 From: Davis King Date: Sat, 10 Jan 2015 12:15:19 -0500 Subject: [PATCH 2/2] Added a test for compute_lda_transform() --- dlib/test/statistics.cpp | 54 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/dlib/test/statistics.cpp b/dlib/test/statistics.cpp index 2a28850db..203f4542c 100644 --- a/dlib/test/statistics.cpp +++ b/dlib/test/statistics.cpp @@ -733,6 +733,59 @@ namespace } + void test_lda () + { + // This test makes sure we pick the right direction in a simple 2D -> 1D LDA + typedef matrix sample_type; + + std::vector labels; + std::vector samples; + for (int i=0; i<4; i++) + { + sample_type s; + s(0) = i; + s(1) = i+1; + samples.push_back(s); + labels.push_back(1); + + sample_type s1; + s1(0) = i+1; + s1(1) = i; + samples.push_back(s1); + labels.push_back(2); + } + + matrix X; + X.set_size(8,2); + for (int i=0; i<8; i++){ + X(i,0) = samples[i](0); + X(i,1) = samples[i](1); + } + + matrix mean; + + dlib::compute_lda_transform(X,mean,labels,1); + + std::vector vals1, vals2; + for (unsigned long i = 0; i < samples.size(); ++i) + { + double val = X*samples[i]-mean; + if (i%2 == 0) + vals1.push_back(val); + else + vals2.push_back(val); + dlog << LINFO << "1D LDA output: " << val; + } + + if (vals1[0] > vals2[0]) + swap(vals1, vals2); + + const double err = equal_error_rate(vals1, vals2).first; + dlog << LINFO << "LDA ERR: " << err; + DLIB_TEST(err == 0); + DLIB_TEST(equal_error_rate(vals2, vals1).first == 1); + } + void perform_test ( ) { @@ -753,6 +806,7 @@ namespace test_randomize_samples2(); another_test(); test_average_precision(); + test_lda(); } } a;