From 86689cdd95f634fb374f3875f62a4cb3408e1699 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Mon, 21 Nov 2022 17:17:56 +0000 Subject: Optimize CPU base-e exponential function on FP32 Resolves: COMPMID-5664 Signed-off-by: Viet-Hoa Do Change-Id: I4182752e213aade19005ee984a488c2490453f8f Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8747 Benchmark: Arm Jenkins Reviewed-by: Pablo Marquez Tello Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- src/core/NEON/NEMath.inl | 104 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 71 insertions(+), 33 deletions(-) (limited to 'src/core/NEON') diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl index acac36e1a5..94bbc10ad8 100644 --- a/src/core/NEON/NEMath.inl +++ b/src/core/NEON/NEMath.inl @@ -28,21 +28,6 @@ namespace arm_compute { -/** Exponent polynomial coefficients */ -const std::array exp_tab = -{ - { - vdupq_n_f32(1.f), - vdupq_n_f32(0.0416598916054f), - vdupq_n_f32(0.500000596046f), - vdupq_n_f32(0.0014122662833f), - vdupq_n_f32(1.00000011921f), - vdupq_n_f32(0.00833693705499f), - vdupq_n_f32(0.166665703058f), - vdupq_n_f32(0.000195780929062f), - } -}; - /** Logarithm polynomial coefficients */ const std::array log_tab = { @@ -65,6 +50,15 @@ constexpr float te_sin_coeff4 = 0.023809523810f; // 1/(6*7) constexpr float te_sin_coeff5 = 0.013888888889f; // 1/(8*9) #ifndef DOXYGEN_SKIP_THIS +inline float32x4_t prefer_vfmaq_f32(float32x4_t a, float32x4_t b, float32x4_t c) +{ +#ifdef __aarch64__ + return vfmaq_f32(a, b, c); +#else // __aarch64__ + return vmlaq_f32(a, b, c); +#endif // __aarch64__ +} + inline float32x4_t vfloorq_f32(float32x4_t val) { static const float32x4_t CONST_1 = vdupq_n_f32(1.f); @@ -142,26 +136,70 @@ inline float32x4_t vtaylor_polyq_f32(float32x4_t x, const std::array::infinity()); - static const float32x4_t CONST_MAX_INPUT = vdupq_n_f32(88.7f); - static const float32x4_t CONST_0 = vdupq_n_f32(0.f); - static const int32x4_t CONST_NEGATIVE_126 = vdupq_n_s32(-126); - - // Perform range reduction [-log(2),log(2)] - int32x4_t m = vcvtq_s32_f32(vmulq_f32(x, CONST_INV_LN2)); - float32x4_t val = vmlsq_f32(x, vcvtq_f32_s32(m), CONST_LN2); - - // Polynomial Approximation - float32x4_t poly = vtaylor_polyq_f32(val, exp_tab); - - // Reconstruct - poly = vreinterpretq_f32_s32(vqaddq_s32(vreinterpretq_s32_f32(poly), vqshlq_n_s32(m, 23))); - poly = vbslq_f32(vcltq_s32(m, CONST_NEGATIVE_126), CONST_0, poly); // Handle underflow - poly = vbslq_f32(vcgtq_f32(x, CONST_MAX_INPUT), CONST_INF, poly); // Handle overflow + const auto c1 = vreinterpretq_f32_u32(vdupq_n_u32(exp_f32_coeff[0])); + const auto c2 = vreinterpretq_f32_u32(vdupq_n_u32(exp_f32_coeff[1])); + const auto c3 = vreinterpretq_f32_u32(vdupq_n_u32(exp_f32_coeff[2])); + const auto c4 = vreinterpretq_f32_u32(vdupq_n_u32(exp_f32_coeff[3])); + const auto c5 = vreinterpretq_f32_u32(vdupq_n_u32(exp_f32_coeff[4])); + + const auto shift = vreinterpretq_f32_u32(vdupq_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto inv_ln2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f + const auto neg_ln2_hi = vreinterpretq_f32_u32(vdupq_n_u32(0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f + const auto neg_ln2_lo = vreinterpretq_f32_u32(vdupq_n_u32(0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f + + const auto inf = vdupq_n_f32(std::numeric_limits::infinity()); + const auto max_input = vdupq_n_f32(88.7f); // Approximately ln(0x1.fffffep+127) + const auto zero = vdupq_n_f32(0.f); + const auto min_input = vdupq_n_f32(-86.6f); // Approximately ln(2^-125) + + // Range reduction: + // e^x = 2^n * e^r + // where: + // n = floor(x / ln(2)) + // r = x - n * ln(2) + // + // By adding x / ln(2) with 2^23 + 127 (shift): + // * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 forces decimal part + // of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n) + 127 will occupy + // the whole fraction part of z in FP32 format. + // Subtracting 2^23 + 127 (shift) from z will result in the integer part of x / ln(2) + // (i.e. n) because the decimal part has been pushed out and lost. + // * The addition of 127 makes the FP32 fraction part of z ready to be used as the exponent + // in FP32 format. Left shifting z by 23 bits will result in 2^n. + const auto z = prefer_vfmaq_f32(shift, x, inv_ln2); + const auto n = z - shift; + const auto scale = vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n + + // The calculation of n * ln(2) is done using 2 steps to achieve accuracy beyond FP32. + // This outperforms longer Taylor series (3-4 tabs) both in term of accuracy and performance. + const auto r_hi = prefer_vfmaq_f32(x, n, neg_ln2_hi); + const auto r = prefer_vfmaq_f32(r_hi, n, neg_ln2_lo); + + // Compute the truncated Taylor series of e^r. + // poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) + const auto r2 = r * r; + + const auto p1 = c1 * r; + const auto p23 = prefer_vfmaq_f32(c2, c3, r); + const auto p45 = prefer_vfmaq_f32(c4, c5, r); + const auto p2345 = prefer_vfmaq_f32(p23, p45, r2); + const auto p12345 = prefer_vfmaq_f32(p1, p2345, r2); + + auto poly = prefer_vfmaq_f32(scale, p12345, scale); + + // Handle underflow and overflow. + poly = vbslq_f32(vcltq_f32(x, min_input), zero, poly); + poly = vbslq_f32(vcgtq_f32(x, max_input), inf, poly); return poly; } -- cgit v1.2.1