From 5ef0bdd53dd2ce6bc7ad28077ffac3bf9e939b5f Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 19 Oct 2023 10:15:54 +0100 Subject: Fix SVE kernel using SVE2 instruction Resolves: COMPMID-6493 Signed-off-by: Viet-Hoa Do Change-Id: I038d91ba266e1e8bf124336bcd272ec77e92038c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10490 Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Anitha Raj Comments-Addressed: Arm Jenkins --- src/core/NEON/SVEMath.h | 63 ++++++++++++++++++++++++-- src/core/NEON/SVEMath.inl | 110 ++++++++++++++++++++++++++++++---------------- 2 files changed, 132 insertions(+), 41 deletions(-) (limited to 'src/core/NEON') diff --git a/src/core/NEON/SVEMath.h b/src/core/NEON/SVEMath.h index 6d69b330ba..49ed9df720 100644 --- a/src/core/NEON/SVEMath.h +++ b/src/core/NEON/SVEMath.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_SVEMATH_H -#define ARM_COMPUTE_SVEMATH_H +#ifndef ACL_SRC_CORE_NEON_SVEMATH_H +#define ACL_SRC_CORE_NEON_SVEMATH_H #if defined(ARM_COMPUTE_ENABLE_SVE) #include "src/core/NEON/wrapper/intrinsics/svcvt.h" @@ -96,6 +96,19 @@ svfloat16_t svtanh_f16_z(svbool_t pg, svfloat16_t val); */ svfloat16_t svexp_f16_z(svbool_t pg, svfloat16_t x); +#ifdef ARM_COMPUTE_ENABLE_SVE2 + +/** Calculate exponential + * + * @param[in] pg Input predicate. + * @param[in] x Input vector value in F16 format. + * + * @return The calculated exponent. + */ +svfloat16_t svexp_f16_z_sve2(svbool_t pg, svfloat16_t x); + +#endif // ARM_COMPUTE_ENABLE_SVE2 + /** Calculate reciprocal. * * @param[in] pg Input predicate. @@ -114,6 +127,19 @@ svfloat16_t svinv_f16_z(svbool_t pg, svfloat16_t x); */ svfloat16_t svlog_f16_z(svbool_t pg, svfloat16_t x); +#ifdef ARM_COMPUTE_ENABLE_SVE2 + +/** Calculate logarithm + * + * @param[in] pg Input predicate. + * @param[in] x Input vector value in F32 format. + * + * @return The calculated logarithm. + */ +svfloat16_t svlog_f16_z_sve2(svbool_t pg, svfloat16_t x); + +#endif // ARM_COMPUTE_ENABLE_SVE2 + /** Calculate inverse square root. * * @param[in] pg Input predicate. @@ -148,6 +174,19 @@ svfloat32_t svsin_f32_z(svbool_t pg, svfloat32_t val); */ svfloat16_t svsin_f16_z(svbool_t pg, svfloat16_t val); +#ifdef ARM_COMPUTE_ENABLE_SVE2 + +/** Calculate sine. + * + * @param[in] pg Input predicate. + * @param[in] val Input vector value in radians, F16 format. + * + * @return The calculated sine. + */ +svfloat16_t svsin_f16_z_sve2(svbool_t pg, svfloat16_t val); + +#endif // ARM_COMPUTE_ENABLE_SVE2 + /** Calculate n power of a number. * * pow(x,n) = e^(n*log(x)) @@ -172,6 +211,22 @@ svfloat32_t svpow_f32_z(svbool_t pg, svfloat32_t a, svfloat32_t b); */ svfloat16_t svpow_f16_z(svbool_t pg, svfloat16_t a, svfloat16_t b); +#ifdef ARM_COMPUTE_ENABLE_SVE2 + +/** Calculate n power of a number. + * + * pow(x,n) = e^(n*log(x)) + * + * @param[in] pg Input predicate. + * @param[in] a Input vector value in F16 format. + * @param[in] b Powers to raise the input to. + * + * @return The calculated power. + */ +svfloat16_t svpow_f16_z_sve2(svbool_t pg, svfloat16_t a, svfloat16_t b); + +#endif // ARM_COMPUTE_ENABLE_SVE2 + /** Convert and pack four 32-bit float vectors into an 8-bit integer vector * * @param[in] in_0 The first float vector @@ -190,4 +245,4 @@ int_vec_type convert_float_to_int(const svfloat32_t &in_0, } // namespace arm_compute #include "src/core/NEON/SVEMath.inl" #endif /* defined(ARM_COMPUTE_ENABLE_SVE) */ -#endif /* ARM_COMPUTE_SVEMATH_H */ +#endif // ACL_SRC_CORE_NEON_SVEMATH_H diff --git a/src/core/NEON/SVEMath.inl b/src/core/NEON/SVEMath.inl index b30125dcb7..fdf94f0859 100644 --- a/src/core/NEON/SVEMath.inl +++ b/src/core/NEON/SVEMath.inl @@ -21,6 +21,10 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + +#ifndef ACL_SRC_CORE_NEON_SVEMATH_INL +#define ACL_SRC_CORE_NEON_SVEMATH_INL + #include #include @@ -163,24 +167,31 @@ inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) inline svfloat16_t svexp_f16_z(svbool_t pg, svfloat16_t x) { auto bottom = svcvt_f32_z(pg, x); -#if defined(ARM_COMPUTE_ENABLE_SVE2) - auto top = svcvtlt_f32_x(pg, x); - auto pg_top = pg; -#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */ auto pg_top = svptrue_b16(); auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(x)))); -#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ bottom = svexp_f32_z(pg, bottom); top = svexp_f32_z(pg_top, top); -#if defined(ARM_COMPUTE_ENABLE_SVE2) - return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); -#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */ return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top)); -#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ } +#ifdef ARM_COMPUTE_ENABLE_SVE2 + +inline svfloat16_t svexp_f16_z_sve2(svbool_t pg, svfloat16_t x) +{ + auto bottom = svcvt_f32_z(pg, x); + auto top = svcvtlt_f32_x(pg, x); + auto pg_top = pg; + + bottom = svexp_f32_z(pg, bottom); + top = svexp_f32_z(pg_top, top); + + return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); +} + +#endif // ARM_COMPUTE_ENABLE_SVE2 + inline svfloat32_t svtanh_f32_z(svbool_t pg, svfloat32_t val) { const svfloat32_t CONST_1 = svdup_n_f32(1.f); @@ -243,24 +254,31 @@ inline svfloat32_t svlog_f32_z(svbool_t pg, svfloat32_t x) inline svfloat16_t svlog_f16_z(svbool_t pg, svfloat16_t x) { auto bottom = svcvt_f32_z(pg, x); -#if defined(ARM_COMPUTE_ENABLE_SVE2) - auto top = svcvtlt_f32_x(pg, x); - auto pg_top = pg; -#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */ auto pg_top = svptrue_b16(); auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(x)))); -#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ bottom = svlog_f32_z(pg, bottom); top = svlog_f32_z(pg_top, top); -#if defined(ARM_COMPUTE_ENABLE_SVE2) - return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); -#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */ return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top)); -#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ } +#ifdef ARM_COMPUTE_ENABLE_SVE2 + +inline svfloat16_t svlog_f16_z_sve2(svbool_t pg, svfloat16_t x) +{ + auto bottom = svcvt_f32_z(pg, x); + auto top = svcvtlt_f32_x(pg, x); + auto pg_top = pg; + + bottom = svlog_f32_z(pg, bottom); + top = svlog_f32_z(pg_top, top); + + return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); +} + +#endif // ARM_COMPUTE_ENABLE_SVE2 + inline svfloat32_t svsin_f32_z(svbool_t pg, svfloat32_t val) { using ScalarType = float; @@ -317,24 +335,31 @@ inline svfloat32_t svsin_f32_z(svbool_t pg, svfloat32_t val) inline svfloat16_t svsin_f16_z(svbool_t pg, svfloat16_t val) { auto bottom = svcvt_f32_z(pg, val); -#if defined(ARM_COMPUTE_ENABLE_SVE2) - auto top = svcvtlt_f32_x(pg, val); - auto pg_top = pg; -#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */ auto pg_top = svptrue_b16(); auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(val)))); -#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ bottom = svsin_f32_z(pg, bottom); top = svsin_f32_z(pg_top, top); -#if defined(ARM_COMPUTE_ENABLE_SVE2) - return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); -#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */ return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top)); -#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ } +#ifdef ARM_COMPUTE_ENABLE_SVE2 + +inline svfloat16_t svsin_f16_z_sve2(svbool_t pg, svfloat16_t val) +{ + auto bottom = svcvt_f32_z(pg, val); + auto top = svcvtlt_f32_x(pg, val); + auto pg_top = pg; + + bottom = svsin_f32_z(pg, bottom); + top = svsin_f32_z(pg_top, top); + + return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); +} + +#endif // ARM_COMPUTE_ENABLE_SVE2 + inline svfloat32_t svpow_f32_z(svbool_t pg, svfloat32_t a, svfloat32_t b) { return svexp_f32_z(pg, svmul_z(pg, b, svlog_f32_z(pg, a))); @@ -345,26 +370,35 @@ inline svfloat16_t svpow_f16_z(svbool_t pg, svfloat16_t a, svfloat16_t b) auto a_bottom = svcvt_f32_z(pg, a); auto b_bottom = svcvt_f32_z(pg, b); -#if defined(ARM_COMPUTE_ENABLE_SVE2) - auto pg_top = pg; - auto a_top = svcvtlt_f32_x(pg, a); - auto b_top = svcvtlt_f32_x(pg, b); -#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */ auto pg_top = svptrue_b16(); auto a_top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(a)))); auto b_top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(b)))); -#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ auto res_bottom = svpow_f32_z(pg, a_bottom, b_bottom); auto res_top = svpow_f32_z(pg_top, a_top, b_top); -#if defined(ARM_COMPUTE_ENABLE_SVE2) - return svcvtnt_f16_m(svcvt_f16_z(pg, res_bottom), pg_top, res_top); -#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */ return svtrn1(svcvt_f16_z(pg, res_bottom), svcvt_f16_z(pg_top, res_top)); -#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ } +#ifdef ARM_COMPUTE_ENABLE_SVE2 + +inline svfloat16_t svpow_f16_z_sve2(svbool_t pg, svfloat16_t a, svfloat16_t b) +{ + auto a_bottom = svcvt_f32_z(pg, a); + auto b_bottom = svcvt_f32_z(pg, b); + + auto pg_top = pg; + auto a_top = svcvtlt_f32_x(pg, a); + auto b_top = svcvtlt_f32_x(pg, b); + + auto res_bottom = svpow_f32_z(pg, a_bottom, b_bottom); + auto res_top = svpow_f32_z(pg_top, a_top, b_top); + + return svcvtnt_f16_m(svcvt_f16_z(pg, res_bottom), pg_top, res_top); +} + +#endif // ARM_COMPUTE_ENABLE_SVE2 + #if defined(ARM_COMPUTE_ENABLE_SVE2) template <> inline svuint8_t convert_float_to_int(const svfloat32_t &in_0, @@ -443,3 +477,5 @@ inline svint8_t convert_float_to_int(const svfloat32_t &in_0, } // namespace arm_compute #endif /* defined(ARM_COMPUTE_ENABLE_SVE) */ + +#endif // ACL_SRC_CORE_NEON_SVEMATH_INL -- cgit v1.2.1