remove branch from cuda kernel (#2045)

* remove branch from cuda kernel

* promote lambda to a global function
This commit is contained in:
Adrià Arrufat 2020-04-01 08:33:25 +09:00 committed by GitHub
parent 57bb5eb58d
commit d1d96e380c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 28 additions and 25 deletions

View File

@ -1405,7 +1405,7 @@ namespace dlib
) )
{ {
float* out = grad.device(); float* out = grad.device();
const float *gi = gradient_input.device(); const float* gi = gradient_input.device();
if (out == gi) if (out == gi)
{ {
launch_kernel(_cuda_leaky_relu_gradient_inplace, max_jobs(grad.size()), launch_kernel(_cuda_leaky_relu_gradient_inplace, max_jobs(grad.size()),
@ -1440,31 +1440,29 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
__device__ float mish_compute_gradient(float x)
{
if (x >= 8)
return 1.f;
if (x <= -8)
return 0.f;
const auto e = std::exp(x);
const auto delta = 2*e + e*e + 2;
const auto omega = 4*(x + 1) + 4*e*e + e*e*e + e*(4*x + 6);
return e*omega/(delta*delta);
}
__global__ void _cuda_mish_gradient_inplace(float* out, const float* s, const float* gi, size_t n)
{
for (auto i : grid_stride_range(0, n))
out[i] = gi[i]*mish_compute_gradient(s[i]);
}
__global__ void _cuda_mish_gradient(float* out, const float* s, const float* gi, size_t n) __global__ void _cuda_mish_gradient(float* out, const float* s, const float* gi, size_t n)
{ {
const auto calculate_gradient = [](float x) for (auto i : grid_stride_range(0, n))
{ out[i] += gi[i]*mish_compute_gradient(s[i]);
if (x >= 8)
return 1.f;
if (x <= -8)
return 0.f;
const auto e = std::exp(x);
const auto delta = 2*e + e*e + 2;
const auto omega = 4*(x + 1) + 4*e*e + e*e*e + e*(4*x + 6);
return e*omega/(delta*delta);
};
if (out == gi)
{
for (auto i : grid_stride_range(0, n))
out[i] = gi[i]*calculate_gradient(s[i]);
}
else
{
for (auto i : grid_stride_range(0, n))
out[i] += gi[i]*calculate_gradient(s[i]);
}
} }
void mish_gradient ( void mish_gradient (
@ -1473,7 +1471,12 @@ namespace dlib
const tensor& gradient_input const tensor& gradient_input
) )
{ {
launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), grad.device(), src.device(), gradient_input.device(), grad.size()); float* out = grad.device();
const float* gi = gradient_input.device();
if (out == gi)
launch_kernel(_cuda_mish_gradient_inplace, max_jobs(grad.size()), out, src.device(), gi, grad.size());
else
launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size());
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------