aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON
diff options
context:
space:
mode:
authorJonathan Deakin <jonathan.deakin@arm.com>2022-10-13 10:50:25 +0000
committerPablo Marquez Tello <pablo.tello@arm.com>2022-10-24 09:38:39 +0000
commit2bc8cfe274b1f524013fbb6561930daca496b3ec (patch)
tree7db162020c02fd8c25743c4e58272f9c9047a2f6 /src/core/NEON
parenta4f887021503507194774aeab3f76dca10888b97 (diff)
downloadComputeLibrary-2bc8cfe274b1f524013fbb6561930daca496b3ec.tar.gz
Add FP16 tanh based on rational approximation
Use rational approximation with optimised coefficients to calculate tanh_f16. Method is ~2.5x faster than previous and has lower relative and absolute error. This will fix https://github.com/ARM-software/ComputeLibrary/issues/1002 Credit to George Steed for suggesting use of vcageq instead of min+max Signed-off-by: Jonathan Deakin <jonathan.deakin@arm.com> Change-Id: Id70da3aab666c68b0d798266a837b59c00937bf7 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8480 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON')
-rw-r--r--src/core/NEON/NEMath.inl49
1 files changed, 38 insertions, 11 deletions
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl
index 1755974d2d..acac36e1a5 100644
--- a/src/core/NEON/NEMath.inl
+++ b/src/core/NEON/NEMath.inl
@@ -518,17 +518,44 @@ inline float16x8_t vinvq_f16(float16x8_t x)
return recip;
}
-inline float16x8_t vtanhq_f16(float16x8_t val)
-{
- const float16x8_t CONST_MIN_TANH = vdupq_n_f16(-10.f);
- const float16x8_t CONST_MAX_TANH = vdupq_n_f16(10.f);
- const float16x8_t x = vminq_f16(vmaxq_f16(val, CONST_MIN_TANH), CONST_MAX_TANH);
- const auto expx = vexpq_f16(x);
- const auto expmx = vinvq_f16(expx);
- const auto ab = vsubq_f16(expx, expmx);
- const auto cd = vaddq_f16(expx, expmx);
- const float16x8_t tanh = vdivq_f16(ab, cd);
- return tanh;
+inline float16x4_t vtanh_rational_approx_f16(float16x4_t x16)
+{
+ // Calculate rational approximation part of tanh exactly on a half-register of F16 by using F32s
+ // Note: doesn't handle overflows, needs truncating at |x| = 4.508
+ const float32x4_t x = vcvt_f32_f16(x16);
+
+ const float32x4_t ONE = vdupq_n_f32(1.0f);
+ const float32x4_t C1 = vdupq_n_f32(0.43760237f);
+ const float32x4_t C2 = vdupq_n_f32(0.104402f);
+ const float32x4_t C3 = vdupq_n_f32(0.013442706f);
+ const float32x4_t C4 = vdupq_n_f32(0.00073561433f);
+
+ const float32x4_t x2 = vmulq_f32(x,x);
+
+ // Denominator polynomial 1 + C1*x^2 + C3*x^4
+ float32x4_t denom = vfmaq_f32(C1, C3, x2);
+ denom = vfmaq_f32(ONE, x2, denom);
+
+ // Numerator polynomial x*(1 + C2*x^2 + C4*x^4)
+ float32x4_t numer = vfmaq_f32(C2, C4, x2);
+ numer = vfmaq_f32(ONE, x2, numer);
+ numer = vmulq_f32(numer, x);
+
+ return vcvt_f16_f32(vdivq_f32(numer, denom));
+}
+
+inline float16x8_t vtanhq_f16(float16x8_t x)
+{
+ // Split into high/low and use rational approximation on both parts exactly
+ const float16x8_t tanh = vcombine_f16(vtanh_rational_approx_f16( vget_low_f16(x)),
+ vtanh_rational_approx_f16(vget_high_f16(x)));
+
+ // tanh(x) == sign(x) to F16 precision for |x| >= 4.508, use sign after this
+ const float16x8_t ONE = vdupq_n_f16(1.0f);
+ const float16x8_t MAX_X = vdupq_n_f16(4.508f);
+ const auto at_limit = vcageq_f16(x, MAX_X); // |x| >= 4.508
+ const float16x8_t sign_x = vbslq_f16(vclezq_f16(x), -ONE, ONE);
+ return vbslq_f16(at_limit, sign_x, tanh);
}
inline float16x8_t vtaylor_polyq_f16(float16x8_t x, const std::array<float16x8_t, 8> &coeffs)