mirror of https://github.com/davisking/dlib.git
remove branch from cuda kernel (#2045)
* remove branch from cuda kernel * promote lambda to a global function
This commit is contained in:
parent
57bb5eb58d
commit
d1d96e380c
|
@ -1405,7 +1405,7 @@ namespace dlib
|
|||
)
|
||||
{
|
||||
float* out = grad.device();
|
||||
const float *gi = gradient_input.device();
|
||||
const float* gi = gradient_input.device();
|
||||
if (out == gi)
|
||||
{
|
||||
launch_kernel(_cuda_leaky_relu_gradient_inplace, max_jobs(grad.size()),
|
||||
|
@ -1440,31 +1440,29 @@ namespace dlib
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
__device__ float mish_compute_gradient(float x)
|
||||
{
|
||||
if (x >= 8)
|
||||
return 1.f;
|
||||
if (x <= -8)
|
||||
return 0.f;
|
||||
|
||||
const auto e = std::exp(x);
|
||||
const auto delta = 2*e + e*e + 2;
|
||||
const auto omega = 4*(x + 1) + 4*e*e + e*e*e + e*(4*x + 6);
|
||||
return e*omega/(delta*delta);
|
||||
}
|
||||
|
||||
__global__ void _cuda_mish_gradient_inplace(float* out, const float* s, const float* gi, size_t n)
|
||||
{
|
||||
for (auto i : grid_stride_range(0, n))
|
||||
out[i] = gi[i]*mish_compute_gradient(s[i]);
|
||||
}
|
||||
|
||||
__global__ void _cuda_mish_gradient(float* out, const float* s, const float* gi, size_t n)
|
||||
{
|
||||
const auto calculate_gradient = [](float x)
|
||||
{
|
||||
if (x >= 8)
|
||||
return 1.f;
|
||||
if (x <= -8)
|
||||
return 0.f;
|
||||
|
||||
const auto e = std::exp(x);
|
||||
const auto delta = 2*e + e*e + 2;
|
||||
const auto omega = 4*(x + 1) + 4*e*e + e*e*e + e*(4*x + 6);
|
||||
return e*omega/(delta*delta);
|
||||
};
|
||||
|
||||
if (out == gi)
|
||||
{
|
||||
for (auto i : grid_stride_range(0, n))
|
||||
out[i] = gi[i]*calculate_gradient(s[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto i : grid_stride_range(0, n))
|
||||
out[i] += gi[i]*calculate_gradient(s[i]);
|
||||
}
|
||||
for (auto i : grid_stride_range(0, n))
|
||||
out[i] += gi[i]*mish_compute_gradient(s[i]);
|
||||
}
|
||||
|
||||
void mish_gradient (
|
||||
|
@ -1473,7 +1471,12 @@ namespace dlib
|
|||
const tensor& gradient_input
|
||||
)
|
||||
{
|
||||
launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), grad.device(), src.device(), gradient_input.device(), grad.size());
|
||||
float* out = grad.device();
|
||||
const float* gi = gradient_input.device();
|
||||
if (out == gi)
|
||||
launch_kernel(_cuda_mish_gradient_inplace, max_jobs(grad.size()), out, src.device(), gi, grad.size());
|
||||
else
|
||||
launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size());
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in New Issue