diff options
author | Viet-Hoa Do <viet-hoa.do@arm.com> | 2023-09-27 16:39:05 +0100 |
---|---|---|
committer | Viet-Hoa Do <viet-hoa.do@arm.com> | 2023-09-28 14:57:48 +0000 |
commit | ef9da00cad2b92633a130d43cb8a196278d49e85 (patch) | |
tree | 9955690805fa41c187e5d4a6a62020df292df699 /src/core/NEON | |
parent | 50dd2ee0cce42c72628b97686b02fc6ec073ca9c (diff) | |
download | ComputeLibrary-ef9da00cad2b92633a130d43cb8a196278d49e85.tar.gz |
Reimplement erf function
* The current implementation has signfinicant inaccuracy
and the issue cascades to GELU.
* Use the implementation from ArmĀ® Optimized Routines.
The maximum error is 1.93 ULP.
Resolves: COMPMID-6554
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: If80131e164b7a078e34dd8e05b1506698f31d17a
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10395
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: TeresaARM <teresa.charlinreyes@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON')
-rw-r--r-- | src/core/NEON/NEMath.inl | 71 |
1 files changed, 50 insertions, 21 deletions
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl index f875917988..a5aba0bf23 100644 --- a/src/core/NEON/NEMath.inl +++ b/src/core/NEON/NEMath.inl @@ -21,6 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + +#include "src/core/utils/Math.h" #include "support/ToolchainSupport.h" #include <cmath> @@ -224,35 +226,62 @@ inline float32x4_t vexpq_f32(float32x4_t x) #ifdef __aarch64__ inline float32x4_t verfq_f32(float32x4_t x) { - static const float erffdata[4] = {0.278393f, 0.230389f, 0.000972f, 0.078108f}; - static const float32x4_t coeffdata = vld1q_f32(erffdata); - static const float32x4_t onev{vdupq_n_f32(1.0f)}; + const float32x4_t max_value = vdupq_n_f32(3.9375); // 4 - 8/128 + const float32x4_t shift = vdupq_n_f32(65536); // 2^16 + const float32x4_t third = vdupq_n_f32(0.3333333333); // 1/3 + const float32x4_t one = vdupq_n_f32(1.f); + const uint32x4_t max_index = vdupq_n_u32(512); + const uint32x4_t sign_mask = vdupq_n_u32(0x7fffffff); + + const float32x4_t x_abs = vabsq_f32(x); - uint32x4_t selector = vcltzq_f32(x); + // erf(x) for x in [0, 3.9375] is approxiated as follows: + // + // erf(x) = erf(r) + scale(r) * d * (1 - r * d - 1/3 * d^2) + // + // where: + // r = floor(x * 128) / 128 + // d = x - r + // + // erf(r) and scale(r) are stored in a 513-entry lookup table. + // The LUT covers the range from 0 to 4 with the step of 1/128. + // + // Special cases: + // erf(x) = 1 for x > 3.9375 + // erf(x) = -1 for x < -3.9375 + + // Find the LUT indices by rounding the input value to the step of 1/128. + // + // `shift` is used to push out the 16 LSBs of the input value. Only 7 bits in the fraction part + // of the input value is preserved. + const float32x4_t z = x_abs + shift; + const float32x4_t r = z - shift; - float32x4_t absx = vabsq_f32(x); - float32x4_t absx2 = vmulq_f32(x, x); - float32x4_t absx3 = vmulq_f32(absx2, absx); - float32x4_t absx4 = vmulq_f32(absx2, absx2); + uint32x4_t index = vreinterpretq_u32_f32(z) - vreinterpretq_u32_f32(shift); + index = vminq_u32(index, max_index); - float32x4_t denom = onev; - denom = vfmaq_laneq_f32(denom, absx, coeffdata, 0); - denom = vfmaq_laneq_f32(denom, absx2, coeffdata, 1); - denom = vfmaq_laneq_f32(denom, absx3, coeffdata, 2); - denom = vfmaq_laneq_f32(denom, absx4, coeffdata, 3); + // Lookup erf(r) and scale(r). + const float64_t entry_0 = *reinterpret_cast<const float64_t *>(&erf_f32_lut[index[0]]); + const float64_t entry_1 = *reinterpret_cast<const float64_t *>(&erf_f32_lut[index[1]]); + const float64_t entry_2 = *reinterpret_cast<const float64_t *>(&erf_f32_lut[index[2]]); + const float64_t entry_3 = *reinterpret_cast<const float64_t *>(&erf_f32_lut[index[3]]); - denom = vmulq_f32(denom, denom); - denom = vmulq_f32(denom, denom); + const float32x4_t entry_01 = vreinterpretq_f32_f64(float64x2_t{entry_0, entry_1}); + const float32x4_t entry_23 = vreinterpretq_f32_f64(float64x2_t{entry_2, entry_3}); - float32x4_t fract = onev; - fract = vdivq_f32(fract, denom); + const float32x4_t erf_r = vuzp1q_f32(entry_01, entry_23); + const float32x4_t scale_r = vuzp2q_f32(entry_01, entry_23); - float32x4_t result = onev; - result = vsubq_f32(result, fract); + // Approximate erf(x) = erf(r) + scale(r) * d * (1 - r * d - 1/3 * d^2). + const float32x4_t d = x_abs - r; + const float32x4_t d2 = d * d; - float32x4_t inverse = vnegq_f32(result); + const float32x4_t t0 = vfmaq_f32(r, third, d); // t0 = r + 1/3 * d. + const float32x4_t t1 = vfmsq_f32(d, d2, t0); // t1 = d - d2 * t0 = d * (1 - r * d - 1/3 * d^2). + const float32x4_t erf_x = vfmaq_f32(erf_r, scale_r, t1); - result = vbslq_f32(selector, inverse, result); + const float32x4_t clamped = vbslq_f32(x_abs > max_value, one, erf_x); + const float32x4_t result = vbslq_f32(sign_mask, clamped, x); return result; } |