Fixed MISH activation with 2 thresholds in Softplus

This commit is contained in:
AlexeyAB 2019-11-22 14:20:53 +03:00
parent 7713a0209c
commit b9ca5ec781
3 changed files with 22 additions and 12 deletions

View File

@ -35,6 +35,11 @@ __device__ float relie_activate_kernel(float x){return (x>0) ? x : .01f*x;}
__device__ float ramp_activate_kernel(float x){return x*(x>0)+.1f*x;} __device__ float ramp_activate_kernel(float x){return x*(x>0)+.1f*x;}
__device__ float leaky_activate_kernel(float x){return (x>0) ? x : .1f*x;} __device__ float leaky_activate_kernel(float x){return (x>0) ? x : .1f*x;}
__device__ float tanh_activate_kernel(float x){return (2/(1 + expf(-2*x)) - 1);} __device__ float tanh_activate_kernel(float x){return (2/(1 + expf(-2*x)) - 1);}
__device__ float softplus_kernel(float x, float threshold = 20) {
if (x > threshold) return x; // too large
else if (x < -threshold) return expf(x); // too small
return logf(expf(x) + 1);
}
__device__ float plse_activate_kernel(float x) __device__ float plse_activate_kernel(float x)
{ {
if(x < -4) return .01f * (x + 4); if(x < -4) return .01f * (x + 4);
@ -207,11 +212,12 @@ __global__ void activate_array_mish_kernel(float *x, int n, float *activation_in
const float MISH_THRESHOLD = 20; const float MISH_THRESHOLD = 20;
float x_val = x[i]; float x_val = x[i];
activation_input[i] = x_val; // store value before activation activation_input[i] = x_val; // store value before activation
//output_gpu[i] = x_val * tanh_activate_kernel(log(1 + expf(x_val))); //output_gpu[i] = x_val * tanh_activate_kernel(logf(1 + expf(x_val)));
// https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L17-L20 // Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L17-L20
if (x_val < MISH_THRESHOLD) output_gpu[i] = x_val * tanh_activate_kernel(log(expf(x_val))); // TF: https://github.com/tensorflow/addons/blob/093cdfa85d334cbe19a37624c33198f3140109ed/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h#L40-L49
else output_gpu[i] = x_val * tanh_activate_kernel(x_val); // log1p(x) == log(x + 1)
output_gpu[i] = x_val * tanh_activate_kernel( softplus_kernel(x_val, MISH_THRESHOLD) );
} }
} }
@ -286,11 +292,12 @@ __global__ void gradient_array_mish_kernel(int n, float *activation_input_gpu, f
if (i < n) { if (i < n) {
const float MISH_THRESHOLD = 20.0f; const float MISH_THRESHOLD = 20.0f;
// implementation from TensorFlow: https://github.com/tensorflow/addons/commit/093cdfa85d334cbe19a37624c33198f3140109ed // implementation from TensorFlow: https://github.com/tensorflow/addons/blob/093cdfa85d334cbe19a37624c33198f3140109ed/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h#L66-L80
// implementation from Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L26-L31 // implementation from Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L26-L31
float inp = activation_input_gpu[i]; // log1p(x) == log(x + 1)
const float sp = (inp < MISH_THRESHOLD) ? log1p(exp(inp)) : inp; const float inp = activation_input_gpu[i];
const float grad_sp = 1 - exp(-sp); const float sp = softplus_kernel(inp, MISH_THRESHOLD);
const float grad_sp = 1 - expf(-sp);
const float tsp = tanh(sp); const float tsp = tanh(sp);
const float grad_tsp = (1 - tsp*tsp) * grad_sp; const float grad_tsp = (1 - tsp*tsp) * grad_sp;
const float grad = inp * grad_tsp + tsp; const float grad = inp * grad_tsp + tsp;

View File

@ -143,9 +143,7 @@ void activate_array_mish(float *x, const int n, float * activation_input, float
for (i = 0; i < n; ++i) { for (i = 0; i < n; ++i) {
float x_val = x[i]; float x_val = x[i];
activation_input[i] = x_val; // store value before activation activation_input[i] = x_val; // store value before activation
//output[i] = x_val * tanh_activate(log(1 + expf(x_val))); output[i] = x_val * tanh_activate( softplus_activate(x_val, MISH_THRESHOLD) );
if (x_val < MISH_THRESHOLD) output[i] = x_val * tanh_activate(log(expf(x_val)));
else output[i] = x_val * tanh_activate(x_val);
} }
} }
@ -215,7 +213,7 @@ void gradient_array_mish(const int n, const float * activation_input, float * de
// implementation from TensorFlow: https://github.com/tensorflow/addons/commit/093cdfa85d334cbe19a37624c33198f3140109ed // implementation from TensorFlow: https://github.com/tensorflow/addons/commit/093cdfa85d334cbe19a37624c33198f3140109ed
// implementation from Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L26-L31 // implementation from Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L26-L31
float inp = activation_input[i]; float inp = activation_input[i];
const float sp = (inp < MISH_THRESHOLD) ? log1p(exp(inp)) : inp; const float sp = softplus_activate(inp, MISH_THRESHOLD);
const float grad_sp = 1 - exp(-sp); const float grad_sp = 1 - exp(-sp);
const float tsp = tanh(sp); const float tsp = tanh(sp);
const float grad_tsp = (1 - tsp*tsp) * grad_sp; const float grad_tsp = (1 - tsp*tsp) * grad_sp;

View File

@ -53,6 +53,11 @@ static inline float relie_activate(float x){return (x>0) ? x : .01f*x;}
static inline float ramp_activate(float x){return x*(x>0)+.1f*x;} static inline float ramp_activate(float x){return x*(x>0)+.1f*x;}
static inline float leaky_activate(float x){return (x>0) ? x : .1f*x;} static inline float leaky_activate(float x){return (x>0) ? x : .1f*x;}
static inline float tanh_activate(float x){return (expf(2*x)-1)/(expf(2*x)+1);} static inline float tanh_activate(float x){return (expf(2*x)-1)/(expf(2*x)+1);}
static inline float softplus_activate(float x, float threshold) {
if (x > threshold) return x; // too large
else if (x < -threshold) return expf(x); // too small
return logf(expf(x) + 1);
}
static inline float plse_activate(float x) static inline float plse_activate(float x)
{ {
if(x < -4) return .01f * (x + 4); if(x < -4) return .01f * (x + 4);