aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON')
-rw-r--r--src/core/NEON/NEMath.inl49
1 files changed, 38 insertions, 11 deletions
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl
index 1755974d2d..acac36e1a5 100644
--- a/src/core/NEON/NEMath.inl
+++ b/src/core/NEON/NEMath.inl
@@ -518,17 +518,44 @@ inline float16x8_t vinvq_f16(float16x8_t x)
return recip;
}
-inline float16x8_t vtanhq_f16(float16x8_t val)
-{
- const float16x8_t CONST_MIN_TANH = vdupq_n_f16(-10.f);
- const float16x8_t CONST_MAX_TANH = vdupq_n_f16(10.f);
- const float16x8_t x = vminq_f16(vmaxq_f16(val, CONST_MIN_TANH), CONST_MAX_TANH);
- const auto expx = vexpq_f16(x);
- const auto expmx = vinvq_f16(expx);
- const auto ab = vsubq_f16(expx, expmx);
- const auto cd = vaddq_f16(expx, expmx);
- const float16x8_t tanh = vdivq_f16(ab, cd);
- return tanh;
+inline float16x4_t vtanh_rational_approx_f16(float16x4_t x16)
+{
+ // Calculate rational approximation part of tanh exactly on a half-register of F16 by using F32s
+ // Note: doesn't handle overflows, needs truncating at |x| = 4.508
+ const float32x4_t x = vcvt_f32_f16(x16);
+
+ const float32x4_t ONE = vdupq_n_f32(1.0f);
+ const float32x4_t C1 = vdupq_n_f32(0.43760237f);
+ const float32x4_t C2 = vdupq_n_f32(0.104402f);
+ const float32x4_t C3 = vdupq_n_f32(0.013442706f);
+ const float32x4_t C4 = vdupq_n_f32(0.00073561433f);
+
+ const float32x4_t x2 = vmulq_f32(x,x);
+
+ // Denominator polynomial 1 + C1*x^2 + C3*x^4
+ float32x4_t denom = vfmaq_f32(C1, C3, x2);
+ denom = vfmaq_f32(ONE, x2, denom);
+
+ // Numerator polynomial x*(1 + C2*x^2 + C4*x^4)
+ float32x4_t numer = vfmaq_f32(C2, C4, x2);
+ numer = vfmaq_f32(ONE, x2, numer);
+ numer = vmulq_f32(numer, x);
+
+ return vcvt_f16_f32(vdivq_f32(numer, denom));
+}
+
+inline float16x8_t vtanhq_f16(float16x8_t x)
+{
+ // Split into high/low and use rational approximation on both parts exactly
+ const float16x8_t tanh = vcombine_f16(vtanh_rational_approx_f16( vget_low_f16(x)),
+ vtanh_rational_approx_f16(vget_high_f16(x)));
+
+ // tanh(x) == sign(x) to F16 precision for |x| >= 4.508, use sign after this
+ const float16x8_t ONE = vdupq_n_f16(1.0f);
+ const float16x8_t MAX_X = vdupq_n_f16(4.508f);
+ const auto at_limit = vcageq_f16(x, MAX_X); // |x| >= 4.508
+ const float16x8_t sign_x = vbslq_f16(vclezq_f16(x), -ONE, ONE);
+ return vbslq_f16(at_limit, sign_x, tanh);
}
inline float16x8_t vtaylor_polyq_f16(float16x8_t x, const std::array<float16x8_t, 8> &coeffs)