aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/utils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/utils.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/utils.hpp220
1 files changed, 195 insertions, 25 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp
index 7dbbe91ba2..11b1bd3e05 100644
--- a/src/core/NEON/kernels/arm_gemm/utils.hpp
+++ b/src/core/NEON/kernels/arm_gemm/utils.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,7 +24,11 @@
#pragma once
+#include "src/cpu/kernels/assembly/arm_gemm.hpp"
+
#include <cstddef>
+#include <limits>
+#include <tuple>
// Macro for unreachable code (e.g. impossible default cases on switch)
#define UNREACHABLE(why) __builtin_unreachable()
@@ -32,7 +36,30 @@
// Paranoid option for the above with assert
// #define UNREACHABLE(why) assert(0 && why)
-#define UNUSED(x) (void)(x)
+namespace arm_gemm {
+
+template<typename T>
+std::string get_type_name() {
+#ifdef __GNUC__
+ std::string s = __PRETTY_FUNCTION__;
+
+ auto start = s.find("cls_");
+
+ if (start==std::string::npos) {
+ return "(unknown)";
+ }
+
+ for(size_t x = start+4; x<s.size(); x++) {
+ if (s[x] == ';' || s[x] == ']') {
+ return s.substr(start+4, x-(start+4));
+ }
+ }
+
+ return "(unknown)";
+#else
+ return "(unsupported)";
+#endif
+}
template<typename T>
inline T iceildiv(const T a, const T b) {
@@ -50,42 +77,185 @@ inline T roundup(const T a, const T b) {
}
}
-namespace arm_gemm {
-namespace utils {
-namespace {
+enum class VLType {
+ None,
+ SVE,
+ SME,
+ SME2
+};
-#ifdef __ARM_FEATURE_SVE
-template<size_t sz>
-inline unsigned long get_vector_length_sz() {
- unsigned long v;
+template<typename T>
+struct IndirectOutputArg {
+ struct {
+ T *base;
+ size_t stride;
+ } direct = {};
+ struct {
+ T * const *ptr;
+ size_t offset;
+ } indirect = {};
+ bool is_indirect;
- __asm (
- "cntb %0"
- : "=r" (v)
- );
+ // Direct
+ IndirectOutputArg(T *base, size_t stride) : is_indirect(false) {
+ direct.base = base;
+ direct.stride = stride;
+ }
+
+ // Indirect
+ IndirectOutputArg(T * const * ptr, size_t offset) : is_indirect(true) {
+ indirect.ptr = ptr;
+ indirect.offset = offset;
+ }
- return v / sz;
+ IndirectOutputArg() : is_indirect(false) {
+ direct.base = nullptr;
+ direct.stride = 0;
+ }
+};
+
+// Check that the provided Requantize32 doesn't have a left shift.
+inline bool quant_no_left_shift(const Requantize32 &qp) {
+ if (qp.per_channel_requant) {
+ return (qp.per_channel_left_shifts == nullptr);
+ } else {
+ return (qp.per_layer_left_shift == 0);
+ }
}
-#define VEC_LEN_SPEC(sz, opcode) template <> inline unsigned long get_vector_length_sz<sz>() { unsigned long v; __asm ( opcode " %0" : "=r" (v)); return v; }
+// Check that the provided Requantize32 is compatible with the "symmetric" hybrid kernels. These don't include row
+// sums, so the 'b_offset' has to be zero.
+inline bool quant_hybrid_symmetric(const Requantize32 &qp) {
+ return quant_no_left_shift(qp) && qp.b_offset == 0;
+}
-VEC_LEN_SPEC(8, "cntd")
-VEC_LEN_SPEC(4, "cntw")
-VEC_LEN_SPEC(2, "cnth")
-VEC_LEN_SPEC(1, "cntb")
-#endif
+// Check that the provided Requantize32 is compatible with the "asymmetric" hybrid kernels. These don't support per
+// channel quantization. Technically b_offset==0 cases would work, but it is a waste to sum and then multiply by 0...
+inline bool quant_hybrid_asymmetric(const Requantize32 &qp) {
+ return quant_no_left_shift(qp) /* && qp.b_offset != 0 */ && qp.per_channel_requant==false;
+}
-} // anonymous namespace
+template<typename T>
+struct IndirectInputArg {
+ struct {
+ const T *base;
+ size_t stride;
+ } direct = {};
+ struct {
+ const T * const * const * ptr;
+ unsigned int start_row;
+ unsigned int start_col;
+ } indirect = {};
+ bool is_indirect;
+ // Direct
+ IndirectInputArg(const T *base, size_t stride) : is_indirect(false) {
+ direct.base = base;
+ direct.stride = stride;
+ }
+
+ // Indirect
+ IndirectInputArg(const T * const * const *ptr, unsigned int start_row, unsigned int start_col) : is_indirect(true) {
+ indirect.ptr = ptr;
+ indirect.start_row = start_row;
+ indirect.start_col = start_col;
+ }
+
+ IndirectInputArg() : is_indirect(false) {
+ direct.base = nullptr;
+ direct.stride = 0;
+ }
+};
+
+namespace utils {
+
+// get_vector_length(): Returns SVE vector length for type "T".
+//
+// It is required that this can be compiled by a compiler in non-SVE mode, but it must be prevented from running (at
+// runtime) if SVE is not enabled. Typically this is used by switchyard/driver code which is built in normal mode
+// which then calls SVE kernels (compiled accordingly) iff SVE is detected at runtime.
template <typename T>
inline unsigned long get_vector_length() {
-#ifdef __ARM_FEATURE_SVE
- return get_vector_length_sz<sizeof(T)>();
-#else
+#if defined(__aarch64__)
+ uint64_t vl;
+
+ __asm __volatile (
+ ".inst 0x0420e3e0\n" // CNTB X0, ALL, MUL #1
+ "mov %0, X0\n"
+ : "=r" (vl)
+ :
+ : "x0"
+ );
+
+ return vl / sizeof(T);
+#else // !defined(__aarch64__)
return 16 / sizeof(T);
-#endif
+#endif // defined(__aarch64__)
+}
+
+#ifdef ARM_COMPUTE_ENABLE_SME
+namespace sme {
+
+// function from misc-sve.cpp
+extern unsigned int raw_vector_length();
+
+template <typename T>
+inline unsigned long get_vector_length() {
+ return raw_vector_length() / sizeof(T);
+}
+
+} // namespace sme
+#endif // ARM_COMPUTE_ENABLE_SME
+
+// get_vector_length(VLType): Returns vector length for type "T".
+//
+// This has the same requirements and constraints as the SVE-only form above, so we call into that code for SVE.
+
+template <typename T>
+inline unsigned long get_vector_length(VLType vl_type) {
+ switch (vl_type) {
+#ifdef ARM_COMPUTE_ENABLE_SME
+ case VLType::SME:
+ return sme::get_vector_length<T>();
+#endif // ARM_COMPUTE_ENABLE_SME
+ case VLType::SVE:
+ return get_vector_length<T>();
+ default:
+ return 16 / sizeof(T);
+ }
}
+// get_default_activation_values(): Returns the default values for activation min and max for integer activation.
+template <typename T>
+inline std::tuple<T, T> get_default_activation_values()
+{
+ const T min = static_cast<T>(std::numeric_limits<T>::min());
+ const T max = static_cast<T>(std::numeric_limits<T>::max());
+
+ return std::make_tuple(min, max);
+}
+
+// get_default_activation_values(): Returns the default values for activation min and max for float activation.
+template <>
+inline std::tuple<float, float> get_default_activation_values()
+{
+ const float min = static_cast<float>(-std::numeric_limits<float>::infinity());
+ const float max = static_cast<float>(std::numeric_limits<float>::infinity());
+
+ return std::make_tuple(min, max);
+}
+
+#if defined(__ARM_FP16_ARGS)
+// get_default_activation_values(): Returns the default values for activation min and max for __fp16 activation.
+template <>
+inline std::tuple<__fp16, __fp16> get_default_activation_values()
+{
+ const __fp16 min = static_cast<__fp16>(-std::numeric_limits<float>::infinity());
+ const __fp16 max = static_cast<__fp16>(std::numeric_limits<float>::infinity());
+
+ return std::make_tuple(min, max);
+}
+#endif // defined(__ARM_FP16_ARGS)
} // utils namespace
} // arm_gemm namespace