diff options
author | Jonathan Deakin <jonathan.deakin@arm.com> | 2022-10-13 10:50:25 +0000 |
---|---|---|
committer | Pablo Marquez Tello <pablo.tello@arm.com> | 2022-10-24 09:38:39 +0000 |
commit | 2bc8cfe274b1f524013fbb6561930daca496b3ec (patch) | |
tree | 7db162020c02fd8c25743c4e58272f9c9047a2f6 /src/core/NEON | |
parent | a4f887021503507194774aeab3f76dca10888b97 (diff) | |
download | ComputeLibrary-2bc8cfe274b1f524013fbb6561930daca496b3ec.tar.gz |
Add FP16 tanh based on rational approximation
Use rational approximation with optimised coefficients
to calculate tanh_f16. Method is ~2.5x faster than previous
and has lower relative and absolute error.
This will fix https://github.com/ARM-software/ComputeLibrary/issues/1002
Credit to George Steed for suggesting use of vcageq instead of min+max
Signed-off-by: Jonathan Deakin <jonathan.deakin@arm.com>
Change-Id: Id70da3aab666c68b0d798266a837b59c00937bf7
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8480
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON')
-rw-r--r-- | src/core/NEON/NEMath.inl | 49 |
1 files changed, 38 insertions, 11 deletions
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl index 1755974d2d..acac36e1a5 100644 --- a/src/core/NEON/NEMath.inl +++ b/src/core/NEON/NEMath.inl @@ -518,17 +518,44 @@ inline float16x8_t vinvq_f16(float16x8_t x) return recip; } -inline float16x8_t vtanhq_f16(float16x8_t val) -{ - const float16x8_t CONST_MIN_TANH = vdupq_n_f16(-10.f); - const float16x8_t CONST_MAX_TANH = vdupq_n_f16(10.f); - const float16x8_t x = vminq_f16(vmaxq_f16(val, CONST_MIN_TANH), CONST_MAX_TANH); - const auto expx = vexpq_f16(x); - const auto expmx = vinvq_f16(expx); - const auto ab = vsubq_f16(expx, expmx); - const auto cd = vaddq_f16(expx, expmx); - const float16x8_t tanh = vdivq_f16(ab, cd); - return tanh; +inline float16x4_t vtanh_rational_approx_f16(float16x4_t x16) +{ + // Calculate rational approximation part of tanh exactly on a half-register of F16 by using F32s + // Note: doesn't handle overflows, needs truncating at |x| = 4.508 + const float32x4_t x = vcvt_f32_f16(x16); + + const float32x4_t ONE = vdupq_n_f32(1.0f); + const float32x4_t C1 = vdupq_n_f32(0.43760237f); + const float32x4_t C2 = vdupq_n_f32(0.104402f); + const float32x4_t C3 = vdupq_n_f32(0.013442706f); + const float32x4_t C4 = vdupq_n_f32(0.00073561433f); + + const float32x4_t x2 = vmulq_f32(x,x); + + // Denominator polynomial 1 + C1*x^2 + C3*x^4 + float32x4_t denom = vfmaq_f32(C1, C3, x2); + denom = vfmaq_f32(ONE, x2, denom); + + // Numerator polynomial x*(1 + C2*x^2 + C4*x^4) + float32x4_t numer = vfmaq_f32(C2, C4, x2); + numer = vfmaq_f32(ONE, x2, numer); + numer = vmulq_f32(numer, x); + + return vcvt_f16_f32(vdivq_f32(numer, denom)); +} + +inline float16x8_t vtanhq_f16(float16x8_t x) +{ + // Split into high/low and use rational approximation on both parts exactly + const float16x8_t tanh = vcombine_f16(vtanh_rational_approx_f16( vget_low_f16(x)), + vtanh_rational_approx_f16(vget_high_f16(x))); + + // tanh(x) == sign(x) to F16 precision for |x| >= 4.508, use sign after this + const float16x8_t ONE = vdupq_n_f16(1.0f); + const float16x8_t MAX_X = vdupq_n_f16(4.508f); + const auto at_limit = vcageq_f16(x, MAX_X); // |x| >= 4.508 + const float16x8_t sign_x = vbslq_f16(vclezq_f16(x), -ONE, ONE); + return vbslq_f16(at_limit, sign_x, tanh); } inline float16x8_t vtaylor_polyq_f16(float16x8_t x, const std::array<float16x8_t, 8> &coeffs) |