From a7ea7d00fe4b25cf6efa4868698e693dac4e52e9 Mon Sep 17 00:00:00 2001 From: Davis King Date: Wed, 18 Nov 2015 18:32:28 -0500 Subject: [PATCH] Implemented CPU version of tanh --- dlib/dnn/cpu_dlib.cpp | 13 +++++++++---- dlib/test/dnn.cpp | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/dlib/dnn/cpu_dlib.cpp b/dlib/dnn/cpu_dlib.cpp index ded0510a7..47e988dbe 100644 --- a/dlib/dnn/cpu_dlib.cpp +++ b/dlib/dnn/cpu_dlib.cpp @@ -650,8 +650,10 @@ namespace dlib const tensor& src ) { - // TODO - DLIB_CASSERT(false,""); + const auto d = dest.host(); + const auto s = src.host(); + for (size_t i = 0; i < src.size(); ++i) + d[i] = std::tanh(s[i]); } void tanh_gradient ( @@ -660,8 +662,11 @@ namespace dlib const tensor& gradient_input ) { - // TODO - DLIB_CASSERT(false,""); + const auto g = grad.host(); + const auto d = dest.host(); + const auto in = gradient_input.host(); + for (size_t i = 0; i < dest.size(); ++i) + g[i] = in[i]*(1-d[i]*d[i]); } // ------------------------------------------------------------------------------------ diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index a456db73b..4a90c52ae 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -39,6 +39,43 @@ namespace return max_error; } +// ---------------------------------------------------------------------------------------- + + void test_tanh() + { + print_spinner(); + resizable_tensor src(5,5), dest(5,5), gradient_input(5,5); + src = matrix_cast(gaussian_randm(5,5, 0)); + dest = matrix_cast(gaussian_randm(5,5, 1)); + gradient_input = matrix_cast(gaussian_randm(5,5, 2)); + + + + auto grad_src = [&](long idx) { + auto f = [&](float eps) { + const float old = src.host()[idx]; + src.host()[idx] += eps; + tanh(dest, src); + float result = dot(gradient_input, dest); + src.host()[idx] = old; + return result; + }; + const float eps = 0.01; + return (f(+eps)-f(-eps))/(2*eps); + }; + + resizable_tensor src_grad; + src_grad.copy_size(src); + src_grad = 0; + + tanh(dest, src); + tanh_gradient(src_grad, dest, gradient_input); + + auto grad_error = compare_gradients(src_grad, grad_src); + dlog << LINFO << "src error: " << grad_error; + DLIB_TEST(grad_error < 0.001); + } + void test_sigmoid() { print_spinner(); @@ -324,6 +361,7 @@ namespace void perform_test ( ) { + test_tanh(); test_softmax(); test_sigmoid(); test_batch_normalize();