diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/interleave_indirect_impl.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/interleave_indirect_impl.hpp | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/interleave_indirect_impl.hpp b/src/core/NEON/kernels/arm_gemm/interleave_indirect_impl.hpp index 4f25da2877..b921fd16d2 100644 --- a/src/core/NEON/kernels/arm_gemm/interleave_indirect_impl.hpp +++ b/src/core/NEON/kernels/arm_gemm/interleave_indirect_impl.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Arm Limited. + * Copyright (c) 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -39,8 +39,12 @@ */ template<unsigned int height_vectors, unsigned int block, VLType vlt, bool integrate_sums, typename TIn, typename TOut> void interleave_block( TOut * &out, const TIn * const *in, size_t width, size_t height, size_t row_offset, bool first) { +#ifdef ARM_COMPUTE_ENABLE_SVE const unsigned int int_by = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block : (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 )); +#else + const unsigned int int_by = height_vectors; +#endif std::vector<int32_t> the_sums; @@ -104,8 +108,12 @@ void interleave_block( TOut * &out, const TIn * const *in, size_t width, size_t template<unsigned int height_vectors, unsigned int block, VLType vlt, typename TOut> inline void FixupRowSums(TOut * &out, const int32_t row_sum_multiplier) { +#ifdef ARM_COMPUTE_ENABLE_SVE const unsigned int height = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block : (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 )); +#else + const unsigned int height = height_vectors; +#endif // If we are integrating row sums, we need to do some fix up, depending on whether the multiplier is non-zero or not. if (row_sum_multiplier) { @@ -138,8 +146,12 @@ void IndirectInterleave(TOut *out, const TIn * const * const *ptr, unsigned int unsigned int rounded_stringlen, const unsigned int y0, const unsigned int ymax, const unsigned int k0, const unsigned int kmax, bool integrate_sums, const int32_t row_sum_multiplier) { +#ifdef ARM_COMPUTE_ENABLE_SVE const unsigned int height = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block : (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 )); +#else + const unsigned int height = height_vectors; +#endif // 'interleave_block' implementations are entitled to read a pointer for each row they handle from the input // pointer array, even for out of range rows (although they must not subsequently dereference those pointers for @@ -208,8 +220,12 @@ void IndirectInterleave(TOut *out, const TIn * const * const *ptr, unsigned int template<unsigned int height_vectors, unsigned int block, VLType vlt, typename TIn, typename TOut> void ConvolutionInterleave(TOut *out, const TIn *in, size_t in_stride, const convolver<TIn> &conv, const unsigned int rounded_stringlen, const unsigned int y0, const unsigned int ymax, const unsigned int k0, const unsigned int kmax, bool integrate_sums, const int32_t row_sum_multiplier) { +#ifdef ARM_COMPUTE_ENABLE_SVE const unsigned int height = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block : (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 )); +#else + const unsigned int height = height_vectors; +#endif auto conv_cols = conv.process_columns(in, in_stride, k0, kmax, rounded_stringlen); // Use alloca here as a std::vector can be expensive in highly threaded scenarios. @@ -246,8 +262,12 @@ void ConvolutionInterleave(TOut *out, const TIn *in, size_t in_stride, const con template<unsigned int height_vectors, unsigned int block, VLType vlt, typename TIn, typename TOut> void Interleave(TOut *out, const TIn *in, size_t in_stride, const unsigned int y0, const unsigned int ymax, const unsigned int k0, const unsigned int kmax, bool integrate_sums, const int32_t row_sum_multiplier) { +#ifdef ARM_COMPUTE_ENABLE_SVE const unsigned int height = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block : (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 )); +#else + const unsigned int height = height_vectors; +#endif // Use alloca here as a std::vector can be expensive in highly threaded scenarios. const TIn **row_ptrs = reinterpret_cast<const TIn **>(alloca(height * sizeof(const TIn *))); |