From 7e9f34d219ae5dd3ddd5d26475f42aa02bcf010f Mon Sep 17 00:00:00 2001 From: Aleksandr Nikolaev Date: Tue, 4 May 2021 16:46:27 +0100 Subject: Fix for tanh at small argument values x - x^3/3 is more accurate approximation for |x| < 0.005 than (exp2x - 1)/(exp2x + 1). Resolves: COMPMID-4098 Signed-off-by: Aleksandr Nikolaev Change-Id: If6f9d7ce4d8d00d36d2dada7ab8f8d9f5b58f5c0 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/321354 Tested-by: bsgcomp Comments-Addressed: bsgcomp Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5563 Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- src/core/NEON/NEMath.inl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl index 29df5433bb..5ac62badcc 100644 --- a/src/core/NEON/NEMath.inl +++ b/src/core/NEON/NEMath.inl @@ -190,12 +190,15 @@ inline float32x4_t vtanhq_f32(float32x4_t val) static const float32x4_t CONST_2 = vdupq_n_f32(2.f); static const float32x4_t CONST_MIN_TANH = vdupq_n_f32(-10.f); static const float32x4_t CONST_MAX_TANH = vdupq_n_f32(10.f); + static const float32x4_t CONST_THR = vdupq_n_f32(5.e-3); + static const float32x4_t CONST_1_3 = vdupq_n_f32(0.3333333f); float32x4_t x = vminq_f32(vmaxq_f32(val, CONST_MIN_TANH), CONST_MAX_TANH); - float32x4_t exp2x = vexpq_f32(vmulq_f32(CONST_2, x)); - float32x4_t num = vsubq_f32(exp2x, CONST_1); - float32x4_t den = vaddq_f32(exp2x, CONST_1); - float32x4_t tanh = vmulq_f32(num, vinvq_f32(den)); + // x * (1 - x^2/3) if |x| < 5.e-3 or (exp2x - 1) / (exp2x + 1) otherwise + float32x4_t exp2x = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vexpq_f32(vmulq_f32(CONST_2, x)), vmulq_f32(x, x)); + float32x4_t num = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vsubq_f32(exp2x, CONST_1), vmulq_f32(CONST_1_3, exp2x)); + float32x4_t den = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vaddq_f32(exp2x, CONST_1), vsubq_f32(CONST_1, num)); + float32x4_t tanh = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vmulq_f32(num, vinvq_f32(den)), vmulq_f32(x, den)); return tanh; } -- cgit v1.2.1