aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/utils.hpp
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2022-06-01 11:47:14 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2022-11-28 16:57:42 +0000
commit03b2971ac69a86f10a1566938d1a25afee15746c (patch)
treeaec7cfc047e1da278b4b71a706cda7b1b0faa158 /src/core/NEON/kernels/arm_gemm/utils.hpp
parent7dc0234331f2150a6b4ac5c2b49de419870f7cf5 (diff)
downloadComputeLibrary-03b2971ac69a86f10a1566938d1a25afee15746c.tar.gz
Integrate SME2 kernels
* Add SME/SME2 detection. * Integrate SME2 implementation for: - Normal convolution - Winograd - Depthwise convolution - Pooling Resolves: COMPMID-5700 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: I2f1ca1d05f8cfeee9309ed1c0a36096a4a6aad5c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8692 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/utils.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/utils.hpp21
1 files changed, 20 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp
index d7b5398488..a28ddadc68 100644
--- a/src/core/NEON/kernels/arm_gemm/utils.hpp
+++ b/src/core/NEON/kernels/arm_gemm/utils.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -80,6 +80,7 @@ inline T roundup(const T a, const T b) {
enum class VLType {
None,
SVE,
+ SME
};
template<typename T>
@@ -191,6 +192,20 @@ inline unsigned long get_vector_length() {
#endif // defined(__aarch64__)
}
+#ifdef ARM_COMPUTE_ENABLE_SME
+namespace sme {
+
+// function from misc-sve.cpp
+extern unsigned int raw_vector_length();
+
+template <typename T>
+inline unsigned long get_vector_length() {
+ return raw_vector_length() / sizeof(T);
+}
+
+} // namespace sme
+#endif // ARM_COMPUTE_ENABLE_SME
+
// get_vector_length(VLType): Returns vector length for type "T".
//
// This has the same requirements and constraints as the SVE-only form above, so we call into that code for SVE.
@@ -198,6 +213,10 @@ inline unsigned long get_vector_length() {
template <typename T>
inline unsigned long get_vector_length(VLType vl_type) {
switch (vl_type) {
+#ifdef ARM_COMPUTE_ENABLE_SME
+ case VLType::SME:
+ return sme::get_vector_length<T>();
+#endif // ARM_COMPUTE_ENABLE_SME
case VLType::SVE:
return get_vector_length<T>();
default: