aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/utils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/utils.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/utils.hpp129
1 files changed, 105 insertions, 24 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp
index 6d483a3b9d..11b1bd3e05 100644
--- a/src/core/NEON/kernels/arm_gemm/utils.hpp
+++ b/src/core/NEON/kernels/arm_gemm/utils.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,9 +24,11 @@
#pragma once
-#include "arm_gemm.hpp"
+#include "src/cpu/kernels/assembly/arm_gemm.hpp"
#include <cstddef>
+#include <limits>
+#include <tuple>
// Macro for unreachable code (e.g. impossible default cases on switch)
#define UNREACHABLE(why) __builtin_unreachable()
@@ -37,6 +39,29 @@
namespace arm_gemm {
template<typename T>
+std::string get_type_name() {
+#ifdef __GNUC__
+ std::string s = __PRETTY_FUNCTION__;
+
+ auto start = s.find("cls_");
+
+ if (start==std::string::npos) {
+ return "(unknown)";
+ }
+
+ for(size_t x = start+4; x<s.size(); x++) {
+ if (s[x] == ';' || s[x] == ']') {
+ return s.substr(start+4, x-(start+4));
+ }
+ }
+
+ return "(unknown)";
+#else
+ return "(unsupported)";
+#endif
+}
+
+template<typename T>
inline T iceildiv(const T a, const T b) {
return (a + b - 1) / b;
}
@@ -55,6 +80,8 @@ inline T roundup(const T a, const T b) {
enum class VLType {
None,
SVE,
+ SME,
+ SME2
};
template<typename T>
@@ -141,40 +168,94 @@ struct IndirectInputArg {
};
namespace utils {
-namespace {
-#ifdef __ARM_FEATURE_SVE
-template<size_t sz>
-inline unsigned long get_vector_length_sz() {
- unsigned long v;
+// get_vector_length(): Returns SVE vector length for type "T".
+//
+// It is required that this can be compiled by a compiler in non-SVE mode, but it must be prevented from running (at
+// runtime) if SVE is not enabled. Typically this is used by switchyard/driver code which is built in normal mode
+// which then calls SVE kernels (compiled accordingly) iff SVE is detected at runtime.
+template <typename T>
+inline unsigned long get_vector_length() {
+#if defined(__aarch64__)
+ uint64_t vl;
- __asm (
- "cntb %0"
- : "=r" (v)
+ __asm __volatile (
+ ".inst 0x0420e3e0\n" // CNTB X0, ALL, MUL #1
+ "mov %0, X0\n"
+ : "=r" (vl)
+ :
+ : "x0"
);
- return v / sz;
+ return vl / sizeof(T);
+#else // !defined(__aarch64__)
+ return 16 / sizeof(T);
+#endif // defined(__aarch64__)
}
-#define VEC_LEN_SPEC(sz, opcode) template <> inline unsigned long get_vector_length_sz<sz>() { unsigned long v; __asm ( opcode " %0" : "=r" (v)); return v; }
-
-VEC_LEN_SPEC(8, "cntd")
-VEC_LEN_SPEC(4, "cntw")
-VEC_LEN_SPEC(2, "cnth")
-VEC_LEN_SPEC(1, "cntb")
-#endif
+#ifdef ARM_COMPUTE_ENABLE_SME
+namespace sme {
-} // anonymous namespace
+// function from misc-sve.cpp
+extern unsigned int raw_vector_length();
template <typename T>
inline unsigned long get_vector_length() {
-#ifdef __ARM_FEATURE_SVE
- return get_vector_length_sz<sizeof(T)>();
-#else
- return 16 / sizeof(T);
-#endif
+ return raw_vector_length() / sizeof(T);
+}
+
+} // namespace sme
+#endif // ARM_COMPUTE_ENABLE_SME
+
+// get_vector_length(VLType): Returns vector length for type "T".
+//
+// This has the same requirements and constraints as the SVE-only form above, so we call into that code for SVE.
+
+template <typename T>
+inline unsigned long get_vector_length(VLType vl_type) {
+ switch (vl_type) {
+#ifdef ARM_COMPUTE_ENABLE_SME
+ case VLType::SME:
+ return sme::get_vector_length<T>();
+#endif // ARM_COMPUTE_ENABLE_SME
+ case VLType::SVE:
+ return get_vector_length<T>();
+ default:
+ return 16 / sizeof(T);
+ }
+}
+
+// get_default_activation_values(): Returns the default values for activation min and max for integer activation.
+template <typename T>
+inline std::tuple<T, T> get_default_activation_values()
+{
+ const T min = static_cast<T>(std::numeric_limits<T>::min());
+ const T max = static_cast<T>(std::numeric_limits<T>::max());
+
+ return std::make_tuple(min, max);
+}
+
+// get_default_activation_values(): Returns the default values for activation min and max for float activation.
+template <>
+inline std::tuple<float, float> get_default_activation_values()
+{
+ const float min = static_cast<float>(-std::numeric_limits<float>::infinity());
+ const float max = static_cast<float>(std::numeric_limits<float>::infinity());
+
+ return std::make_tuple(min, max);
}
+#if defined(__ARM_FP16_ARGS)
+// get_default_activation_values(): Returns the default values for activation min and max for __fp16 activation.
+template <>
+inline std::tuple<__fp16, __fp16> get_default_activation_values()
+{
+ const __fp16 min = static_cast<__fp16>(-std::numeric_limits<float>::infinity());
+ const __fp16 max = static_cast<__fp16>(std::numeric_limits<float>::infinity());
+
+ return std::make_tuple(min, max);
+}
+#endif // defined(__ARM_FP16_ARGS)
} // utils namespace
} // arm_gemm namespace