mirror of https://github.com/davisking/dlib.git
Implemented CPU version of tanh
This commit is contained in:
parent
9b36bb980e
commit
a7ea7d00fe
|
@ -650,8 +650,10 @@ namespace dlib
|
|||
const tensor& src
|
||||
)
|
||||
{
|
||||
// TODO
|
||||
DLIB_CASSERT(false,"");
|
||||
const auto d = dest.host();
|
||||
const auto s = src.host();
|
||||
for (size_t i = 0; i < src.size(); ++i)
|
||||
d[i] = std::tanh(s[i]);
|
||||
}
|
||||
|
||||
void tanh_gradient (
|
||||
|
@ -660,8 +662,11 @@ namespace dlib
|
|||
const tensor& gradient_input
|
||||
)
|
||||
{
|
||||
// TODO
|
||||
DLIB_CASSERT(false,"");
|
||||
const auto g = grad.host();
|
||||
const auto d = dest.host();
|
||||
const auto in = gradient_input.host();
|
||||
for (size_t i = 0; i < dest.size(); ++i)
|
||||
g[i] = in[i]*(1-d[i]*d[i]);
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
|
|
@ -39,6 +39,43 @@ namespace
|
|||
return max_error;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void test_tanh()
|
||||
{
|
||||
print_spinner();
|
||||
resizable_tensor src(5,5), dest(5,5), gradient_input(5,5);
|
||||
src = matrix_cast<float>(gaussian_randm(5,5, 0));
|
||||
dest = matrix_cast<float>(gaussian_randm(5,5, 1));
|
||||
gradient_input = matrix_cast<float>(gaussian_randm(5,5, 2));
|
||||
|
||||
|
||||
|
||||
auto grad_src = [&](long idx) {
|
||||
auto f = [&](float eps) {
|
||||
const float old = src.host()[idx];
|
||||
src.host()[idx] += eps;
|
||||
tanh(dest, src);
|
||||
float result = dot(gradient_input, dest);
|
||||
src.host()[idx] = old;
|
||||
return result;
|
||||
};
|
||||
const float eps = 0.01;
|
||||
return (f(+eps)-f(-eps))/(2*eps);
|
||||
};
|
||||
|
||||
resizable_tensor src_grad;
|
||||
src_grad.copy_size(src);
|
||||
src_grad = 0;
|
||||
|
||||
tanh(dest, src);
|
||||
tanh_gradient(src_grad, dest, gradient_input);
|
||||
|
||||
auto grad_error = compare_gradients(src_grad, grad_src);
|
||||
dlog << LINFO << "src error: " << grad_error;
|
||||
DLIB_TEST(grad_error < 0.001);
|
||||
}
|
||||
|
||||
void test_sigmoid()
|
||||
{
|
||||
print_spinner();
|
||||
|
@ -324,6 +361,7 @@ namespace
|
|||
void perform_test (
|
||||
)
|
||||
{
|
||||
test_tanh();
|
||||
test_softmax();
|
||||
test_sigmoid();
|
||||
test_batch_normalize();
|
||||
|
|
Loading…
Reference in New Issue