aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/NEON/NEMath.inl11
1 files changed, 7 insertions, 4 deletions
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl
index 29df5433bb..5ac62badcc 100644
--- a/src/core/NEON/NEMath.inl
+++ b/src/core/NEON/NEMath.inl
@@ -190,12 +190,15 @@ inline float32x4_t vtanhq_f32(float32x4_t val)
static const float32x4_t CONST_2 = vdupq_n_f32(2.f);
static const float32x4_t CONST_MIN_TANH = vdupq_n_f32(-10.f);
static const float32x4_t CONST_MAX_TANH = vdupq_n_f32(10.f);
+ static const float32x4_t CONST_THR = vdupq_n_f32(5.e-3);
+ static const float32x4_t CONST_1_3 = vdupq_n_f32(0.3333333f);
float32x4_t x = vminq_f32(vmaxq_f32(val, CONST_MIN_TANH), CONST_MAX_TANH);
- float32x4_t exp2x = vexpq_f32(vmulq_f32(CONST_2, x));
- float32x4_t num = vsubq_f32(exp2x, CONST_1);
- float32x4_t den = vaddq_f32(exp2x, CONST_1);
- float32x4_t tanh = vmulq_f32(num, vinvq_f32(den));
+ // x * (1 - x^2/3) if |x| < 5.e-3 or (exp2x - 1) / (exp2x + 1) otherwise
+ float32x4_t exp2x = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vexpq_f32(vmulq_f32(CONST_2, x)), vmulq_f32(x, x));
+ float32x4_t num = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vsubq_f32(exp2x, CONST_1), vmulq_f32(CONST_1_3, exp2x));
+ float32x4_t den = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vaddq_f32(exp2x, CONST_1), vsubq_f32(CONST_1, num));
+ float32x4_t tanh = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vmulq_f32(num, vinvq_f32(den)), vmulq_f32(x, den));
return tanh;
}