diff options
Diffstat (limited to 'src/core/NEON/SVEMath.inl')
-rw-r--r-- | src/core/NEON/SVEMath.inl | 256 |
1 files changed, 173 insertions, 83 deletions
diff --git a/src/core/NEON/SVEMath.inl b/src/core/NEON/SVEMath.inl index 7625e5be34..fdf94f0859 100644 --- a/src/core/NEON/SVEMath.inl +++ b/src/core/NEON/SVEMath.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,10 +21,14 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + +#ifndef ACL_SRC_CORE_NEON_SVEMATH_INL +#define ACL_SRC_CORE_NEON_SVEMATH_INL + #include <cmath> #include <limits> -#if defined(__ARM_FEATURE_SVE) && defined(ENABLE_SVE) +#if defined(__ARM_FEATURE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE) #ifndef M_PI #define M_PI (3.14159265358979323846) @@ -32,8 +36,16 @@ namespace arm_compute { -inline svfloat32_t svtaylor_poly_f32_z(svbool_t pg, svfloat32_t x, svfloat32_t coeff_1, svfloat32_t coeff_2, svfloat32_t coeff_3, - svfloat32_t coeff_4, svfloat32_t coeff_5, svfloat32_t coeff_6, svfloat32_t coeff_7, svfloat32_t coeff_8) +inline svfloat32_t svtaylor_poly_f32_z(svbool_t pg, + svfloat32_t x, + svfloat32_t coeff_1, + svfloat32_t coeff_2, + svfloat32_t coeff_3, + svfloat32_t coeff_4, + svfloat32_t coeff_5, + svfloat32_t coeff_6, + svfloat32_t coeff_7, + svfloat32_t coeff_8) { const auto A = svmla_f32_z(pg, coeff_1, coeff_5, x); const auto B = svmla_f32_z(pg, coeff_3, coeff_7, x); @@ -45,8 +57,16 @@ inline svfloat32_t svtaylor_poly_f32_z(svbool_t pg, svfloat32_t x, svfloat32_t c return res; } -inline svfloat16_t svtaylor_poly_f16_z(svbool_t pg, svfloat16_t x, svfloat16_t coeff_1, svfloat16_t coeff_2, svfloat16_t coeff_3, - svfloat16_t coeff_4, svfloat16_t coeff_5, svfloat16_t coeff_6, svfloat16_t coeff_7, svfloat16_t coeff_8) +inline svfloat16_t svtaylor_poly_f16_z(svbool_t pg, + svfloat16_t x, + svfloat16_t coeff_1, + svfloat16_t coeff_2, + svfloat16_t coeff_3, + svfloat16_t coeff_4, + svfloat16_t coeff_5, + svfloat16_t coeff_6, + svfloat16_t coeff_7, + svfloat16_t coeff_8) { const auto A = svmla_f16_z(pg, coeff_1, coeff_5, x); const auto B = svmla_f16_z(pg, coeff_3, coeff_7, x); @@ -74,67 +94,104 @@ inline svfloat32_t svinv_f32_z(svbool_t pg, svfloat32_t x) return recip; } +static const uint32_t svexp_f32_coeff[] = { + 0x3f7ffff6, // x^1: 0x1.ffffecp-1f + 0x3efffedb, // x^2: 0x1.fffdb6p-2f + 0x3e2aaf33, // x^3: 0x1.555e66p-3f + 0x3d2b9f17, // x^4: 0x1.573e2ep-5f + 0x3c072010, // x^5: 0x1.0e4020p-7f +}; + inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) { - const auto CONST_LN2 = svdup_n_f32(0.6931471805f); // ln(2) - const auto CONST_INV_LN2 = svdup_n_f32(1.4426950408f); // 1/ln(2) - const auto CONST_INF = svdup_n_f32(std::numeric_limits<float>::infinity()); - const auto CONST_MAX_INPUT = svdup_n_f32(88.7f); - const auto CONST_0 = svdup_n_f32(0.f); - const auto CONST_NEGATIVE_126 = svdup_n_s32(-126); - - /** Exponent polynomial coefficients */ - const svfloat32_t exp_tab_1 = svdup_n_f32(1.f); - const svfloat32_t exp_tab_2 = svdup_n_f32(0.0416598916054f); - const svfloat32_t exp_tab_3 = svdup_n_f32(0.500000596046f); - const svfloat32_t exp_tab_4 = svdup_n_f32(0.0014122662833f); - const svfloat32_t exp_tab_5 = svdup_n_f32(1.00000011921f); - const svfloat32_t exp_tab_6 = svdup_n_f32(0.00833693705499f); - const svfloat32_t exp_tab_7 = svdup_n_f32(0.166665703058f); - const svfloat32_t exp_tab_8 = svdup_n_f32(0.000195780929062f); - - // Perform range reduction [-log(2),log(2)] - auto m = svcvt_s32_f32_z(pg, svmul_f32_z(pg, x, CONST_INV_LN2)); - auto val = svmls_f32_z(pg, x, svcvt_f32_s32_z(pg, m), CONST_LN2); - - // Polynomial Approximation - auto poly = svtaylor_poly_f32_z(pg, val, exp_tab_1, exp_tab_2, exp_tab_3, exp_tab_4, exp_tab_5, exp_tab_6, exp_tab_7, exp_tab_8); + const auto c1 = svreinterpret_f32_u32(svdup_n_u32(svexp_f32_coeff[0])); + const auto c2 = svreinterpret_f32_u32(svdup_n_u32(svexp_f32_coeff[1])); + const auto c3 = svreinterpret_f32_u32(svdup_n_u32(svexp_f32_coeff[2])); + const auto c4 = svreinterpret_f32_u32(svdup_n_u32(svexp_f32_coeff[3])); + const auto c5 = svreinterpret_f32_u32(svdup_n_u32(svexp_f32_coeff[4])); + + const auto shift = svreinterpret_f32_u32(svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto inv_ln2 = svreinterpret_f32_u32(svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f + const auto neg_ln2_hi = + svreinterpret_f32_u32(svdup_n_u32(0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f + const auto neg_ln2_lo = + svreinterpret_f32_u32(svdup_n_u32(0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f + + const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity()); + const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) + const auto zero = svdup_n_f32(0.f); + const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) + + // Range reduction: + // e^x = 2^n * e^r + // where: + // n = floor(x / ln(2)) + // r = x - n * ln(2) + // + // By adding x / ln(2) with 2^23 + 127 (shift): + // * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 forces decimal part + // of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n) + 127 will occupy + // the whole fraction part of z in FP32 format. + // Subtracting 2^23 + 127 (shift) from z will result in the integer part of x / ln(2) + // (i.e. n) because the decimal part has been pushed out and lost. + // * The addition of 127 makes the FP32 fraction part of z ready to be used as the exponent + // in FP32 format. Left shifting z by 23 bits will result in 2^n. + const auto z = svmla_f32_z(pg, shift, x, inv_ln2); + const auto n = svsub_f32_z(pg, z, shift); + const auto scale = svreinterpret_f32_u32(svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n + + // The calculation of n * ln(2) is done using 2 steps to achieve accuracy beyond FP32. + // This outperforms longer Taylor series (3-4 tabs) both in term of accuracy and performance. + const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi); + const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo); + + // Compute the truncated Taylor series of e^r. + // poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) + const auto r2 = svmul_f32_z(pg, r, r); + + const auto p1 = svmul_f32_z(pg, c1, r); + const auto p23 = svmla_f32_z(pg, c2, c3, r); + const auto p45 = svmla_f32_z(pg, c4, c5, r); + const auto p2345 = svmla_f32_z(pg, p23, p45, r2); + const auto p12345 = svmla_f32_z(pg, p1, p2345, r2); + + auto poly = svmla_f32_z(pg, scale, p12345, scale); + + // Handle underflow and overflow. + poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly); + poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly); - // Reconstruct - poly = svreinterpret_f32_s32(svqadd_s32(svreinterpret_s32_f32(poly), svlsl_n_s32_z(pg, m, 23))); + return poly; +} - // Handle underflow - svbool_t ltpg = svcmplt_s32(pg, m, CONST_NEGATIVE_126); - poly = svsel_f32(ltpg, CONST_0, poly); +inline svfloat16_t svexp_f16_z(svbool_t pg, svfloat16_t x) +{ + auto bottom = svcvt_f32_z(pg, x); + auto pg_top = svptrue_b16(); + auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(x)))); - // Handle overflow - svbool_t gtpg = svcmpgt_f32(pg, x, CONST_MAX_INPUT); - poly = svsel_f32(gtpg, CONST_INF, poly); + bottom = svexp_f32_z(pg, bottom); + top = svexp_f32_z(pg_top, top); - return poly; + return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top)); } -inline svfloat16_t svexp_f16_z(svbool_t pg, svfloat16_t x) +#ifdef ARM_COMPUTE_ENABLE_SVE2 + +inline svfloat16_t svexp_f16_z_sve2(svbool_t pg, svfloat16_t x) { auto bottom = svcvt_f32_z(pg, x); -#if defined(__ARM_FEATURE_SVE2) auto top = svcvtlt_f32_x(pg, x); auto pg_top = pg; -#else /* defined(__ARM_FEATURE_SVE2) */ - auto pg_top = svptrue_b16(); - auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(x)))); -#endif /* defined(__ARM_FEATURE_SVE2) */ bottom = svexp_f32_z(pg, bottom); top = svexp_f32_z(pg_top, top); -#if defined(__ARM_FEATURE_SVE2) return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); -#else /* defined(__ARM_FEATURE_SVE2) */ - return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top)); -#endif /* defined(__ARM_FEATURE_SVE2) */ } +#endif // ARM_COMPUTE_ENABLE_SVE2 + inline svfloat32_t svtanh_f32_z(svbool_t pg, svfloat32_t val) { const svfloat32_t CONST_1 = svdup_n_f32(1.f); @@ -185,7 +242,8 @@ inline svfloat32_t svlog_f32_z(svbool_t pg, svfloat32_t x) auto val = svreinterpret_f32_s32(svsub_s32_z(pg, svreinterpret_s32_f32(x), svlsl_n_s32_z(pg, m, 23))); // Polynomial Approximation - auto poly = svtaylor_poly_f32_z(pg, val, log_tab_1, log_tab_2, log_tab_3, log_tab_4, log_tab_5, log_tab_6, log_tab_7, log_tab_8); + auto poly = svtaylor_poly_f32_z(pg, val, log_tab_1, log_tab_2, log_tab_3, log_tab_4, log_tab_5, log_tab_6, + log_tab_7, log_tab_8); // Reconstruct poly = svmla_f32_z(pg, poly, svcvt_f32_s32_z(pg, m), CONST_LN2); @@ -196,24 +254,31 @@ inline svfloat32_t svlog_f32_z(svbool_t pg, svfloat32_t x) inline svfloat16_t svlog_f16_z(svbool_t pg, svfloat16_t x) { auto bottom = svcvt_f32_z(pg, x); -#if defined(__ARM_FEATURE_SVE2) - auto top = svcvtlt_f32_x(pg, x); - auto pg_top = pg; -#else /* defined(__ARM_FEATURE_SVE2) */ auto pg_top = svptrue_b16(); auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(x)))); -#endif /* defined(__ARM_FEATURE_SVE2) */ bottom = svlog_f32_z(pg, bottom); top = svlog_f32_z(pg_top, top); -#if defined(__ARM_FEATURE_SVE2) - return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); -#else /* defined(__ARM_FEATURE_SVE2) */ return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top)); -#endif /* defined(__ARM_FEATURE_SVE2) */ } +#ifdef ARM_COMPUTE_ENABLE_SVE2 + +inline svfloat16_t svlog_f16_z_sve2(svbool_t pg, svfloat16_t x) +{ + auto bottom = svcvt_f32_z(pg, x); + auto top = svcvtlt_f32_x(pg, x); + auto pg_top = pg; + + bottom = svlog_f32_z(pg, bottom); + top = svlog_f32_z(pg_top, top); + + return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); +} + +#endif // ARM_COMPUTE_ENABLE_SVE2 + inline svfloat32_t svsin_f32_z(svbool_t pg, svfloat32_t val) { using ScalarType = float; @@ -231,7 +296,8 @@ inline svfloat32_t svsin_f32_z(svbool_t pg, svfloat32_t val) //Find positive or negative const auto c_v = svabs_z(pg, wrapper::svcvt_z<int32_t>(pg, svmul_z(pg, val, ipi_v))); const auto sign_v = svcmple(pg, val, wrapper::svdup_n(ScalarType(0))); - const auto odd_v = svcmpne(pg, svand_z(pg, wrapper::svreinterpret<IntType>(c_v), wrapper::svdup_n(IntType(1))), wrapper::svdup_n(IntType(0))); + const auto odd_v = svcmpne(pg, svand_z(pg, wrapper::svreinterpret<IntType>(c_v), wrapper::svdup_n(IntType(1))), + wrapper::svdup_n(IntType(0))); auto neg_v = sveor_z(pg, odd_v, sign_v); @@ -269,24 +335,31 @@ inline svfloat32_t svsin_f32_z(svbool_t pg, svfloat32_t val) inline svfloat16_t svsin_f16_z(svbool_t pg, svfloat16_t val) { auto bottom = svcvt_f32_z(pg, val); -#if defined(__ARM_FEATURE_SVE2) - auto top = svcvtlt_f32_x(pg, val); - auto pg_top = pg; -#else /* defined(__ARM_FEATURE_SVE2) */ auto pg_top = svptrue_b16(); auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(val)))); -#endif /* defined(__ARM_FEATURE_SVE2) */ bottom = svsin_f32_z(pg, bottom); top = svsin_f32_z(pg_top, top); -#if defined(__ARM_FEATURE_SVE2) - return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); -#else /* defined(__ARM_FEATURE_SVE2) */ return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top)); -#endif /* defined(__ARM_FEATURE_SVE2) */ } +#ifdef ARM_COMPUTE_ENABLE_SVE2 + +inline svfloat16_t svsin_f16_z_sve2(svbool_t pg, svfloat16_t val) +{ + auto bottom = svcvt_f32_z(pg, val); + auto top = svcvtlt_f32_x(pg, val); + auto pg_top = pg; + + bottom = svsin_f32_z(pg, bottom); + top = svsin_f32_z(pg_top, top); + + return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); +} + +#endif // ARM_COMPUTE_ENABLE_SVE2 + inline svfloat32_t svpow_f32_z(svbool_t pg, svfloat32_t a, svfloat32_t b) { return svexp_f32_z(pg, svmul_z(pg, b, svlog_f32_z(pg, a))); @@ -297,29 +370,41 @@ inline svfloat16_t svpow_f16_z(svbool_t pg, svfloat16_t a, svfloat16_t b) auto a_bottom = svcvt_f32_z(pg, a); auto b_bottom = svcvt_f32_z(pg, b); -#if defined(__ARM_FEATURE_SVE2) - auto pg_top = pg; - auto a_top = svcvtlt_f32_x(pg, a); - auto b_top = svcvtlt_f32_x(pg, b); -#else /* defined(__ARM_FEATURE_SVE2) */ auto pg_top = svptrue_b16(); auto a_top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(a)))); auto b_top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(b)))); -#endif /* defined(__ARM_FEATURE_SVE2) */ auto res_bottom = svpow_f32_z(pg, a_bottom, b_bottom); auto res_top = svpow_f32_z(pg_top, a_top, b_top); -#if defined(__ARM_FEATURE_SVE2) - return svcvtnt_f16_m(svcvt_f16_z(pg, res_bottom), pg_top, res_top); -#else /* defined(__ARM_FEATURE_SVE2) */ return svtrn1(svcvt_f16_z(pg, res_bottom), svcvt_f16_z(pg_top, res_top)); -#endif /* defined(__ARM_FEATURE_SVE2) */ } -#if defined(__ARM_FEATURE_SVE2) +#ifdef ARM_COMPUTE_ENABLE_SVE2 + +inline svfloat16_t svpow_f16_z_sve2(svbool_t pg, svfloat16_t a, svfloat16_t b) +{ + auto a_bottom = svcvt_f32_z(pg, a); + auto b_bottom = svcvt_f32_z(pg, b); + + auto pg_top = pg; + auto a_top = svcvtlt_f32_x(pg, a); + auto b_top = svcvtlt_f32_x(pg, b); + + auto res_bottom = svpow_f32_z(pg, a_bottom, b_bottom); + auto res_top = svpow_f32_z(pg_top, a_top, b_top); + + return svcvtnt_f16_m(svcvt_f16_z(pg, res_bottom), pg_top, res_top); +} + +#endif // ARM_COMPUTE_ENABLE_SVE2 + +#if defined(ARM_COMPUTE_ENABLE_SVE2) template <> -inline svuint8_t convert_float_to_int<svuint8_t>(const svfloat32_t &in_0, const svfloat32_t &in_1, const svfloat32_t &in_2, const svfloat32_t &in_3) +inline svuint8_t convert_float_to_int<svuint8_t>(const svfloat32_t &in_0, + const svfloat32_t &in_1, + const svfloat32_t &in_2, + const svfloat32_t &in_3) { svuint8_t out; const auto all_true_pg = svptrue_b32(); @@ -353,7 +438,10 @@ inline svuint8_t convert_float_to_int<svuint8_t>(const svfloat32_t &in_0, const } template <> -inline svint8_t convert_float_to_int<svint8_t>(const svfloat32_t &in_0, const svfloat32_t &in_1, const svfloat32_t &in_2, const svfloat32_t &in_3) +inline svint8_t convert_float_to_int<svint8_t>(const svfloat32_t &in_0, + const svfloat32_t &in_1, + const svfloat32_t &in_2, + const svfloat32_t &in_3) { svint8_t out; const auto all_true_pg = svptrue_b32(); @@ -385,7 +473,9 @@ inline svint8_t convert_float_to_int<svint8_t>(const svfloat32_t &in_0, const sv return out; } -#endif /* defined(__ARM_FEATURE_SVE2) */ +#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ } // namespace arm_compute -#endif /* defined(ENABLE_SVE) */ +#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */ + +#endif // ACL_SRC_CORE_NEON_SVEMATH_INL |