Fixed mish gradient

This commit is contained in:
AlexeyAB 2020-05-12 20:15:10 +03:00
parent 2500278784
commit bef28445e5
1 changed files with 6 additions and 4 deletions

View File

@ -40,7 +40,8 @@ __device__ float gelu_activate_kernel(float x){return (0.5*x*(1 + tanhf(0.797885
__device__ float softplus_kernel(float x, float threshold = 20) {
if (x > threshold) return x; // too large
else if (x < -threshold) return expf(x); // too small
return logf(expf(x) + 1);
return log1pf(expf(x));
//return logf(expf(x) + 1);
}
__device__ float plse_activate_kernel(float x)
{
@ -257,8 +258,8 @@ __global__ void activate_array_mish_kernel(float *x, int n, float *activation_in
// Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L17-L20
// TF: https://github.com/tensorflow/addons/blob/093cdfa85d334cbe19a37624c33198f3140109ed/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h#L40-L49
// log1p(x) == log(x + 1)
output_gpu[i] = x_val * tanh_activate_kernel( softplus_kernel(x_val, MISH_THRESHOLD) );
//output_gpu[i] = mish_yashas(x_val);
//output_gpu[i] = x_val * tanh_activate_kernel( softplus_kernel(x_val, MISH_THRESHOLD) );
output_gpu[i] = mish_yashas(x_val);
//output_gpu[i] = mish_njuffa(x_val);
}
}
@ -355,7 +356,8 @@ __global__ void gradient_array_mish_kernel(int n, float *activation_input_gpu, f
// log1p(x) == log(x + 1)
const float inp = activation_input_gpu[i];
const float sp = softplus_kernel(inp, MISH_THRESHOLD);
const float grad_sp = 1 - expf(-sp);
const float grad_sp = -expm1f(-sp);
//const float grad_sp = 1 - expf(-sp);
const float tsp = tanh(sp);
const float grad_tsp = (1 - tsp*tsp) * grad_sp;
const float grad = inp * grad_tsp + tsp;