mirror of https://github.com/davisking/dlib.git
Added CPU implementation of sigmoid() and sigmoid_gradient()
This commit is contained in:
parent
7c65c8d28a
commit
3124aa0dba
|
@ -523,8 +523,10 @@ namespace dlib
|
|||
const tensor& src
|
||||
)
|
||||
{
|
||||
// TODO
|
||||
DLIB_CASSERT(false,"");
|
||||
const auto d = dest.host();
|
||||
const auto s = src.host();
|
||||
for (size_t i = 0; i < src.size(); ++i)
|
||||
d[i] = 1/(1+std::exp(-s[i]));
|
||||
}
|
||||
|
||||
void sigmoid_gradient (
|
||||
|
@ -533,8 +535,11 @@ namespace dlib
|
|||
const tensor& gradient_input
|
||||
)
|
||||
{
|
||||
// TODO
|
||||
DLIB_CASSERT(false,"");
|
||||
const auto g = grad.host();
|
||||
const auto d = dest.host();
|
||||
const auto in = gradient_input.host();
|
||||
for (size_t i = 0; i < dest.size(); ++i)
|
||||
g[i] = in[i]*d[i]*(1-d[i]);
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
|
|
@ -39,6 +39,41 @@ namespace
|
|||
return max_error;
|
||||
}
|
||||
|
||||
void test_sigmoid()
|
||||
{
|
||||
print_spinner();
|
||||
resizable_tensor src(5,5), dest(5,5), gradient_input(5,5);
|
||||
src = matrix_cast<float>(gaussian_randm(5,5, 0));
|
||||
dest = matrix_cast<float>(gaussian_randm(5,5, 1));
|
||||
gradient_input = matrix_cast<float>(gaussian_randm(5,5, 2));
|
||||
|
||||
|
||||
|
||||
auto grad_src = [&](long idx) {
|
||||
auto f = [&](float eps) {
|
||||
const float old = src.host()[idx];
|
||||
src.host()[idx] += eps;
|
||||
sigmoid(dest, src);
|
||||
float result = dot(gradient_input, dest);
|
||||
src.host()[idx] = old;
|
||||
return result;
|
||||
};
|
||||
const float eps = 0.01;
|
||||
return (f(+eps)-f(-eps))/(2*eps);
|
||||
};
|
||||
|
||||
resizable_tensor src_grad;
|
||||
src_grad.copy_size(src);
|
||||
src_grad = 0;
|
||||
|
||||
sigmoid(dest, src);
|
||||
sigmoid_gradient(src_grad, dest, gradient_input);
|
||||
|
||||
auto grad_error = compare_gradients(src_grad, grad_src);
|
||||
dlog << LINFO << "src error: " << grad_error;
|
||||
DLIB_TEST(grad_error < 0.001);
|
||||
}
|
||||
|
||||
void test_batch_normalize()
|
||||
{
|
||||
print_spinner();
|
||||
|
@ -254,6 +289,7 @@ namespace
|
|||
void perform_test (
|
||||
)
|
||||
{
|
||||
test_sigmoid();
|
||||
test_batch_normalize();
|
||||
test_batch_normalize_conv();
|
||||
test_basic_tensor_ops();
|
||||
|
|
Loading…
Reference in New Issue