aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRamy Elgammal <ramy.elgammal@arm.com>2023-04-20 12:32:03 +0100
committerRamy Elgammal <ramy.elgammal@arm.com>2023-04-25 10:12:37 +0000
commit7fefac722568d997b4d9e136925e93c7abeb564a (patch)
tree965fc36cee632aa12d66953d9f996871169bc3f6
parent05a65e3f397946cacf5c17d8c528a9fad3f1b322 (diff)
downloadComputeLibrary-7fefac722568d997b4d9e136925e93c7abeb564a.tar.gz
Fix rounding to nearest even for armv7a
- If input value to round has the decimal point beyond the fraction part (23 bit) then simply return the rounded value = input value. Resolves: COMPMID-6025 Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com> Change-Id: I1994e49a9bca7daeaeec7681aec099c63a97b53f Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9473 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Jakub Sujak <jakub.sujak@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/NEON/NEMath.inl89
1 files changed, 53 insertions, 36 deletions
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl
index 8b2d1c3c37..6198a257fc 100644
--- a/src/core/NEON/NEMath.inl
+++ b/src/core/NEON/NEMath.inl
@@ -54,7 +54,7 @@ inline float32x4_t prefer_vfmaq_f32(float32x4_t a, float32x4_t b, float32x4_t c)
{
#ifdef __aarch64__
return vfmaq_f32(a, b, c);
-#else // __aarch64__
+#else // __aarch64__
return vmlaq_f32(a, b, c);
#endif // __aarch64__
}
@@ -73,20 +73,36 @@ inline float32x4_t vroundq_rte_f32(float32x4_t val)
{
#ifdef __aarch64__
return vrndnq_f32(val);
-#else // __aarch64__
+#else // __aarch64__
static const float32x4_t CONST_HALF_FLOAT = vdupq_n_f32(0.5f);
static const float32x4_t CONST_1_FLOAT = vdupq_n_f32(1.f);
static const int32x4_t CONST_1_INT = vdupq_n_s32(1);
const float32x4_t floor_val = vfloorq_f32(val);
const float32x4_t diff = vsubq_f32(val, floor_val);
+ const float32x4_t fp32_upper_limit = vreinterpretq_f32_u32(vdupq_n_u32(0x4B000000)); // 0x4B000000 = (23U + 127U) << 23U
/*
- * Select the floor value when (diff<0.5 || (diff==0.5 && floor_val%2==0).
- * This condition is checked by vorrq_u32(vcltq_f32(diff, CONST_HALF_FLOAT) ,vandq_u32(vceqq_f32(diff, CONST_HALF_FLOAT) , vmvnq_u32(vtstq_s32(vandq_s32(vcvtq_s32_f32(floor_val), CONST_1_INT),CONST_1_INT))))
+ * 1. Select the floor value when (diff<0.5 || (diff==0.5 && floor_val%2==0).
+ * This condition is checked by vorrq_u32(vcltq_f32(diff, CONST_HALF_FLOAT) ,vandq_u32(vceqq_f32(diff, CONST_HALF_FLOAT) , vmvnq_u32(vtstq_s32(vandq_s32(vcvtq_s32_f32(floor_val), CONST_1_INT),CONST_1_INT))))
+ *
+ * 2. In case the input value (val) is out of signed int32 range, then simple use the input value as the rounded value
+ * Because:
+ * in this case converting to int32 would saturate
+ * If the input float value is >= 2^23 * 1.00... 23 Zeros ..0 then the rounded value is exactly equal to the input value.
+ * Because:
+ * in IEEE single precision floating point representation the fraction part is 23 bit, so if exponent is 23 it means the fraction part = 0 as any digits after decimal point are truncated.
+ * Hence, rounding has no effect:
+ * Threshold upper limit with format |S|E(8bits)| Fraction(23bits) | = (23 + 127) << 23 (assuming positive sign): Adding 127, because 127 represents the actual zero in this format.
*/
- return vbslq_f32(vorrq_u32(vcltq_f32(diff, CONST_HALF_FLOAT), vandq_u32(vceqq_f32(diff, CONST_HALF_FLOAT), vmvnq_u32(vtstq_s32(vandq_s32(vcvtq_s32_f32(floor_val), CONST_1_INT), CONST_1_INT)))),
- floor_val, vaddq_f32(floor_val, CONST_1_FLOAT));
+ float32x4_t rounded_val = vbslq_f32(vorrq_u32(vcltq_f32(diff, CONST_HALF_FLOAT),
+ vandq_u32(vceqq_f32(diff, CONST_HALF_FLOAT),
+ vmvnq_u32(vtstq_s32(vandq_s32(vcvtq_s32_f32(floor_val), CONST_1_INT),CONST_1_INT)))),
+ floor_val, vaddq_f32(floor_val, CONST_1_FLOAT));
+
+ float32x4_t result = vbslq_f32(vcgeq_f32(vabsq_f32(val), fp32_upper_limit), val, rounded_val);
+
+ return result;
#endif // __aarch64__
}
@@ -136,7 +152,8 @@ inline float32x4_t vtaylor_polyq_f32(float32x4_t x, const std::array<float32x4_t
return res;
}
-static const uint32_t exp_f32_coeff[] = {
+static const uint32_t exp_f32_coeff[] =
+{
0x3f7ffff6, // x^1: 0x1.ffffecp-1f
0x3efffedb, // x^2: 0x1.fffdb6p-2f
0x3e2aaf33, // x^3: 0x1.555e66p-3f
@@ -152,15 +169,15 @@ inline float32x4_t vexpq_f32(float32x4_t x)
const auto c4 = vreinterpretq_f32_u32(vdupq_n_u32(exp_f32_coeff[3]));
const auto c5 = vreinterpretq_f32_u32(vdupq_n_u32(exp_f32_coeff[4]));
- const auto shift = vreinterpretq_f32_u32(vdupq_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f
- const auto inv_ln2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f
- const auto neg_ln2_hi = vreinterpretq_f32_u32(vdupq_n_u32(0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
- const auto neg_ln2_lo = vreinterpretq_f32_u32(vdupq_n_u32(0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
+ const auto shift = vreinterpretq_f32_u32(vdupq_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f
+ const auto inv_ln2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f
+ const auto neg_ln2_hi = vreinterpretq_f32_u32(vdupq_n_u32(0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
+ const auto neg_ln2_lo = vreinterpretq_f32_u32(vdupq_n_u32(0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
const auto inf = vdupq_n_f32(std::numeric_limits<float>::infinity());
- const auto max_input = vdupq_n_f32(88.37f); // Approximately ln(2^127.5)
+ const auto max_input = vdupq_n_f32(88.37f); // Approximately ln(2^127.5)
const auto zero = vdupq_n_f32(0.f);
- const auto min_input = vdupq_n_f32(-86.64f); // Approximately ln(2^-125)
+ const auto min_input = vdupq_n_f32(-86.64f); // Approximately ln(2^-125)
// Range reduction:
// e^x = 2^n * e^r
@@ -176,23 +193,23 @@ inline float32x4_t vexpq_f32(float32x4_t x)
// (i.e. n) because the decimal part has been pushed out and lost.
// * The addition of 127 makes the FP32 fraction part of z ready to be used as the exponent
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
- const auto z = prefer_vfmaq_f32(shift, x, inv_ln2);
- const auto n = z - shift;
- const auto scale = vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n
+ const auto z = prefer_vfmaq_f32(shift, x, inv_ln2);
+ const auto n = z - shift;
+ const auto scale = vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy beyond FP32.
// This outperforms longer Taylor series (3-4 tabs) both in term of accuracy and performance.
const auto r_hi = prefer_vfmaq_f32(x, n, neg_ln2_hi);
- const auto r = prefer_vfmaq_f32(r_hi, n, neg_ln2_lo);
+ const auto r = prefer_vfmaq_f32(r_hi, n, neg_ln2_lo);
// Compute the truncated Taylor series of e^r.
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
const auto r2 = r * r;
- const auto p1 = c1 * r;
- const auto p23 = prefer_vfmaq_f32(c2, c3, r);
- const auto p45 = prefer_vfmaq_f32(c4, c5, r);
- const auto p2345 = prefer_vfmaq_f32(p23, p45, r2);
+ const auto p1 = c1 * r;
+ const auto p23 = prefer_vfmaq_f32(c2, c3, r);
+ const auto p45 = prefer_vfmaq_f32(c4, c5, r);
+ const auto p2345 = prefer_vfmaq_f32(p23, p45, r2);
const auto p12345 = prefer_vfmaq_f32(p1, p2345, r2);
auto poly = prefer_vfmaq_f32(scale, p12345, scale);
@@ -450,7 +467,7 @@ inline void convert_float32x4x3_to_uint8x8x3(const float32x4x3_t &in1, const flo
inline void convert_float32x4x4_to_uint8x16(const float32x4x4_t &in, uint8x16_t &out)
{
const auto low = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[0])),
- vqmovn_u32(vcvtq_u32_f32(in.val[1])));
+ vqmovn_u32(vcvtq_u32_f32(in.val[1])));
const auto high = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[2])),
vqmovn_u32(vcvtq_u32_f32(in.val[3])));
out = vcombine_u8(vqmovn_u16(low), vqmovn_u16(high));
@@ -459,7 +476,7 @@ inline void convert_float32x4x4_to_uint8x16(const float32x4x4_t &in, uint8x16_t
inline void convert_float32x4x4_to_int8x16(const float32x4x4_t &in, int8x16_t &out)
{
const auto low = vcombine_s16(vqmovn_s32(vcvtq_s32_f32(in.val[0])),
- vqmovn_s32(vcvtq_s32_f32(in.val[1])));
+ vqmovn_s32(vcvtq_s32_f32(in.val[1])));
const auto high = vcombine_s16(vqmovn_s32(vcvtq_s32_f32(in.val[2])),
vqmovn_s32(vcvtq_s32_f32(in.val[3])));
out = vcombine_s8(vqmovn_s16(low), vqmovn_s16(high));
@@ -563,21 +580,21 @@ inline float16x4_t vtanh_rational_approx_f16(float16x4_t x16)
const float32x4_t x = vcvt_f32_f16(x16);
const float32x4_t ONE = vdupq_n_f32(1.0f);
- const float32x4_t C1 = vdupq_n_f32(0.43760237f);
- const float32x4_t C2 = vdupq_n_f32(0.104402f);
- const float32x4_t C3 = vdupq_n_f32(0.013442706f);
- const float32x4_t C4 = vdupq_n_f32(0.00073561433f);
+ const float32x4_t C1 = vdupq_n_f32(0.43760237f);
+ const float32x4_t C2 = vdupq_n_f32(0.104402f);
+ const float32x4_t C3 = vdupq_n_f32(0.013442706f);
+ const float32x4_t C4 = vdupq_n_f32(0.00073561433f);
- const float32x4_t x2 = vmulq_f32(x,x);
+ const float32x4_t x2 = vmulq_f32(x, x);
// Denominator polynomial 1 + C1*x^2 + C3*x^4
float32x4_t denom = vfmaq_f32(C1, C3, x2);
- denom = vfmaq_f32(ONE, x2, denom);
+ denom = vfmaq_f32(ONE, x2, denom);
// Numerator polynomial x*(1 + C2*x^2 + C4*x^4)
float32x4_t numer = vfmaq_f32(C2, C4, x2);
- numer = vfmaq_f32(ONE, x2, numer);
- numer = vmulq_f32(numer, x);
+ numer = vfmaq_f32(ONE, x2, numer);
+ numer = vmulq_f32(numer, x);
return vcvt_f16_f32(vdivq_f32(numer, denom));
}
@@ -585,14 +602,14 @@ inline float16x4_t vtanh_rational_approx_f16(float16x4_t x16)
inline float16x8_t vtanhq_f16(float16x8_t x)
{
// Split into high/low and use rational approximation on both parts exactly
- const float16x8_t tanh = vcombine_f16(vtanh_rational_approx_f16( vget_low_f16(x)),
+ const float16x8_t tanh = vcombine_f16(vtanh_rational_approx_f16(vget_low_f16(x)),
vtanh_rational_approx_f16(vget_high_f16(x)));
// tanh(x) == sign(x) to F16 precision for |x| >= 4.508, use sign after this
- const float16x8_t ONE = vdupq_n_f16(1.0f);
- const float16x8_t MAX_X = vdupq_n_f16(4.508f);
- const auto at_limit = vcageq_f16(x, MAX_X); // |x| >= 4.508
- const float16x8_t sign_x = vbslq_f16(vclezq_f16(x), -ONE, ONE);
+ const float16x8_t ONE = vdupq_n_f16(1.0f);
+ const float16x8_t MAX_X = vdupq_n_f16(4.508f);
+ const auto at_limit = vcageq_f16(x, MAX_X); // |x| >= 4.508
+ const float16x8_t sign_x = vbslq_f16(vclezq_f16(x), -ONE, ONE);
return vbslq_f16(at_limit, sign_x, tanh);
}