mirror of https://github.com/davisking/dlib.git
remove branch from cuda kernel (#2045)
* remove branch from cuda kernel * promote lambda to a global function
This commit is contained in:
parent
57bb5eb58d
commit
d1d96e380c
|
@ -1405,7 +1405,7 @@ namespace dlib
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
float* out = grad.device();
|
float* out = grad.device();
|
||||||
const float *gi = gradient_input.device();
|
const float* gi = gradient_input.device();
|
||||||
if (out == gi)
|
if (out == gi)
|
||||||
{
|
{
|
||||||
launch_kernel(_cuda_leaky_relu_gradient_inplace, max_jobs(grad.size()),
|
launch_kernel(_cuda_leaky_relu_gradient_inplace, max_jobs(grad.size()),
|
||||||
|
@ -1440,31 +1440,29 @@ namespace dlib
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
__device__ float mish_compute_gradient(float x)
|
||||||
|
{
|
||||||
|
if (x >= 8)
|
||||||
|
return 1.f;
|
||||||
|
if (x <= -8)
|
||||||
|
return 0.f;
|
||||||
|
|
||||||
|
const auto e = std::exp(x);
|
||||||
|
const auto delta = 2*e + e*e + 2;
|
||||||
|
const auto omega = 4*(x + 1) + 4*e*e + e*e*e + e*(4*x + 6);
|
||||||
|
return e*omega/(delta*delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void _cuda_mish_gradient_inplace(float* out, const float* s, const float* gi, size_t n)
|
||||||
|
{
|
||||||
|
for (auto i : grid_stride_range(0, n))
|
||||||
|
out[i] = gi[i]*mish_compute_gradient(s[i]);
|
||||||
|
}
|
||||||
|
|
||||||
__global__ void _cuda_mish_gradient(float* out, const float* s, const float* gi, size_t n)
|
__global__ void _cuda_mish_gradient(float* out, const float* s, const float* gi, size_t n)
|
||||||
{
|
{
|
||||||
const auto calculate_gradient = [](float x)
|
for (auto i : grid_stride_range(0, n))
|
||||||
{
|
out[i] += gi[i]*mish_compute_gradient(s[i]);
|
||||||
if (x >= 8)
|
|
||||||
return 1.f;
|
|
||||||
if (x <= -8)
|
|
||||||
return 0.f;
|
|
||||||
|
|
||||||
const auto e = std::exp(x);
|
|
||||||
const auto delta = 2*e + e*e + 2;
|
|
||||||
const auto omega = 4*(x + 1) + 4*e*e + e*e*e + e*(4*x + 6);
|
|
||||||
return e*omega/(delta*delta);
|
|
||||||
};
|
|
||||||
|
|
||||||
if (out == gi)
|
|
||||||
{
|
|
||||||
for (auto i : grid_stride_range(0, n))
|
|
||||||
out[i] = gi[i]*calculate_gradient(s[i]);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
for (auto i : grid_stride_range(0, n))
|
|
||||||
out[i] += gi[i]*calculate_gradient(s[i]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void mish_gradient (
|
void mish_gradient (
|
||||||
|
@ -1473,7 +1471,12 @@ namespace dlib
|
||||||
const tensor& gradient_input
|
const tensor& gradient_input
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), grad.device(), src.device(), gradient_input.device(), grad.size());
|
float* out = grad.device();
|
||||||
|
const float* gi = gradient_input.device();
|
||||||
|
if (out == gi)
|
||||||
|
launch_kernel(_cuda_mish_gradient_inplace, max_jobs(grad.size()), out, src.device(), gi, grad.size());
|
||||||
|
else
|
||||||
|
launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
Loading…
Reference in New Issue