From 2bc8cfe274b1f524013fbb6561930daca496b3ec Mon Sep 17 00:00:00 2001 From: Jonathan Deakin Date: Thu, 13 Oct 2022 10:50:25 +0000 Subject: Add FP16 tanh based on rational approximation Use rational approximation with optimised coefficients to calculate tanh_f16. Method is ~2.5x faster than previous and has lower relative and absolute error. This will fix https://github.com/ARM-software/ComputeLibrary/issues/1002 Credit to George Steed for suggesting use of vcageq instead of min+max Signed-off-by: Jonathan Deakin Change-Id: Id70da3aab666c68b0d798266a837b59c00937bf7 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8480 Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez Tello Reviewed-by: Viet-Hoa Do Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- src/core/NEON/NEMath.inl | 49 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 11 deletions(-) (limited to 'src/core/NEON') diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl index 1755974d2d..acac36e1a5 100644 --- a/src/core/NEON/NEMath.inl +++ b/src/core/NEON/NEMath.inl @@ -518,17 +518,44 @@ inline float16x8_t vinvq_f16(float16x8_t x) return recip; } -inline float16x8_t vtanhq_f16(float16x8_t val) -{ - const float16x8_t CONST_MIN_TANH = vdupq_n_f16(-10.f); - const float16x8_t CONST_MAX_TANH = vdupq_n_f16(10.f); - const float16x8_t x = vminq_f16(vmaxq_f16(val, CONST_MIN_TANH), CONST_MAX_TANH); - const auto expx = vexpq_f16(x); - const auto expmx = vinvq_f16(expx); - const auto ab = vsubq_f16(expx, expmx); - const auto cd = vaddq_f16(expx, expmx); - const float16x8_t tanh = vdivq_f16(ab, cd); - return tanh; +inline float16x4_t vtanh_rational_approx_f16(float16x4_t x16) +{ + // Calculate rational approximation part of tanh exactly on a half-register of F16 by using F32s + // Note: doesn't handle overflows, needs truncating at |x| = 4.508 + const float32x4_t x = vcvt_f32_f16(x16); + + const float32x4_t ONE = vdupq_n_f32(1.0f); + const float32x4_t C1 = vdupq_n_f32(0.43760237f); + const float32x4_t C2 = vdupq_n_f32(0.104402f); + const float32x4_t C3 = vdupq_n_f32(0.013442706f); + const float32x4_t C4 = vdupq_n_f32(0.00073561433f); + + const float32x4_t x2 = vmulq_f32(x,x); + + // Denominator polynomial 1 + C1*x^2 + C3*x^4 + float32x4_t denom = vfmaq_f32(C1, C3, x2); + denom = vfmaq_f32(ONE, x2, denom); + + // Numerator polynomial x*(1 + C2*x^2 + C4*x^4) + float32x4_t numer = vfmaq_f32(C2, C4, x2); + numer = vfmaq_f32(ONE, x2, numer); + numer = vmulq_f32(numer, x); + + return vcvt_f16_f32(vdivq_f32(numer, denom)); +} + +inline float16x8_t vtanhq_f16(float16x8_t x) +{ + // Split into high/low and use rational approximation on both parts exactly + const float16x8_t tanh = vcombine_f16(vtanh_rational_approx_f16( vget_low_f16(x)), + vtanh_rational_approx_f16(vget_high_f16(x))); + + // tanh(x) == sign(x) to F16 precision for |x| >= 4.508, use sign after this + const float16x8_t ONE = vdupq_n_f16(1.0f); + const float16x8_t MAX_X = vdupq_n_f16(4.508f); + const auto at_limit = vcageq_f16(x, MAX_X); // |x| >= 4.508 + const float16x8_t sign_x = vbslq_f16(vclezq_f16(x), -ONE, ONE); + return vbslq_f16(at_limit, sign_x, tanh); } inline float16x8_t vtaylor_polyq_f16(float16x8_t x, const std::array &coeffs) -- cgit v1.2.1