From 00e18a8d595543e3b08bd598115b80165eb7dd8d Mon Sep 17 00:00:00 2001 From: Davis King Date: Mon, 9 Nov 2015 17:37:15 -0500 Subject: [PATCH] Added tests for some of the new tensor operators. --- dlib/test/dnn.cpp | 48 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index ccf1559fd..e7e5c9589 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -41,6 +41,7 @@ namespace void test_batch_normalize() { + print_spinner(); resizable_tensor src(5,5), gamma(1,5), beta(1,5), dest, means, vars, gradient_input(5,5); src = matrix_cast(gaussian_randm(5,5, 0)); gamma = matrix_cast(gaussian_randm(1,5, 1)); @@ -115,6 +116,7 @@ namespace void test_batch_normalize_conv() { + print_spinner(); resizable_tensor src(5,5,4,4), gamma(1,5), beta(1,5), dest, means, vars, gradient_input(5,5,4,4); src = matrix_cast(gaussian_randm(5,5*4*4, 0)); gamma = matrix_cast(gaussian_randm(1,5, 1)); @@ -190,6 +192,51 @@ namespace } +// ---------------------------------------------------------------------------------------- + + void test_basic_tensor_ops() + { + print_spinner(); + resizable_tensor dest, src(3,4), A(1,4), B(1,4); + src = 2; + affine_transform(dest, src, 2, 3); + dlog << LINFO << mat(dest); + matrix truth1(3,4), truth2(3,4); + + truth1 = 7; + truth2 = 7, 10, 7, 7, + 7, 10, 7, 7, + 7, 10, 7, 7; + DLIB_TEST(max(abs(truth1-mat(dest))) < 1e-5); + + A = 2; + B = 3; + A.host()[1] = 3; + B.host()[1] = 4; + dest = 0; + affine_transform(dest, src, A, B); + dlog << LINFO << mat(dest); + DLIB_TEST(max(abs(truth2-mat(dest))) < 1e-5); + + A.set_size(3,4); + B.set_size(3,4); + A = matrix_cast(gaussian_randm(3,4, 1)); + B = matrix_cast(gaussian_randm(3,4, 2)); + affine_transform(dest, src, A, B); + dlog << LINFO << mat(dest); + matrix truth3 = pointwise_multiply(mat(src), mat(A)) + mat(B); + DLIB_TEST(max(abs(truth3-mat(dest))) < 1e-5); + + matrix truth4 = pointwise_multiply(mat(A), mat(B)); + multiply(A, B); + DLIB_TEST(max(abs(truth4-mat(A))) < 1e-5); + + matrix truth5 = mat(B) > 0.1; + dlog << LINFO << truth5; + threshold(B, 0.1); + DLIB_TEST(max(abs(truth5-mat(B))) < 1e-5); + } + // ---------------------------------------------------------------------------------------- class dnn_tester : public tester @@ -206,6 +253,7 @@ namespace { test_batch_normalize(); test_batch_normalize_conv(); + test_basic_tensor_ops(); } } a;