diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/utils.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/utils.hpp | 58 |
1 files changed, 44 insertions, 14 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp index d0a8635604..9d8e31870d 100644 --- a/src/core/NEON/kernels/arm_gemm/utils.hpp +++ b/src/core/NEON/kernels/arm_gemm/utils.hpp @@ -176,6 +176,7 @@ namespace utils { // which then calls SVE kernels (compiled accordingly) iff SVE is detected at runtime. template <typename T> inline unsigned long get_vector_length() { +#if defined(__aarch64__) uint64_t vl; __asm __volatile ( @@ -187,26 +188,24 @@ inline unsigned long get_vector_length() { ); return vl / sizeof(T); +#else // !defined(__aarch64__) + return 16 / sizeof(T); +#endif // defined(__aarch64__) } +#ifdef ARM_COMPUTE_ENABLE_SME namespace sme { -template <typename T> -inline uint64_t get_vector_length() { - uint64_t raw_vector_length; - - __asm __volatile ( - ".inst 0x04bf5821\n" // RDSVL X1, #1 - "mov %0, X1\n" - : "=r" (raw_vector_length) - : - : "x1" - ); +// function from misc-sve.cpp +extern unsigned int raw_vector_length(); - return raw_vector_length / sizeof(T); +template <typename T> +inline unsigned long get_vector_length() { + return raw_vector_length() / sizeof(T); } } // namespace sme +#endif // ARM_COMPUTE_ENABLE_SME // get_vector_length(VLType): Returns vector length for type "T". // @@ -215,17 +214,48 @@ inline uint64_t get_vector_length() { template <typename T> inline unsigned long get_vector_length(VLType vl_type) { switch (vl_type) { -#ifdef ARM_COMPUTE_ENABLE_SVE +#ifdef ARM_COMPUTE_ENABLE_SME case VLType::SME: return sme::get_vector_length<T>(); +#endif // ARM_COMPUTE_ENABLE_SME case VLType::SVE: return get_vector_length<T>(); -#endif default: return 16 / sizeof(T); } } +// get_default_activation_values(): Returns the default values for activation min and max for integer activation. +template <typename T> +inline std::tuple<T, T> get_default_activation_values() +{ + const T min = static_cast<T>(std::numeric_limits<T>::min()); + const T max = static_cast<T>(std::numeric_limits<T>::max()); + + return std::make_tuple(min, max); +} + +// get_default_activation_values(): Returns the default values for activation min and max for float activation. +template <> +inline std::tuple<float, float> get_default_activation_values() +{ + const float min = static_cast<float>(-std::numeric_limits<float>::infinity()); + const float max = static_cast<float>(std::numeric_limits<float>::infinity()); + + return std::make_tuple(min, max); +} + +#if defined(__ARM_FP16_ARGS) +// get_default_activation_values(): Returns the default values for activation min and max for __fp16 activation. +template <> +inline std::tuple<__fp16, __fp16> get_default_activation_values() +{ + const __fp16 min = static_cast<__fp16>(-std::numeric_limits<float>::infinity()); + const __fp16 max = static_cast<__fp16>(std::numeric_limits<float>::infinity()); + + return std::make_tuple(min, max); +} +#endif // defined(__ARM_FP16_ARGS) } // utils namespace } // arm_gemm namespace |