diff options
author | Aleksandr Nikolaev <aleksandr.nikolaev@arm.com> | 2021-05-04 16:46:27 +0100 |
---|---|---|
committer | Aleksandr Nikolaev <aleksandr.nikolaev@arm.com> | 2021-05-05 12:13:54 +0000 |
commit | 7e9f34d219ae5dd3ddd5d26475f42aa02bcf010f (patch) | |
tree | 7dff0ce6bcd7579882388426327df78b2c6e1b41 /src | |
parent | c9309f22a026dfce92365e2f0802c40e8e1c449e (diff) | |
download | ComputeLibrary-7e9f34d219ae5dd3ddd5d26475f42aa02bcf010f.tar.gz |
Fix for tanh at small argument values
x - x^3/3 is more accurate approximation for |x| < 0.005
than (exp2x - 1)/(exp2x + 1).
Resolves: COMPMID-4098
Signed-off-by: Aleksandr Nikolaev <aleksandr.nikolaev@arm.com>
Change-Id: If6f9d7ce4d8d00d36d2dada7ab8f8d9f5b58f5c0
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/321354
Tested-by: bsgcomp <bsgcomp@arm.com>
Comments-Addressed: bsgcomp <bsgcomp@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5563
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r-- | src/core/NEON/NEMath.inl | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl index 29df5433bb..5ac62badcc 100644 --- a/src/core/NEON/NEMath.inl +++ b/src/core/NEON/NEMath.inl @@ -190,12 +190,15 @@ inline float32x4_t vtanhq_f32(float32x4_t val) static const float32x4_t CONST_2 = vdupq_n_f32(2.f); static const float32x4_t CONST_MIN_TANH = vdupq_n_f32(-10.f); static const float32x4_t CONST_MAX_TANH = vdupq_n_f32(10.f); + static const float32x4_t CONST_THR = vdupq_n_f32(5.e-3); + static const float32x4_t CONST_1_3 = vdupq_n_f32(0.3333333f); float32x4_t x = vminq_f32(vmaxq_f32(val, CONST_MIN_TANH), CONST_MAX_TANH); - float32x4_t exp2x = vexpq_f32(vmulq_f32(CONST_2, x)); - float32x4_t num = vsubq_f32(exp2x, CONST_1); - float32x4_t den = vaddq_f32(exp2x, CONST_1); - float32x4_t tanh = vmulq_f32(num, vinvq_f32(den)); + // x * (1 - x^2/3) if |x| < 5.e-3 or (exp2x - 1) / (exp2x + 1) otherwise + float32x4_t exp2x = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vexpq_f32(vmulq_f32(CONST_2, x)), vmulq_f32(x, x)); + float32x4_t num = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vsubq_f32(exp2x, CONST_1), vmulq_f32(CONST_1_3, exp2x)); + float32x4_t den = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vaddq_f32(exp2x, CONST_1), vsubq_f32(CONST_1, num)); + float32x4_t tanh = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vmulq_f32(num, vinvq_f32(den)), vmulq_f32(x, den)); return tanh; } |