From ef9da00cad2b92633a130d43cb8a196278d49e85 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Wed, 27 Sep 2023 16:39:05 +0100 Subject: Reimplement erf function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * The current implementation has signfinicant inaccuracy and the issue cascades to GELU. * Use the implementation from ArmĀ® Optimized Routines. The maximum error is 1.93 ULP. Resolves: COMPMID-6554 Signed-off-by: Viet-Hoa Do Change-Id: If80131e164b7a078e34dd8e05b1506698f31d17a Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10395 Tested-by: Arm Jenkins Reviewed-by: TeresaARM Reviewed-by: SiCong Li Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- src/core/NEON/NEMath.inl | 71 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 21 deletions(-) (limited to 'src/core/NEON') diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl index f875917988..a5aba0bf23 100644 --- a/src/core/NEON/NEMath.inl +++ b/src/core/NEON/NEMath.inl @@ -21,6 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + +#include "src/core/utils/Math.h" #include "support/ToolchainSupport.h" #include @@ -224,35 +226,62 @@ inline float32x4_t vexpq_f32(float32x4_t x) #ifdef __aarch64__ inline float32x4_t verfq_f32(float32x4_t x) { - static const float erffdata[4] = {0.278393f, 0.230389f, 0.000972f, 0.078108f}; - static const float32x4_t coeffdata = vld1q_f32(erffdata); - static const float32x4_t onev{vdupq_n_f32(1.0f)}; + const float32x4_t max_value = vdupq_n_f32(3.9375); // 4 - 8/128 + const float32x4_t shift = vdupq_n_f32(65536); // 2^16 + const float32x4_t third = vdupq_n_f32(0.3333333333); // 1/3 + const float32x4_t one = vdupq_n_f32(1.f); + const uint32x4_t max_index = vdupq_n_u32(512); + const uint32x4_t sign_mask = vdupq_n_u32(0x7fffffff); + + const float32x4_t x_abs = vabsq_f32(x); - uint32x4_t selector = vcltzq_f32(x); + // erf(x) for x in [0, 3.9375] is approxiated as follows: + // + // erf(x) = erf(r) + scale(r) * d * (1 - r * d - 1/3 * d^2) + // + // where: + // r = floor(x * 128) / 128 + // d = x - r + // + // erf(r) and scale(r) are stored in a 513-entry lookup table. + // The LUT covers the range from 0 to 4 with the step of 1/128. + // + // Special cases: + // erf(x) = 1 for x > 3.9375 + // erf(x) = -1 for x < -3.9375 + + // Find the LUT indices by rounding the input value to the step of 1/128. + // + // `shift` is used to push out the 16 LSBs of the input value. Only 7 bits in the fraction part + // of the input value is preserved. + const float32x4_t z = x_abs + shift; + const float32x4_t r = z - shift; - float32x4_t absx = vabsq_f32(x); - float32x4_t absx2 = vmulq_f32(x, x); - float32x4_t absx3 = vmulq_f32(absx2, absx); - float32x4_t absx4 = vmulq_f32(absx2, absx2); + uint32x4_t index = vreinterpretq_u32_f32(z) - vreinterpretq_u32_f32(shift); + index = vminq_u32(index, max_index); - float32x4_t denom = onev; - denom = vfmaq_laneq_f32(denom, absx, coeffdata, 0); - denom = vfmaq_laneq_f32(denom, absx2, coeffdata, 1); - denom = vfmaq_laneq_f32(denom, absx3, coeffdata, 2); - denom = vfmaq_laneq_f32(denom, absx4, coeffdata, 3); + // Lookup erf(r) and scale(r). + const float64_t entry_0 = *reinterpret_cast(&erf_f32_lut[index[0]]); + const float64_t entry_1 = *reinterpret_cast(&erf_f32_lut[index[1]]); + const float64_t entry_2 = *reinterpret_cast(&erf_f32_lut[index[2]]); + const float64_t entry_3 = *reinterpret_cast(&erf_f32_lut[index[3]]); - denom = vmulq_f32(denom, denom); - denom = vmulq_f32(denom, denom); + const float32x4_t entry_01 = vreinterpretq_f32_f64(float64x2_t{entry_0, entry_1}); + const float32x4_t entry_23 = vreinterpretq_f32_f64(float64x2_t{entry_2, entry_3}); - float32x4_t fract = onev; - fract = vdivq_f32(fract, denom); + const float32x4_t erf_r = vuzp1q_f32(entry_01, entry_23); + const float32x4_t scale_r = vuzp2q_f32(entry_01, entry_23); - float32x4_t result = onev; - result = vsubq_f32(result, fract); + // Approximate erf(x) = erf(r) + scale(r) * d * (1 - r * d - 1/3 * d^2). + const float32x4_t d = x_abs - r; + const float32x4_t d2 = d * d; - float32x4_t inverse = vnegq_f32(result); + const float32x4_t t0 = vfmaq_f32(r, third, d); // t0 = r + 1/3 * d. + const float32x4_t t1 = vfmsq_f32(d, d2, t0); // t1 = d - d2 * t0 = d * (1 - r * d - 1/3 * d^2). + const float32x4_t erf_x = vfmaq_f32(erf_r, scale_r, t1); - result = vbslq_f32(selector, inverse, result); + const float32x4_t clamped = vbslq_f32(x_abs > max_value, one, erf_x); + const float32x4_t result = vbslq_f32(sign_mask, clamped, x); return result; } -- cgit v1.2.1