diff options
Diffstat (limited to 'src/core/NEON/SVEMath.inl')
-rw-r--r-- | src/core/NEON/SVEMath.inl | 201 |
1 files changed, 132 insertions, 69 deletions
diff --git a/src/core/NEON/SVEMath.inl b/src/core/NEON/SVEMath.inl index 5ebfeaa5c5..86592f6dc3 100644 --- a/src/core/NEON/SVEMath.inl +++ b/src/core/NEON/SVEMath.inl @@ -26,6 +26,10 @@ #if defined(__ARM_FEATURE_SVE) +#ifndef M_PI +#define M_PI (3.14159265358979323846) +#endif // M_PI + namespace arm_compute { inline svfloat32_t svtaylor_poly_f32_z(svbool_t pg, svfloat32_t x, const std::array<svfloat32_t, 8> &coeffs) @@ -115,47 +119,23 @@ inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) inline svfloat16_t svexp_f16_z(svbool_t pg, svfloat16_t x) { - const auto CONST_LN2 = svdup_n_f16(0.6931471805f); // ln(2) - const auto CONST_INV_LN2 = svdup_n_f16(1.4426950408f); // 1/ln(2) - const auto CONST_INF = svdup_n_f16(std::numeric_limits<float16_t>::infinity()); - const auto CONST_MAX_INPUT = svdup_n_f16(88.7f); - const auto CONST_0 = svdup_n_f16(0.f); - const auto CONST_NEGATIVE_126 = svdup_n_s16(-126); - - /** Exponent polynomial coefficients */ - const std::array<svfloat16_t, 8> exp_tab = - { - { - svdup_n_f16(1.f), - svdup_n_f16(0.0416598916054f), - svdup_n_f16(0.500000596046f), - svdup_n_f16(0.0014122662833f), - svdup_n_f16(1.00000011921f), - svdup_n_f16(0.00833693705499f), - svdup_n_f16(0.166665703058f), - svdup_n_f16(0.000195780929062f), - } - }; - - // Perform range reduction [-log(2),log(2)] - auto m = svcvt_s16_f16_z(pg, svmul_f16_z(pg, x, CONST_INV_LN2)); - auto val = svmls_f16_z(pg, x, svcvt_f16_s16_z(pg, m), CONST_LN2); - - // Polynomial Approximation - auto poly = svtaylor_poly_f16_z(pg, val, exp_tab); - - // Reconstruct - poly = svreinterpret_f16_s16(svqadd_s16(svreinterpret_s16_f16(poly), svlsl_n_s16_z(pg, m, 11))); - - // Handle underflow - svbool_t ltpg = svcmplt_s16(pg, m, CONST_NEGATIVE_126); - poly = svsel_f16(ltpg, CONST_0, poly); + auto bottom = svcvt_f32_z(pg, x); +#if defined(__ARM_FEATURE_SVE2) + auto top = svcvtlt_f32_x(pg, x); + auto pg_top = pg; +#else /* defined(__ARM_FEATURE_SVE2) */ + auto pg_top = svptrue_b16(); + auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(x)))); +#endif /* defined(__ARM_FEATURE_SVE2) */ - // Handle overflow - svbool_t gtpg = svcmpgt_f16(pg, x, CONST_MAX_INPUT); - poly = svsel_f16(gtpg, CONST_INF, poly); + bottom = svexp_f32_z(pg, bottom); + top = svexp_f32_z(pg_top, top); - return poly; +#if defined(__ARM_FEATURE_SVE2) + return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); +#else /* defined(__ARM_FEATURE_SVE2) */ + return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top)); +#endif /* defined(__ARM_FEATURE_SVE2) */ } inline svfloat32_t svtanh_f32_z(svbool_t pg, svfloat32_t val) @@ -190,9 +170,6 @@ inline svfloat16_t svtanh_f16_z(svbool_t pg, svfloat16_t val) inline svfloat32_t svlog_f32_z(svbool_t pg, svfloat32_t x) { -#if defined(__ARM_FEATURE_SVE2) - return svcvt_f32_s32_z(pg, svlogb_f32_z(pg, x)); -#else /* !defined(__ARM_FEATURE_SVE2) */ /** Logarithm polynomial coefficients */ const std::array<svfloat32_t, 8> log_tab = { @@ -222,45 +199,131 @@ inline svfloat32_t svlog_f32_z(svbool_t pg, svfloat32_t x) poly = svmla_f32_z(pg, poly, svcvt_f32_s32_z(pg, m), CONST_LN2); return poly; -#endif /* defined(__ARM_FEATURE_SVE2) */ } inline svfloat16_t svlog_f16_z(svbool_t pg, svfloat16_t x) { + auto bottom = svcvt_f32_z(pg, x); #if defined(__ARM_FEATURE_SVE2) - return svcvt_f16_s16_z(pg, svlogb_f16_z(pg, x)); -#else /* !defined(__ARM_FEATURE_SVE2) */ + auto top = svcvtlt_f32_x(pg, x); + auto pg_top = pg; +#else /* defined(__ARM_FEATURE_SVE2) */ + auto pg_top = svptrue_b16(); + auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(x)))); +#endif /* defined(__ARM_FEATURE_SVE2) */ - /** Logarithm polynomial coefficients */ - const std::array<svfloat16_t, 8> log_tab - { - { - svdup_n_f16(-2.29561495781f), - svdup_n_f16(-2.47071170807f), - svdup_n_f16(-5.68692588806f), - svdup_n_f16(-0.165253549814f), - svdup_n_f16(5.17591238022f), - svdup_n_f16(0.844007015228f), - svdup_n_f16(4.58445882797f), - svdup_n_f16(0.0141278216615f), - } - }; + bottom = svlog_f32_z(pg, bottom); + top = svlog_f32_z(pg_top, top); - const auto CONST_7 = svdup_n_s16(7); // 7 - const auto CONST_LN2 = svdup_n_f16(0.6931471805f); // ln(2) +#if defined(__ARM_FEATURE_SVE2) + return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); +#else /* defined(__ARM_FEATURE_SVE2) */ + return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top)); +#endif /* defined(__ARM_FEATURE_SVE2) */ +} - // Extract exponent - auto m = svsub_s16_z(pg, svasr_n_s16_z(pg, svreinterpret_s16_f16(x), 11), CONST_7); - auto val = svreinterpret_f16_s16(svsub_s16_z(pg, svreinterpret_s16_f16(x), svlsl_n_s16_z(pg, m, 11))); +inline svfloat32_t svsin_f32_z(svbool_t pg, svfloat32_t val) +{ + using ScalarType = float; + using IntType = u32; - // Polynomial Approximation - auto poly = svtaylor_poly_f16_z(pg, val, log_tab); + constexpr float te_sin_coeff2 = 0.166666666666f; // 1/(2*3) + constexpr float te_sin_coeff3 = 0.05f; // 1/(4*5) + constexpr float te_sin_coeff4 = 0.023809523810f; // 1/(6*7) + constexpr float te_sin_coeff5 = 0.013888888889f; // 1/(8*9) - // Reconstruct - poly = svmla_f16_z(pg, poly, svcvt_f16_s16_z(pg, m), CONST_LN2); + const auto pi_v = wrapper::svdup_n(ScalarType(M_PI)); + const auto pio2_v = wrapper::svdup_n(ScalarType(M_PI / 2)); + const auto ipi_v = wrapper::svdup_n(ScalarType(1 / M_PI)); - return poly; + //Find positive or negative + const auto c_v = svabs_z(pg, wrapper::svcvt_z<int32_t>(pg, svmul_z(pg, val, ipi_v))); + const auto sign_v = svcmple(pg, val, wrapper::svdup_n(ScalarType(0))); + const auto odd_v = svcmpne(pg, svand_z(pg, wrapper::svreinterpret<IntType>(c_v), wrapper::svdup_n(IntType(1))), wrapper::svdup_n(IntType(0))); + + auto neg_v = sveor_z(pg, odd_v, sign_v); + + //Modulus a - (n * int(a*(1/n))) + auto ma = svsub_z(pg, svabs_z(pg, val), svmul_z(pg, pi_v, wrapper::svcvt_z<ScalarType>(pg, c_v))); + const auto reb_v = svcmpge(pg, ma, pio2_v); + + //Rebase a between 0 and pi/2 + ma = svsel(reb_v, svsub_z(pg, pi_v, ma), ma); + + //Taylor series + const auto ma2 = svmul_z(pg, ma, ma); + + //2nd elem: x^3 / 3! + auto elem = svmul_z(pg, svmul_z(pg, ma, ma2), wrapper::svdup_n(ScalarType(te_sin_coeff2))); + auto res = svsub_z(pg, ma, elem); + + //3rd elem: x^5 / 5! + elem = svmul_z(pg, svmul_z(pg, elem, ma2), wrapper::svdup_n(ScalarType(te_sin_coeff3))); + res = svadd_z(pg, res, elem); + + //4th elem: x^7 / 7!float32x2_t vsin_f32(float32x2_t val) + elem = svmul_z(pg, svmul_z(pg, elem, ma2), wrapper::svdup_n(ScalarType(te_sin_coeff4))); + res = svsub_z(pg, res, elem); + + //5th elem: x^9 / 9! + elem = svmul_z(pg, svmul_z(pg, elem, ma2), wrapper::svdup_n(ScalarType(te_sin_coeff5))); + res = svadd_z(pg, res, elem); + + //Change of sign + res = svneg_m(res, neg_v, res); + return res; +} + +inline svfloat16_t svsin_f16_z(svbool_t pg, svfloat16_t val) +{ + auto bottom = svcvt_f32_z(pg, val); +#if defined(__ARM_FEATURE_SVE2) + auto top = svcvtlt_f32_x(pg, val); + auto pg_top = pg; +#else /* defined(__ARM_FEATURE_SVE2) */ + auto pg_top = svptrue_b16(); + auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(val)))); +#endif /* defined(__ARM_FEATURE_SVE2) */ + + bottom = svsin_f32_z(pg, bottom); + top = svsin_f32_z(pg_top, top); + +#if defined(__ARM_FEATURE_SVE2) + return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); +#else /* defined(__ARM_FEATURE_SVE2) */ + return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top)); #endif /* defined(__ARM_FEATURE_SVE2) */ } + +inline svfloat32_t svpow_f32_z(svbool_t pg, svfloat32_t a, svfloat32_t b) +{ + return svexp_f32_z(pg, svmul_z(pg, b, svlog_f32_z(pg, a))); +} + +inline svfloat16_t svpow_f16_z(svbool_t pg, svfloat16_t a, svfloat16_t b) +{ + auto a_bottom = svcvt_f32_z(pg, a); + auto b_bottom = svcvt_f32_z(pg, b); + +#if defined(__ARM_FEATURE_SVE2) + auto pg_top = pg; + auto a_top = svcvtlt_f32_x(pg, a); + auto b_top = svcvtlt_f32_x(pg, b) +#else /* defined(__ARM_FEATURE_SVE2) */ + auto pg_top = svptrue_b16(); + auto a_top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(a)))); + auto b_top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(b)))); +#endif /* defined(__ARM_FEATURE_SVE2) */ + + auto res_bottom = svpow_f32_z(pg, a_bottom, b_bottom); + auto res_top = svpow_f32_z(pg_top, a_top, b_top); + +#if defined(__ARM_FEATURE_SVE2) + return svcvtnt_f16_m(svcvt_f16_z(pg, res_bottom), pg_top, res_top); +#else /* defined(__ARM_FEATURE_SVE2) */ + return svtrn1(svcvt_f16_z(pg, res_bottom), svcvt_f16_z(pg_top, res_top)); +#endif /* defined(__ARM_FEATURE_SVE2) */ +} + } // namespace arm_compute #endif /* defined(__ARM_FEATURE_SVE) */ |