aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-10-19 10:15:54 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-10-31 10:17:30 +0000
commit5ef0bdd53dd2ce6bc7ad28077ffac3bf9e939b5f (patch)
tree9268cc52debd91a4c146493b53e28ea937fef4e1 /src/core/NEON
parent29254aeb11a76c86449c2f38587e9144b2f2aacb (diff)
downloadComputeLibrary-5ef0bdd53dd2ce6bc7ad28077ffac3bf9e939b5f.tar.gz
Fix SVE kernel using SVE2 instruction
Resolves: COMPMID-6493 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: I038d91ba266e1e8bf124336bcd272ec77e92038c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10490 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Anitha Raj <Anitha.Raj@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON')
-rw-r--r--src/core/NEON/SVEMath.h63
-rw-r--r--src/core/NEON/SVEMath.inl110
2 files changed, 132 insertions, 41 deletions
diff --git a/src/core/NEON/SVEMath.h b/src/core/NEON/SVEMath.h
index 6d69b330ba..49ed9df720 100644
--- a/src/core/NEON/SVEMath.h
+++ b/src/core/NEON/SVEMath.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2021 Arm Limited.
+ * Copyright (c) 2020-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_SVEMATH_H
-#define ARM_COMPUTE_SVEMATH_H
+#ifndef ACL_SRC_CORE_NEON_SVEMATH_H
+#define ACL_SRC_CORE_NEON_SVEMATH_H
#if defined(ARM_COMPUTE_ENABLE_SVE)
#include "src/core/NEON/wrapper/intrinsics/svcvt.h"
@@ -96,6 +96,19 @@ svfloat16_t svtanh_f16_z(svbool_t pg, svfloat16_t val);
*/
svfloat16_t svexp_f16_z(svbool_t pg, svfloat16_t x);
+#ifdef ARM_COMPUTE_ENABLE_SVE2
+
+/** Calculate exponential
+ *
+ * @param[in] pg Input predicate.
+ * @param[in] x Input vector value in F16 format.
+ *
+ * @return The calculated exponent.
+ */
+svfloat16_t svexp_f16_z_sve2(svbool_t pg, svfloat16_t x);
+
+#endif // ARM_COMPUTE_ENABLE_SVE2
+
/** Calculate reciprocal.
*
* @param[in] pg Input predicate.
@@ -114,6 +127,19 @@ svfloat16_t svinv_f16_z(svbool_t pg, svfloat16_t x);
*/
svfloat16_t svlog_f16_z(svbool_t pg, svfloat16_t x);
+#ifdef ARM_COMPUTE_ENABLE_SVE2
+
+/** Calculate logarithm
+ *
+ * @param[in] pg Input predicate.
+ * @param[in] x Input vector value in F32 format.
+ *
+ * @return The calculated logarithm.
+ */
+svfloat16_t svlog_f16_z_sve2(svbool_t pg, svfloat16_t x);
+
+#endif // ARM_COMPUTE_ENABLE_SVE2
+
/** Calculate inverse square root.
*
* @param[in] pg Input predicate.
@@ -148,6 +174,19 @@ svfloat32_t svsin_f32_z(svbool_t pg, svfloat32_t val);
*/
svfloat16_t svsin_f16_z(svbool_t pg, svfloat16_t val);
+#ifdef ARM_COMPUTE_ENABLE_SVE2
+
+/** Calculate sine.
+ *
+ * @param[in] pg Input predicate.
+ * @param[in] val Input vector value in radians, F16 format.
+ *
+ * @return The calculated sine.
+ */
+svfloat16_t svsin_f16_z_sve2(svbool_t pg, svfloat16_t val);
+
+#endif // ARM_COMPUTE_ENABLE_SVE2
+
/** Calculate n power of a number.
*
* pow(x,n) = e^(n*log(x))
@@ -172,6 +211,22 @@ svfloat32_t svpow_f32_z(svbool_t pg, svfloat32_t a, svfloat32_t b);
*/
svfloat16_t svpow_f16_z(svbool_t pg, svfloat16_t a, svfloat16_t b);
+#ifdef ARM_COMPUTE_ENABLE_SVE2
+
+/** Calculate n power of a number.
+ *
+ * pow(x,n) = e^(n*log(x))
+ *
+ * @param[in] pg Input predicate.
+ * @param[in] a Input vector value in F16 format.
+ * @param[in] b Powers to raise the input to.
+ *
+ * @return The calculated power.
+ */
+svfloat16_t svpow_f16_z_sve2(svbool_t pg, svfloat16_t a, svfloat16_t b);
+
+#endif // ARM_COMPUTE_ENABLE_SVE2
+
/** Convert and pack four 32-bit float vectors into an 8-bit integer vector
*
* @param[in] in_0 The first float vector
@@ -190,4 +245,4 @@ int_vec_type convert_float_to_int(const svfloat32_t &in_0,
} // namespace arm_compute
#include "src/core/NEON/SVEMath.inl"
#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */
-#endif /* ARM_COMPUTE_SVEMATH_H */
+#endif // ACL_SRC_CORE_NEON_SVEMATH_H
diff --git a/src/core/NEON/SVEMath.inl b/src/core/NEON/SVEMath.inl
index b30125dcb7..fdf94f0859 100644
--- a/src/core/NEON/SVEMath.inl
+++ b/src/core/NEON/SVEMath.inl
@@ -21,6 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
+#ifndef ACL_SRC_CORE_NEON_SVEMATH_INL
+#define ACL_SRC_CORE_NEON_SVEMATH_INL
+
#include <cmath>
#include <limits>
@@ -163,24 +167,31 @@ inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x)
inline svfloat16_t svexp_f16_z(svbool_t pg, svfloat16_t x)
{
auto bottom = svcvt_f32_z(pg, x);
-#if defined(ARM_COMPUTE_ENABLE_SVE2)
- auto top = svcvtlt_f32_x(pg, x);
- auto pg_top = pg;
-#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */
auto pg_top = svptrue_b16();
auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(x))));
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
bottom = svexp_f32_z(pg, bottom);
top = svexp_f32_z(pg_top, top);
-#if defined(ARM_COMPUTE_ENABLE_SVE2)
- return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top);
-#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */
return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top));
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
}
+#ifdef ARM_COMPUTE_ENABLE_SVE2
+
+inline svfloat16_t svexp_f16_z_sve2(svbool_t pg, svfloat16_t x)
+{
+ auto bottom = svcvt_f32_z(pg, x);
+ auto top = svcvtlt_f32_x(pg, x);
+ auto pg_top = pg;
+
+ bottom = svexp_f32_z(pg, bottom);
+ top = svexp_f32_z(pg_top, top);
+
+ return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top);
+}
+
+#endif // ARM_COMPUTE_ENABLE_SVE2
+
inline svfloat32_t svtanh_f32_z(svbool_t pg, svfloat32_t val)
{
const svfloat32_t CONST_1 = svdup_n_f32(1.f);
@@ -243,24 +254,31 @@ inline svfloat32_t svlog_f32_z(svbool_t pg, svfloat32_t x)
inline svfloat16_t svlog_f16_z(svbool_t pg, svfloat16_t x)
{
auto bottom = svcvt_f32_z(pg, x);
-#if defined(ARM_COMPUTE_ENABLE_SVE2)
- auto top = svcvtlt_f32_x(pg, x);
- auto pg_top = pg;
-#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */
auto pg_top = svptrue_b16();
auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(x))));
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
bottom = svlog_f32_z(pg, bottom);
top = svlog_f32_z(pg_top, top);
-#if defined(ARM_COMPUTE_ENABLE_SVE2)
- return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top);
-#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */
return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top));
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
}
+#ifdef ARM_COMPUTE_ENABLE_SVE2
+
+inline svfloat16_t svlog_f16_z_sve2(svbool_t pg, svfloat16_t x)
+{
+ auto bottom = svcvt_f32_z(pg, x);
+ auto top = svcvtlt_f32_x(pg, x);
+ auto pg_top = pg;
+
+ bottom = svlog_f32_z(pg, bottom);
+ top = svlog_f32_z(pg_top, top);
+
+ return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top);
+}
+
+#endif // ARM_COMPUTE_ENABLE_SVE2
+
inline svfloat32_t svsin_f32_z(svbool_t pg, svfloat32_t val)
{
using ScalarType = float;
@@ -317,24 +335,31 @@ inline svfloat32_t svsin_f32_z(svbool_t pg, svfloat32_t val)
inline svfloat16_t svsin_f16_z(svbool_t pg, svfloat16_t val)
{
auto bottom = svcvt_f32_z(pg, val);
-#if defined(ARM_COMPUTE_ENABLE_SVE2)
- auto top = svcvtlt_f32_x(pg, val);
- auto pg_top = pg;
-#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */
auto pg_top = svptrue_b16();
auto top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(val))));
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
bottom = svsin_f32_z(pg, bottom);
top = svsin_f32_z(pg_top, top);
-#if defined(ARM_COMPUTE_ENABLE_SVE2)
- return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top);
-#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */
return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top));
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
}
+#ifdef ARM_COMPUTE_ENABLE_SVE2
+
+inline svfloat16_t svsin_f16_z_sve2(svbool_t pg, svfloat16_t val)
+{
+ auto bottom = svcvt_f32_z(pg, val);
+ auto top = svcvtlt_f32_x(pg, val);
+ auto pg_top = pg;
+
+ bottom = svsin_f32_z(pg, bottom);
+ top = svsin_f32_z(pg_top, top);
+
+ return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top);
+}
+
+#endif // ARM_COMPUTE_ENABLE_SVE2
+
inline svfloat32_t svpow_f32_z(svbool_t pg, svfloat32_t a, svfloat32_t b)
{
return svexp_f32_z(pg, svmul_z(pg, b, svlog_f32_z(pg, a)));
@@ -345,26 +370,35 @@ inline svfloat16_t svpow_f16_z(svbool_t pg, svfloat16_t a, svfloat16_t b)
auto a_bottom = svcvt_f32_z(pg, a);
auto b_bottom = svcvt_f32_z(pg, b);
-#if defined(ARM_COMPUTE_ENABLE_SVE2)
- auto pg_top = pg;
- auto a_top = svcvtlt_f32_x(pg, a);
- auto b_top = svcvtlt_f32_x(pg, b);
-#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */
auto pg_top = svptrue_b16();
auto a_top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(a))));
auto b_top = svcvt_f32_z(pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(b))));
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
auto res_bottom = svpow_f32_z(pg, a_bottom, b_bottom);
auto res_top = svpow_f32_z(pg_top, a_top, b_top);
-#if defined(ARM_COMPUTE_ENABLE_SVE2)
- return svcvtnt_f16_m(svcvt_f16_z(pg, res_bottom), pg_top, res_top);
-#else /* defined(ARM_COMPUTE_ENABLE_SVE2) */
return svtrn1(svcvt_f16_z(pg, res_bottom), svcvt_f16_z(pg_top, res_top));
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
}
+#ifdef ARM_COMPUTE_ENABLE_SVE2
+
+inline svfloat16_t svpow_f16_z_sve2(svbool_t pg, svfloat16_t a, svfloat16_t b)
+{
+ auto a_bottom = svcvt_f32_z(pg, a);
+ auto b_bottom = svcvt_f32_z(pg, b);
+
+ auto pg_top = pg;
+ auto a_top = svcvtlt_f32_x(pg, a);
+ auto b_top = svcvtlt_f32_x(pg, b);
+
+ auto res_bottom = svpow_f32_z(pg, a_bottom, b_bottom);
+ auto res_top = svpow_f32_z(pg_top, a_top, b_top);
+
+ return svcvtnt_f16_m(svcvt_f16_z(pg, res_bottom), pg_top, res_top);
+}
+
+#endif // ARM_COMPUTE_ENABLE_SVE2
+
#if defined(ARM_COMPUTE_ENABLE_SVE2)
template <>
inline svuint8_t convert_float_to_int<svuint8_t>(const svfloat32_t &in_0,
@@ -443,3 +477,5 @@ inline svint8_t convert_float_to_int<svint8_t>(const svfloat32_t &in_0,
} // namespace arm_compute
#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */
+
+#endif // ACL_SRC_CORE_NEON_SVEMATH_INL