aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAleksandr Nikolaev <aleksandr.nikolaev@arm.com>2021-05-04 16:46:27 +0100
committerAleksandr Nikolaev <aleksandr.nikolaev@arm.com>2021-05-05 12:13:54 +0000
commit7e9f34d219ae5dd3ddd5d26475f42aa02bcf010f (patch)
tree7dff0ce6bcd7579882388426327df78b2c6e1b41
parentc9309f22a026dfce92365e2f0802c40e8e1c449e (diff)
downloadComputeLibrary-7e9f34d219ae5dd3ddd5d26475f42aa02bcf010f.tar.gz
Fix for tanh at small argument values
x - x^3/3 is more accurate approximation for |x| < 0.005 than (exp2x - 1)/(exp2x + 1). Resolves: COMPMID-4098 Signed-off-by: Aleksandr Nikolaev <aleksandr.nikolaev@arm.com> Change-Id: If6f9d7ce4d8d00d36d2dada7ab8f8d9f5b58f5c0 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/321354 Tested-by: bsgcomp <bsgcomp@arm.com> Comments-Addressed: bsgcomp <bsgcomp@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5563 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/NEON/NEMath.inl11
1 files changed, 7 insertions, 4 deletions
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl
index 29df5433bb..5ac62badcc 100644
--- a/src/core/NEON/NEMath.inl
+++ b/src/core/NEON/NEMath.inl
@@ -190,12 +190,15 @@ inline float32x4_t vtanhq_f32(float32x4_t val)
static const float32x4_t CONST_2 = vdupq_n_f32(2.f);
static const float32x4_t CONST_MIN_TANH = vdupq_n_f32(-10.f);
static const float32x4_t CONST_MAX_TANH = vdupq_n_f32(10.f);
+ static const float32x4_t CONST_THR = vdupq_n_f32(5.e-3);
+ static const float32x4_t CONST_1_3 = vdupq_n_f32(0.3333333f);
float32x4_t x = vminq_f32(vmaxq_f32(val, CONST_MIN_TANH), CONST_MAX_TANH);
- float32x4_t exp2x = vexpq_f32(vmulq_f32(CONST_2, x));
- float32x4_t num = vsubq_f32(exp2x, CONST_1);
- float32x4_t den = vaddq_f32(exp2x, CONST_1);
- float32x4_t tanh = vmulq_f32(num, vinvq_f32(den));
+ // x * (1 - x^2/3) if |x| < 5.e-3 or (exp2x - 1) / (exp2x + 1) otherwise
+ float32x4_t exp2x = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vexpq_f32(vmulq_f32(CONST_2, x)), vmulq_f32(x, x));
+ float32x4_t num = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vsubq_f32(exp2x, CONST_1), vmulq_f32(CONST_1_3, exp2x));
+ float32x4_t den = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vaddq_f32(exp2x, CONST_1), vsubq_f32(CONST_1, num));
+ float32x4_t tanh = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vmulq_f32(num, vinvq_f32(den)), vmulq_f32(x, den));
return tanh;
}