From 7fefac722568d997b4d9e136925e93c7abeb564a Mon Sep 17 00:00:00 2001 From: Ramy Elgammal Date: Thu, 20 Apr 2023 12:32:03 +0100 Subject: Fix rounding to nearest even for armv7a - If input value to round has the decimal point beyond the fraction part (23 bit) then simply return the rounded value = input value. Resolves: COMPMID-6025 Signed-off-by: Ramy Elgammal Change-Id: I1994e49a9bca7daeaeec7681aec099c63a97b53f Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9473 Comments-Addressed: Arm Jenkins Reviewed-by: Jakub Sujak Reviewed-by: Viet-Hoa Do Benchmark: Arm Jenkins Tested-by: Arm Jenkins --- src/core/NEON/NEMath.inl | 89 ++++++++++++++++++++++++++++-------------------- 1 file changed, 53 insertions(+), 36 deletions(-) diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl index 8b2d1c3c37..6198a257fc 100644 --- a/src/core/NEON/NEMath.inl +++ b/src/core/NEON/NEMath.inl @@ -54,7 +54,7 @@ inline float32x4_t prefer_vfmaq_f32(float32x4_t a, float32x4_t b, float32x4_t c) { #ifdef __aarch64__ return vfmaq_f32(a, b, c); -#else // __aarch64__ +#else // __aarch64__ return vmlaq_f32(a, b, c); #endif // __aarch64__ } @@ -73,20 +73,36 @@ inline float32x4_t vroundq_rte_f32(float32x4_t val) { #ifdef __aarch64__ return vrndnq_f32(val); -#else // __aarch64__ +#else // __aarch64__ static const float32x4_t CONST_HALF_FLOAT = vdupq_n_f32(0.5f); static const float32x4_t CONST_1_FLOAT = vdupq_n_f32(1.f); static const int32x4_t CONST_1_INT = vdupq_n_s32(1); const float32x4_t floor_val = vfloorq_f32(val); const float32x4_t diff = vsubq_f32(val, floor_val); + const float32x4_t fp32_upper_limit = vreinterpretq_f32_u32(vdupq_n_u32(0x4B000000)); // 0x4B000000 = (23U + 127U) << 23U /* - * Select the floor value when (diff<0.5 || (diff==0.5 && floor_val%2==0). - * This condition is checked by vorrq_u32(vcltq_f32(diff, CONST_HALF_FLOAT) ,vandq_u32(vceqq_f32(diff, CONST_HALF_FLOAT) , vmvnq_u32(vtstq_s32(vandq_s32(vcvtq_s32_f32(floor_val), CONST_1_INT),CONST_1_INT)))) + * 1. Select the floor value when (diff<0.5 || (diff==0.5 && floor_val%2==0). + * This condition is checked by vorrq_u32(vcltq_f32(diff, CONST_HALF_FLOAT) ,vandq_u32(vceqq_f32(diff, CONST_HALF_FLOAT) , vmvnq_u32(vtstq_s32(vandq_s32(vcvtq_s32_f32(floor_val), CONST_1_INT),CONST_1_INT)))) + * + * 2. In case the input value (val) is out of signed int32 range, then simple use the input value as the rounded value + * Because: + * in this case converting to int32 would saturate + * If the input float value is >= 2^23 * 1.00... 23 Zeros ..0 then the rounded value is exactly equal to the input value. + * Because: + * in IEEE single precision floating point representation the fraction part is 23 bit, so if exponent is 23 it means the fraction part = 0 as any digits after decimal point are truncated. + * Hence, rounding has no effect: + * Threshold upper limit with format |S|E(8bits)| Fraction(23bits) | = (23 + 127) << 23 (assuming positive sign): Adding 127, because 127 represents the actual zero in this format. */ - return vbslq_f32(vorrq_u32(vcltq_f32(diff, CONST_HALF_FLOAT), vandq_u32(vceqq_f32(diff, CONST_HALF_FLOAT), vmvnq_u32(vtstq_s32(vandq_s32(vcvtq_s32_f32(floor_val), CONST_1_INT), CONST_1_INT)))), - floor_val, vaddq_f32(floor_val, CONST_1_FLOAT)); + float32x4_t rounded_val = vbslq_f32(vorrq_u32(vcltq_f32(diff, CONST_HALF_FLOAT), + vandq_u32(vceqq_f32(diff, CONST_HALF_FLOAT), + vmvnq_u32(vtstq_s32(vandq_s32(vcvtq_s32_f32(floor_val), CONST_1_INT),CONST_1_INT)))), + floor_val, vaddq_f32(floor_val, CONST_1_FLOAT)); + + float32x4_t result = vbslq_f32(vcgeq_f32(vabsq_f32(val), fp32_upper_limit), val, rounded_val); + + return result; #endif // __aarch64__ } @@ -136,7 +152,8 @@ inline float32x4_t vtaylor_polyq_f32(float32x4_t x, const std::array::infinity()); - const auto max_input = vdupq_n_f32(88.37f); // Approximately ln(2^127.5) + const auto max_input = vdupq_n_f32(88.37f); // Approximately ln(2^127.5) const auto zero = vdupq_n_f32(0.f); - const auto min_input = vdupq_n_f32(-86.64f); // Approximately ln(2^-125) + const auto min_input = vdupq_n_f32(-86.64f); // Approximately ln(2^-125) // Range reduction: // e^x = 2^n * e^r @@ -176,23 +193,23 @@ inline float32x4_t vexpq_f32(float32x4_t x) // (i.e. n) because the decimal part has been pushed out and lost. // * The addition of 127 makes the FP32 fraction part of z ready to be used as the exponent // in FP32 format. Left shifting z by 23 bits will result in 2^n. - const auto z = prefer_vfmaq_f32(shift, x, inv_ln2); - const auto n = z - shift; - const auto scale = vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n + const auto z = prefer_vfmaq_f32(shift, x, inv_ln2); + const auto n = z - shift; + const auto scale = vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n // The calculation of n * ln(2) is done using 2 steps to achieve accuracy beyond FP32. // This outperforms longer Taylor series (3-4 tabs) both in term of accuracy and performance. const auto r_hi = prefer_vfmaq_f32(x, n, neg_ln2_hi); - const auto r = prefer_vfmaq_f32(r_hi, n, neg_ln2_lo); + const auto r = prefer_vfmaq_f32(r_hi, n, neg_ln2_lo); // Compute the truncated Taylor series of e^r. // poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) const auto r2 = r * r; - const auto p1 = c1 * r; - const auto p23 = prefer_vfmaq_f32(c2, c3, r); - const auto p45 = prefer_vfmaq_f32(c4, c5, r); - const auto p2345 = prefer_vfmaq_f32(p23, p45, r2); + const auto p1 = c1 * r; + const auto p23 = prefer_vfmaq_f32(c2, c3, r); + const auto p45 = prefer_vfmaq_f32(c4, c5, r); + const auto p2345 = prefer_vfmaq_f32(p23, p45, r2); const auto p12345 = prefer_vfmaq_f32(p1, p2345, r2); auto poly = prefer_vfmaq_f32(scale, p12345, scale); @@ -450,7 +467,7 @@ inline void convert_float32x4x3_to_uint8x8x3(const float32x4x3_t &in1, const flo inline void convert_float32x4x4_to_uint8x16(const float32x4x4_t &in, uint8x16_t &out) { const auto low = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[0])), - vqmovn_u32(vcvtq_u32_f32(in.val[1]))); + vqmovn_u32(vcvtq_u32_f32(in.val[1]))); const auto high = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[2])), vqmovn_u32(vcvtq_u32_f32(in.val[3]))); out = vcombine_u8(vqmovn_u16(low), vqmovn_u16(high)); @@ -459,7 +476,7 @@ inline void convert_float32x4x4_to_uint8x16(const float32x4x4_t &in, uint8x16_t inline void convert_float32x4x4_to_int8x16(const float32x4x4_t &in, int8x16_t &out) { const auto low = vcombine_s16(vqmovn_s32(vcvtq_s32_f32(in.val[0])), - vqmovn_s32(vcvtq_s32_f32(in.val[1]))); + vqmovn_s32(vcvtq_s32_f32(in.val[1]))); const auto high = vcombine_s16(vqmovn_s32(vcvtq_s32_f32(in.val[2])), vqmovn_s32(vcvtq_s32_f32(in.val[3]))); out = vcombine_s8(vqmovn_s16(low), vqmovn_s16(high)); @@ -563,21 +580,21 @@ inline float16x4_t vtanh_rational_approx_f16(float16x4_t x16) const float32x4_t x = vcvt_f32_f16(x16); const float32x4_t ONE = vdupq_n_f32(1.0f); - const float32x4_t C1 = vdupq_n_f32(0.43760237f); - const float32x4_t C2 = vdupq_n_f32(0.104402f); - const float32x4_t C3 = vdupq_n_f32(0.013442706f); - const float32x4_t C4 = vdupq_n_f32(0.00073561433f); + const float32x4_t C1 = vdupq_n_f32(0.43760237f); + const float32x4_t C2 = vdupq_n_f32(0.104402f); + const float32x4_t C3 = vdupq_n_f32(0.013442706f); + const float32x4_t C4 = vdupq_n_f32(0.00073561433f); - const float32x4_t x2 = vmulq_f32(x,x); + const float32x4_t x2 = vmulq_f32(x, x); // Denominator polynomial 1 + C1*x^2 + C3*x^4 float32x4_t denom = vfmaq_f32(C1, C3, x2); - denom = vfmaq_f32(ONE, x2, denom); + denom = vfmaq_f32(ONE, x2, denom); // Numerator polynomial x*(1 + C2*x^2 + C4*x^4) float32x4_t numer = vfmaq_f32(C2, C4, x2); - numer = vfmaq_f32(ONE, x2, numer); - numer = vmulq_f32(numer, x); + numer = vfmaq_f32(ONE, x2, numer); + numer = vmulq_f32(numer, x); return vcvt_f16_f32(vdivq_f32(numer, denom)); } @@ -585,14 +602,14 @@ inline float16x4_t vtanh_rational_approx_f16(float16x4_t x16) inline float16x8_t vtanhq_f16(float16x8_t x) { // Split into high/low and use rational approximation on both parts exactly - const float16x8_t tanh = vcombine_f16(vtanh_rational_approx_f16( vget_low_f16(x)), + const float16x8_t tanh = vcombine_f16(vtanh_rational_approx_f16(vget_low_f16(x)), vtanh_rational_approx_f16(vget_high_f16(x))); // tanh(x) == sign(x) to F16 precision for |x| >= 4.508, use sign after this - const float16x8_t ONE = vdupq_n_f16(1.0f); - const float16x8_t MAX_X = vdupq_n_f16(4.508f); - const auto at_limit = vcageq_f16(x, MAX_X); // |x| >= 4.508 - const float16x8_t sign_x = vbslq_f16(vclezq_f16(x), -ONE, ONE); + const float16x8_t ONE = vdupq_n_f16(1.0f); + const float16x8_t MAX_X = vdupq_n_f16(4.508f); + const auto at_limit = vcageq_f16(x, MAX_X); // |x| >= 4.508 + const float16x8_t sign_x = vbslq_f16(vclezq_f16(x), -ONE, ONE); return vbslq_f16(at_limit, sign_x, tanh); } -- cgit v1.2.1