aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/NEMath.inl
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-09-27 16:39:05 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-09-28 14:57:48 +0000
commitef9da00cad2b92633a130d43cb8a196278d49e85 (patch)
tree9955690805fa41c187e5d4a6a62020df292df699 /src/core/NEON/NEMath.inl
parent50dd2ee0cce42c72628b97686b02fc6ec073ca9c (diff)
downloadComputeLibrary-ef9da00cad2b92633a130d43cb8a196278d49e85.tar.gz
Reimplement erf function
* The current implementation has signfinicant inaccuracy and the issue cascades to GELU. * Use the implementation from ArmĀ® Optimized Routines. The maximum error is 1.93 ULP. Resolves: COMPMID-6554 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: If80131e164b7a078e34dd8e05b1506698f31d17a Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10395 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: TeresaARM <teresa.charlinreyes@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/NEMath.inl')
-rw-r--r--src/core/NEON/NEMath.inl71
1 files changed, 50 insertions, 21 deletions
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl
index f875917988..a5aba0bf23 100644
--- a/src/core/NEON/NEMath.inl
+++ b/src/core/NEON/NEMath.inl
@@ -21,6 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
+#include "src/core/utils/Math.h"
#include "support/ToolchainSupport.h"
#include <cmath>
@@ -224,35 +226,62 @@ inline float32x4_t vexpq_f32(float32x4_t x)
#ifdef __aarch64__
inline float32x4_t verfq_f32(float32x4_t x)
{
- static const float erffdata[4] = {0.278393f, 0.230389f, 0.000972f, 0.078108f};
- static const float32x4_t coeffdata = vld1q_f32(erffdata);
- static const float32x4_t onev{vdupq_n_f32(1.0f)};
+ const float32x4_t max_value = vdupq_n_f32(3.9375); // 4 - 8/128
+ const float32x4_t shift = vdupq_n_f32(65536); // 2^16
+ const float32x4_t third = vdupq_n_f32(0.3333333333); // 1/3
+ const float32x4_t one = vdupq_n_f32(1.f);
+ const uint32x4_t max_index = vdupq_n_u32(512);
+ const uint32x4_t sign_mask = vdupq_n_u32(0x7fffffff);
+
+ const float32x4_t x_abs = vabsq_f32(x);
- uint32x4_t selector = vcltzq_f32(x);
+ // erf(x) for x in [0, 3.9375] is approxiated as follows:
+ //
+ // erf(x) = erf(r) + scale(r) * d * (1 - r * d - 1/3 * d^2)
+ //
+ // where:
+ // r = floor(x * 128) / 128
+ // d = x - r
+ //
+ // erf(r) and scale(r) are stored in a 513-entry lookup table.
+ // The LUT covers the range from 0 to 4 with the step of 1/128.
+ //
+ // Special cases:
+ // erf(x) = 1 for x > 3.9375
+ // erf(x) = -1 for x < -3.9375
+
+ // Find the LUT indices by rounding the input value to the step of 1/128.
+ //
+ // `shift` is used to push out the 16 LSBs of the input value. Only 7 bits in the fraction part
+ // of the input value is preserved.
+ const float32x4_t z = x_abs + shift;
+ const float32x4_t r = z - shift;
- float32x4_t absx = vabsq_f32(x);
- float32x4_t absx2 = vmulq_f32(x, x);
- float32x4_t absx3 = vmulq_f32(absx2, absx);
- float32x4_t absx4 = vmulq_f32(absx2, absx2);
+ uint32x4_t index = vreinterpretq_u32_f32(z) - vreinterpretq_u32_f32(shift);
+ index = vminq_u32(index, max_index);
- float32x4_t denom = onev;
- denom = vfmaq_laneq_f32(denom, absx, coeffdata, 0);
- denom = vfmaq_laneq_f32(denom, absx2, coeffdata, 1);
- denom = vfmaq_laneq_f32(denom, absx3, coeffdata, 2);
- denom = vfmaq_laneq_f32(denom, absx4, coeffdata, 3);
+ // Lookup erf(r) and scale(r).
+ const float64_t entry_0 = *reinterpret_cast<const float64_t *>(&erf_f32_lut[index[0]]);
+ const float64_t entry_1 = *reinterpret_cast<const float64_t *>(&erf_f32_lut[index[1]]);
+ const float64_t entry_2 = *reinterpret_cast<const float64_t *>(&erf_f32_lut[index[2]]);
+ const float64_t entry_3 = *reinterpret_cast<const float64_t *>(&erf_f32_lut[index[3]]);
- denom = vmulq_f32(denom, denom);
- denom = vmulq_f32(denom, denom);
+ const float32x4_t entry_01 = vreinterpretq_f32_f64(float64x2_t{entry_0, entry_1});
+ const float32x4_t entry_23 = vreinterpretq_f32_f64(float64x2_t{entry_2, entry_3});
- float32x4_t fract = onev;
- fract = vdivq_f32(fract, denom);
+ const float32x4_t erf_r = vuzp1q_f32(entry_01, entry_23);
+ const float32x4_t scale_r = vuzp2q_f32(entry_01, entry_23);
- float32x4_t result = onev;
- result = vsubq_f32(result, fract);
+ // Approximate erf(x) = erf(r) + scale(r) * d * (1 - r * d - 1/3 * d^2).
+ const float32x4_t d = x_abs - r;
+ const float32x4_t d2 = d * d;
- float32x4_t inverse = vnegq_f32(result);
+ const float32x4_t t0 = vfmaq_f32(r, third, d); // t0 = r + 1/3 * d.
+ const float32x4_t t1 = vfmsq_f32(d, d2, t0); // t1 = d - d2 * t0 = d * (1 - r * d - 1/3 * d^2).
+ const float32x4_t erf_x = vfmaq_f32(erf_r, scale_r, t1);
- result = vbslq_f32(selector, inverse, result);
+ const float32x4_t clamped = vbslq_f32(x_abs > max_value, one, erf_x);
+ const float32x4_t result = vbslq_f32(sign_mask, clamped, x);
return result;
}