From ef637398a8c2060e15de438020c53331da8bd6dd Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Mon, 12 Feb 2024 21:32:51 +0000 Subject: Integrate new pretranspose_b_array with extra fused transpose of B This patch fuses the transposition taking place in Acl with the transformations done in arm_gemm (called pretranspose_b_array) if the underlying kernel and transform supports it. This should improve start-up time (as it's for constant Rhs matrices) and memory footprint. The transformations in arm_gemm are kernel specific. The Rhs matrix is transformed into certain layouts to improve the performance. Resolves: COMPMID-6595 Change-Id: Id2932dd966e59f903c279417bebcea83d9a42464 Signed-off-by: Gunes Bayir Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11144 Tested-by: Arm Jenkins Reviewed-by: Viet-Hoa Do Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- Android.bp | 1 + docs/user_guide/release_version_and_change_log.dox | 3 + filelist.json | 1 + src/BUILD.bazel | 1 + src/CMakeLists.txt | 1 + src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp | 8 +- .../NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp | 18 +- .../kernels/arm_gemm/gemm_hybrid_quantized.hpp | 8 +- .../NEON/kernels/arm_gemm/gemm_interleaved.hpp | 20 +- src/core/NEON/kernels/arm_gemm/gemv_batched.hpp | 6 +- .../NEON/kernels/arm_gemm/gemv_pretransposed.hpp | 8 +- src/core/NEON/kernels/arm_gemm/interleave-8way.cpp | 264 +++++++++++++++++++++ .../arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp | 6 +- .../arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp | 6 +- .../kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp | 6 +- .../NEON/kernels/arm_gemm/quantize_wrapper.hpp | 8 +- .../NEON/kernels/arm_gemm/std_transforms_fixed.hpp | 9 +- .../kernels/arm_gemm/std_transforms_fixed_trB.hpp | 87 +++++++ .../NEON/kernels/arm_gemm/std_transforms_sme.hpp | 9 +- .../NEON/kernels/arm_gemm/std_transforms_sve.hpp | 9 +- src/core/NEON/kernels/arm_gemm/transform.cpp | 11 +- src/cpu/kernels/assembly/gemm_common.hpp | 41 +++- .../operators/internal/CpuGemmAssemblyDispatch.cpp | 62 +++-- 23 files changed, 513 insertions(+), 80 deletions(-) create mode 100644 src/core/NEON/kernels/arm_gemm/interleave-8way.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/std_transforms_fixed_trB.hpp diff --git a/Android.bp b/Android.bp index 670138b209..0d087c943b 100644 --- a/Android.bp +++ b/Android.bp @@ -332,6 +332,7 @@ cc_library_static { "src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp", "src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp", "src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp", + "src/core/NEON/kernels/arm_gemm/interleave-8way.cpp", "src/core/NEON/kernels/arm_gemm/interleave_indirect-sve.cpp", "src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp", "src/core/NEON/kernels/arm_gemm/mergeresults-fp16.cpp", diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox index 676f1ca032..b788957dda 100644 --- a/docs/user_guide/release_version_and_change_log.dox +++ b/docs/user_guide/release_version_and_change_log.dox @@ -41,6 +41,9 @@ If there is more than one release in a month then an extra sequential number is @section S2_2_changelog Changelog +v24.04 Public major release + - Optimize start-up time of @ref NEConvolutionLayer for some input configurations where GeMM is selected as the convolution algorithm + v24.02 Public major release - Replace template writer with compute kernel writer in dynamic fusion. - Performance optimizations: diff --git a/filelist.json b/filelist.json index dcf3204ecd..d44a7216ac 100644 --- a/filelist.json +++ b/filelist.json @@ -1592,6 +1592,7 @@ "src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp", "src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp", "src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp", + "src/core/NEON/kernels/arm_gemm/interleave-8way.cpp", "src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp", "src/core/NEON/kernels/arm_gemm/mergeresults-fp16.cpp", "src/core/NEON/kernels/arm_gemm/mergeresults.cpp", diff --git a/src/BUILD.bazel b/src/BUILD.bazel index 9d5ae63484..f9d166c525 100644 --- a/src/BUILD.bazel +++ b/src/BUILD.bazel @@ -517,6 +517,7 @@ filegroup( "core/NEON/kernels/arm_gemm/gemm_quint8.cpp", "core/NEON/kernels/arm_gemm/gemm_uint16.cpp", "core/NEON/kernels/arm_gemm/gemm_uint8.cpp", + "core/NEON/kernels/arm_gemm/interleave-8way.cpp", "core/NEON/kernels/arm_gemm/interleave_indirect.cpp", "core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16/generic.cpp", "core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index be7a6ef188..c5a172172b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -508,6 +508,7 @@ target_sources( core/NEON/kernels/arm_gemm/gemm_quint8.cpp core/NEON/kernels/arm_gemm/gemm_uint16.cpp core/NEON/kernels/arm_gemm/gemm_uint8.cpp + core/NEON/kernels/arm_gemm/interleave-8way.cpp core/NEON/kernels/arm_gemm/interleave_indirect.cpp core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16/generic.cpp core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp index 436316c0f7..a6c9677305 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -221,7 +221,9 @@ public: return roundup(_Nsize, strategy::out_width()) * roundup(_Ksize, strategy::k_unroll()) * _nmulti * sizeof(Toi); } - void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override { + assert(!transposed); + Toi *buffer = reinterpret_cast(in_buffer); _B_transposed = buffer; strategy strat(_ci); @@ -237,7 +239,7 @@ public: const unsigned int size = roundup(xmax-x0, strategy::out_width()) * k_size; strat.transforms.PrepareB( buffer, B + (multi * B_multi_stride), ldb, - x0, xmax, k0, kmax); + x0, xmax, k0, kmax, false); buffer += size; } diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp index 1780375c44..89c2d5a23e 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2023 Arm Limited. + * Copyright (c) 2017-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -631,11 +631,16 @@ public: } } - void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { - pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, 0, get_B_pretranspose_window_size()); + bool B_pretranspose_supports_transpose() const override { + strategy strat(_args._ci); + return strat.transforms.PrepareB_supports_transpose(); + } + + void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override { + pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, transposed, 0, get_B_pretranspose_window_size()); } - void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, size_t start, size_t end) override { + void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed, size_t start, size_t end) override { if (end >= get_B_pretranspose_window_size()) { requantize_bias(in_buffer, B, ldb, B_multi_stride); } @@ -717,7 +722,8 @@ public: strat.transforms.PrepareB(buffer, B + (multi * B_multi_stride), ldb, x0, xmax, (k_section_base * _args._Ksize) + k_offset, // K starting point - compute row to read based on our section and the true section length. - (k_section_base * _args._Ksize) + k_offset + k_length); // K end point - starting point plus length computed above. + (k_section_base * _args._Ksize) + k_offset + k_length, // K end point - starting point plus length computed above. + transposed); // We need to modify our position based on the ROUNDED version of what we just did. unsigned int padded_length = roundup(k_length, strategy::k_unroll()); @@ -731,7 +737,7 @@ public: } else { // In the single K section case, can process the whole lot in one go. strat.transforms.PrepareB(buffer, B + (multi * B_multi_stride), ldb, - n_start, n_end, k0, std::min(kmax, _args._Ksize)); + n_start, n_end, k0, std::min(kmax, _args._Ksize), transposed); } } } diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp index efb5bd1bb4..f12efe4282 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -277,7 +277,9 @@ public: } } - void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override { + assert(!transposed); + requantize_bias(in_buffer, B, ldb, B_multi_stride); uintptr_t buffer_int = reinterpret_cast(in_buffer); @@ -296,7 +298,7 @@ public: const unsigned int size = roundup(xmax-x0, strategy::out_width()) * k_size; strat.transforms.PrepareB( buffer, B + (multi * B_multi_stride), ldb, - x0, xmax, k0, kmax); + x0, xmax, k0, kmax, false); buffer += size; } diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index 362a3e30ea..4f732f7d94 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2023 Arm Limited. + * Copyright (c) 2017-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -1067,11 +1067,18 @@ public: } } - void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { - pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, 0, get_B_pretranspose_window_size()); + // Support for transposed B is a property of the strategy::transpose type + bool B_pretranspose_supports_transpose() const override { + typename transform_type::value>::type transforms; + + return transforms.PrepareB_supports_transpose(); + } + + void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed) override { + pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, transposed, 0, get_B_pretranspose_window_size()); } - void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, size_t start, size_t end) override { + void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed, size_t start, size_t end) override { // Perform column sums etc as part of the last block. if (end >= get_B_pretranspose_window_size()) { requantize_bias(in_buffer, B, ldb, B_multi_stride); @@ -1134,7 +1141,8 @@ public: strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb, x0, xmax, (k_section_base * _Ksize) + k_offset, // K starting point - compute row to read based on our section and the true section length. - (k_section_base * _Ksize) + k_offset + k_length); // K end point - starting point plus length computed above. + (k_section_base * _Ksize) + k_offset + k_length, // K end point - starting point plus length computed above. + transposed); // We need to modify our position based on the ROUNDED version of what we just did. unsigned int padded_length = roundup(k_length, strategy::k_unroll()); @@ -1149,7 +1157,7 @@ public: // In the single K section case, can process the whole lot in one go. // Caution: 'blockwalker::kmax()' rounds up, so clamp to valid _Ksize. strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb, - current.x0(), current.xmax(), current.k0(), std::min(current.kmax(), _Ksize)); + current.x0(), current.xmax(), current.k0(), std::min(current.kmax(), _Ksize), transposed); buffer += roundup(current.xmax() - current.x0(), strategy::out_width()) * roundup(current.kmax() - current.k0(), strategy::k_unroll()); } diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp index 4fc9b3456a..ad504f2664 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -88,8 +88,8 @@ public: return _subgemm->get_B_pretransposed_array_size(); } - void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { - _subgemm->pretranspose_B_array(buffer, B, ldb, B_multi_stride); + void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override { + _subgemm->pretranspose_B_array(buffer, B, ldb, B_multi_stride, transposed); } void set_pretransposed_B_data(void *buffer) override { diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp index 86b33d081f..f70fc98572 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -215,7 +215,9 @@ public: } } - void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { + void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override { + assert(!transposed); + requantize_bias(buffer, B, ldb, B_multi_stride); // The actual transposed buffer goes after the column sums (if any) @@ -225,7 +227,7 @@ public: strategy strat(_args._ci); for (unsigned int multi=0; multi<_args._nmulti; multi++) { - strat.transforms.PrepareB(B_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _args._Nsize, 0, _args._Ksize); + strat.transforms.PrepareB(B_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _args._Nsize, 0, _args._Ksize, false); } _B_pretransposed = B_buffer; diff --git a/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp b/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp new file mode 100644 index 0000000000..148678ba69 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp @@ -0,0 +1,264 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include + +#include +#include + +#include "transform.hpp" +#include "utils.hpp" + +namespace arm_gemm { + +namespace { + +// Helper function to interleave a single 4x4 block of 32-bin values +// together. + +// _full version doesn't need to worry about any padding. +static inline void transpose_block_32_full(const uint8_t * __restrict in_ptr0, const uint8_t * __restrict in_ptr1, const uint8_t * __restrict in_ptr2, const uint8_t * __restrict in_ptr3, uint8_t * __restrict out_ptr, long output_stride) { + uint32x4_t inputs[4]; + uint32x4_t inters[4]; + uint32x4_t outputs[4]; + + inputs[0] = vld1q_u32(reinterpret_cast(in_ptr0)); + inputs[1] = vld1q_u32(reinterpret_cast(in_ptr1)); + inputs[2] = vld1q_u32(reinterpret_cast(in_ptr2)); + inputs[3] = vld1q_u32(reinterpret_cast(in_ptr3)); + + inters[0] = vzip1q_u32(inputs[0], inputs[2]); + inters[1] = vzip2q_u32(inputs[0], inputs[2]); + inters[2] = vzip1q_u32(inputs[1], inputs[3]); + inters[3] = vzip2q_u32(inputs[1], inputs[3]); + + outputs[0] = vzip1q_u32(inters[0], inters[2]); + outputs[1] = vzip2q_u32(inters[0], inters[2]); + outputs[2] = vzip1q_u32(inters[1], inters[3]); + outputs[3] = vzip2q_u32(inters[1], inters[3]); + + vst1q_u32(reinterpret_cast(out_ptr), outputs[0]); + vst1q_u32(reinterpret_cast(out_ptr + output_stride), outputs[1]); + vst1q_u32(reinterpret_cast(out_ptr + output_stride*2), outputs[2]); + vst1q_u32(reinterpret_cast(out_ptr + output_stride*3), outputs[3]); +} + +// _part version: Only read "bytes_in" bytes, not a full vector. Only write +// out 4-byte blocks that have some live content (if bytes_in is not a +// multiple of 4 there will some padding in each 4-block) +static inline void transpose_block_32_part(const uint8_t *in_ptr0, const uint8_t *in_ptr1, const uint8_t *in_ptr2, const uint8_t *in_ptr3, uint8_t *out_ptr, long bytes_in, long output_stride) { + uint32x4_t inputs[4]; + uint32x4_t inters[4]; + uint32x4_t outputs[4]; + uint8_t scratch[16] = {0}; + + long num_outs = iceildiv(bytes_in, 4); + + memcpy(scratch, in_ptr0, bytes_in); + inputs[0] = vld1q_u32(reinterpret_cast(scratch)); + memcpy(scratch, in_ptr1, bytes_in); + inputs[1] = vld1q_u32(reinterpret_cast(scratch)); + memcpy(scratch, in_ptr2, bytes_in); + inputs[2] = vld1q_u32(reinterpret_cast(scratch)); + memcpy(scratch, in_ptr3, bytes_in); + inputs[3] = vld1q_u32(reinterpret_cast(scratch)); + + inters[0] = vzip1q_u32(inputs[0], inputs[2]); + inters[1] = vzip2q_u32(inputs[0], inputs[2]); + inters[2] = vzip1q_u32(inputs[1], inputs[3]); + inters[3] = vzip2q_u32(inputs[1], inputs[3]); + + outputs[0] = vzip1q_u32(inters[0], inters[2]); + outputs[1] = vzip2q_u32(inters[0], inters[2]); + outputs[2] = vzip1q_u32(inters[1], inters[3]); + outputs[3] = vzip2q_u32(inters[1], inters[3]); + + do { + vst1q_u32(reinterpret_cast(out_ptr), outputs[0]); + if (num_outs < 2) + break; + vst1q_u32(reinterpret_cast(out_ptr + output_stride), outputs[1]); + if (num_outs < 3) + break; + vst1q_u32(reinterpret_cast(out_ptr + output_stride*2), outputs[2]); + if (num_outs < 4) + break; + vst1q_u32(reinterpret_cast(out_ptr + output_stride*3), outputs[3]); + } while (0); +} + +template +struct Unroll { + template + static void run(F f) { + Unroll::run(f); + f(N-1); + } +}; + +template<> +struct Unroll<0> { + template + static void run(F) { + } +}; + +// Interleave some multiple of 4 rows together. +// +// The template parameter BLOCKS controls the size of the inner loop - each BLOCK is 4 rows. +// The function parameter interleave_multiple controls the number of times the inner loop is run. + +// The total interleave depth for a given run is therefore BLOCKS * interleave_multiple * 4. +template +void a64_interleave_1x4(uint8_t *out, const uint8_t *in, long width, long in_stride, long height, long interleave_multiple) { + const long total_interleave_depth = BLOCKS * 4 * interleave_multiple; + constexpr long loop_interleave_depth = BLOCKS * 4; + + uint8_t *pad_row = reinterpret_cast(alloca(width)); + + if (height % total_interleave_depth) { + memset(pad_row, 0, width); + } + + // Outer loop: process blocks of total_interleave_depth rows at a time. + for (long y0_base=0; y0_base::run( [&](unsigned y) { + in_ptrs[y] = (y+y0 < height) ? in + ((y+y0) * in_stride) : pad_row; + }); + + long bytes_left = width; + // Process full vectors using transpose_block_32_full() + while (bytes_left >= 16) { // 16 is the vector length in bytes + Unroll::run( [&](unsigned u) { + transpose_block_32_full(in_ptrs[u*4 + 0], in_ptrs[u*4 + 1], in_ptrs[u*4 + 2], in_ptrs[u*4 + 3], + out_ptr + 16*u, total_interleave_depth * 4); // 4 is the blocking depth + }); + + Unroll::run( [&](unsigned y) { + in_ptrs[y] += 16; // 16 is the vector length in bytes + }); + + out_ptr += total_interleave_depth * 16; // 16 is the vector length in bytes + bytes_left -= 16; // 16 is the vector length in bytes + } + + // Process any remaining bytes using transpose_block_32_part() + if (bytes_left) { + Unroll::run( [&](unsigned u) { + transpose_block_32_part(in_ptrs[u*4 + 0], in_ptrs[u*4 + 1], in_ptrs[u*4 + 2], in_ptrs[u*4 + 3], + out_ptr + 16*u, bytes_left, total_interleave_depth * 4); + }); + } + } + + // Update "out" pointer for next set of total_interleave_depth rows + out += total_interleave_depth * roundup(width, 4); + } +} + +} // anonymous namespace + +template<> +void Transform<16, 4, false, VLType::None>( + uint8_t *out, const uint8_t *in, int stride, int y0, int ymax, int x0, int xmax) +{ + a64_interleave_1x4<4>( + reinterpret_cast(out), + reinterpret_cast(in + y0 * stride + x0), + (xmax - x0), + stride, + (ymax - y0), + 1 + ); +} + +template<> +void Transform<16, 4, false, VLType::None>( + int8_t *out, const int8_t *in, int stride, int y0, int ymax, int x0, int xmax) +{ + a64_interleave_1x4<4>( + reinterpret_cast(out), + reinterpret_cast(in + y0 * stride + x0), + (xmax - x0), + stride, + (ymax - y0), + 1 + ); +} + +template<> +void Transform<12, 1, false, VLType::None>( + float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax) +{ + a64_interleave_1x4<3>( + reinterpret_cast(out), + reinterpret_cast(in + y0 * stride + x0), + (xmax - x0) * sizeof(float), + stride * sizeof(float), + (ymax - y0), + 1 + ); +} + +template<> +void Transform<16, 1, false, VLType::None>( + float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax) +{ + a64_interleave_1x4<4>( + reinterpret_cast(out), + reinterpret_cast(in + y0 * stride + x0), + (xmax - x0) * sizeof(float), + stride * sizeof(float), + (ymax - y0), + 1 + ); +} + +template<> +void Transform<24, 1, false, VLType::None>( + float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax) +{ + a64_interleave_1x4<3>( + reinterpret_cast(out), + reinterpret_cast(in + y0 * stride + x0), + (xmax - x0) * sizeof(float), + stride * sizeof(float), + (ymax - y0), + 2 + ); +} + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp index 171929e65e..bce4de74f7 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023 Arm Limited. + * Copyright (c) 2021, 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,7 +24,7 @@ #pragma once #ifdef __aarch64__ -#include "../std_transforms_fixed.hpp" +#include "../std_transforms_fixed_trB.hpp" #include "../performance_parameters.hpp" #define ARGLIST \ @@ -71,7 +71,7 @@ public: return true; } - StdTransformsFixed transforms = {}; + StdTransformsFixedTRB transforms = {}; template static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) { diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp index 759729de5e..7f85d2dd42 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, 2023 Arm Limited. + * Copyright (c) 2019-2021, 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,7 +24,7 @@ #pragma once #ifdef __aarch64__ -#include "../std_transforms_fixed.hpp" +#include "../std_transforms_fixed_trB.hpp" #include "../performance_parameters.hpp" #define ARGLIST \ @@ -71,7 +71,7 @@ public: return true; } - StdTransformsFixed transforms = {}; + StdTransformsFixedTRB transforms = {}; template static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) { diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp index 65ef407f79..19acfe8ae9 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -25,7 +25,7 @@ #ifdef __aarch64__ -#include "../std_transforms_fixed.hpp" +#include "../std_transforms_fixed_trB.hpp" #include "../performance_parameters.hpp" #include "../bfloat.hpp" @@ -68,7 +68,7 @@ public: } // Use the standard fixed size transforms. - StdTransformsFixed transforms = {}; + StdTransformsFixedTRB transforms = {}; template static PerformanceParameters get_performance_parameters(const CPUInfo *ci) { diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp index ce727032e6..d35825c428 100644 --- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp +++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -184,9 +184,11 @@ public: col_sums_pretransposed(B, ldb, B_multi_stride); } - void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { + void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override { + assert(!transposed); + uintptr_t buffer_int = reinterpret_cast(buffer); - _subgemm->pretranspose_B_array(reinterpret_cast(buffer_int + col_sum_size()), B, ldb, B_multi_stride); + _subgemm->pretranspose_B_array(reinterpret_cast(buffer_int + col_sum_size()), B, ldb, B_multi_stride, transposed); requantize_bias(buffer, B, ldb, B_multi_stride); } diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp index 4669be9993..a9cbf4ec8d 100644 --- a/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp +++ b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 Arm Limited. + * Copyright (c) 2018-2020, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -63,9 +63,14 @@ public: ConvolutionInterleave(out, ptr, stride, conv, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier); } + bool PrepareB_supports_transpose() const { + return false; + } + template void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0, - const int xmax, const int k0, const int kmax) const { + const int xmax, const int k0, const int kmax, bool transposed) const { + assert(!transposed); Transform(out, in, stride, x0, xmax, k0, kmax); } diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_fixed_trB.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed_trB.hpp new file mode 100644 index 0000000000..1db716455f --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed_trB.hpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2018-2020, 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#include "convolver.hpp" +#include "mergeresults.hpp" +#include "transform.hpp" +#include "interleave_indirect.hpp" + +namespace arm_gemm { + +/* + * Define "standard" transforms for the blocked GEMMs with fixed vector + * length. This version supports accepting the RHS/B matrix in transposed + * format. + * + * This assumes that A is interleaved 'height' ways, B is interleaved + * 'width' ways and transposed, and that the merge needs to work in 'height' + * x 'width' blocks. + * + * The optional 'block' parameter is for kernels using dot-product type + * instructions like UDOT and SDOT. + */ +template +class StdTransformsFixedTRB +{ +public: + template + void PrepareA(TOperand *out, const TIn *in, const int stride, const int y0, + const int ymax, const int k0, const int kmax, int32_t row_sum_multiplier) const { + Interleave(out, in, stride, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier); + } + + template + void PrepareA_indirect(TOperand *out, const TIn * const * const *ptr, size_t stringlen, size_t rounded_stringlen, const int y0, + const int ymax, const int k0, const int kmax, int32_t row_sum_multiplier) { + IndirectInterleave(out, ptr, stringlen, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier); + } + + template + void PrepareA_convolution(TOperand *out, const TIn *ptr, size_t stride, const convolver &conv, size_t rounded_stringlen, + const int y0, const int ymax, const int k0, const int kmax, int32_t row_sum_multiplier) { + ConvolutionInterleave(out, ptr, stride, conv, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier); + } + + bool PrepareB_supports_transpose() const { + return true; + } + + template + void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0, + const int xmax, const int k0, const int kmax, bool transposed) const { + if (transposed) { + Transform(out, in, stride, x0, xmax, k0, kmax); + } else { + Transform(out, in, stride, x0, xmax, k0, kmax); + } + } + + template + void Merge(TOut *out, const TResult *in, int stride, int y0, int ymax, int x0, int xmax, const TOut *bias, const Activation act, bool append) const { + MergeResults(out, in, stride, y0, ymax, x0, xmax, bias, act, append); + } +}; + +} // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp index afe24e7ce0..40f61626a1 100644 --- a/src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp +++ b/src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -60,9 +60,14 @@ public: ConvolutionInterleave(out, ptr, stride, conv, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier); } + bool PrepareB_supports_transpose() const { + return false; + } + template void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0, - const int xmax, const int k0, const int kmax) { + const int xmax, const int k0, const int kmax, bool transposed) { + assert (!transposed); Transform(out, in, stride, x0, xmax, k0, kmax); } diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp index 3256d919ea..c516bfc456 100644 --- a/src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp +++ b/src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 Arm Limited. + * Copyright (c) 2017-2018,2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -61,9 +61,14 @@ public: ConvolutionInterleave(out, ptr, stride, conv, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier); } + bool PrepareB_supports_transpose() const { + return false; + } + template void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0, - const int xmax, const int k0, const int kmax) { + const int xmax, const int k0, const int kmax, bool transposed) { + assert (!transposed); Transform(out, in, stride, x0, xmax, k0, kmax); } diff --git a/src/core/NEON/kernels/arm_gemm/transform.cpp b/src/core/NEON/kernels/arm_gemm/transform.cpp index 5aa62f0fe4..45e4f0e1de 100644 --- a/src/core/NEON/kernels/arm_gemm/transform.cpp +++ b/src/core/NEON/kernels/arm_gemm/transform.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023 Arm Limited. + * Copyright (c) 2021-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -134,7 +134,14 @@ template void Transform<8, 1, true, VLType::None>(float *, const __fp16 *, int, #endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) #ifdef ARM_COMPUTE_ENABLE_BF16 template void Transform<8, 1, true, VLType::None>(float *, const bfloat16 *, int, int, int, int, int); -#endif +#endif // ARM_COMPUTE_ENABLE_BF16 #endif // AArch32 +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +template void Transform<12, 1, false, VLType::None>(float *, const __fp16 *, int, int, int, int, int); +#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#ifdef ARM_COMPUTE_ENABLE_BF16 +template void Transform<12, 1, false, VLType::None>(float *, const bfloat16 *, int, int, int, int, int); +#endif // ARM_COMPUTE_ENABLE_BF16 + } // namespace arm_gemm diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp index 6fe9f13f02..4825814e31 100644 --- a/src/cpu/kernels/assembly/gemm_common.hpp +++ b/src/cpu/kernels/assembly/gemm_common.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021,2023 Arm Limited. + * Copyright (c) 2017-2021,2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,6 +21,10 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + +#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_GEMM_COMMON_HPP +#define ACL_SRC_CPU_KERNELS_ASSEMBLY_GEMM_COMMON_HPP + #pragma once #include "convolution_parameters.hpp" @@ -116,6 +120,11 @@ public: { return false; } + /* Does pretranspose accept the transposed flag? */ + virtual bool B_pretranspose_supports_transpose() const + { + return false; + } /* Total number of bytes of space needed for pretransposed arrays. */ virtual size_t get_B_pretransposed_array_size() const { @@ -128,10 +137,10 @@ public: } /* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */ /* The "real" version of this depends on the templated operand type (see below). */ - virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0; + virtual void pretranspose_B_array_generic(void *, const void *, const int, const int, bool) = 0; /* Threaded version with window start/end parameters */ virtual void - pretranspose_B_array_part_generic(void *, const void *, const int, const int, const size_t, const size_t) = 0; + pretranspose_B_array_part_generic(void *, const void *, const int, const int, bool, const size_t, const size_t) = 0; /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */ virtual void set_pretransposed_B_data(void *) @@ -251,28 +260,34 @@ public: /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */ - virtual void pretranspose_B_array(void *, const To *, const int, const int){}; + virtual void pretranspose_B_array(void *, const To *, const int, const int, bool){}; /* Implementation of the void * overload which casts its arguments to the appropriate type. */ - void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override + void pretranspose_B_array_generic( + void *out, const void *in, const int row_stride, const int multi_stride, bool transposed) override { - pretranspose_B_array(out, static_cast(in), row_stride, multi_stride); + pretranspose_B_array(out, static_cast(in), row_stride, multi_stride, transposed); } /* Threaded versions of the above. * The fallback/backwards compatible version of the threaded interface exposes a window size of 1 and * just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only * legal values for start and end are 0 and 1 respectively. */ - virtual void - pretranspose_B_array_part(void *out, const To *in, const int row_stride, const int multi_stride, size_t, size_t) + virtual void pretranspose_B_array_part( + void *out, const To *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t) { - pretranspose_B_array(out, in, row_stride, multi_stride); + pretranspose_B_array(out, in, row_stride, multi_stride, transposed); }; - void pretranspose_B_array_part_generic( - void *out, const void *in, const int row_stride, const int multi_stride, size_t start, size_t end) override + void pretranspose_B_array_part_generic(void *out, + const void *in, + const int row_stride, + const int multi_stride, + bool transposed, + size_t start, + size_t end) override { - pretranspose_B_array_part(out, static_cast(in), row_stride, multi_stride, start, end); + pretranspose_B_array_part(out, static_cast(in), row_stride, multi_stride, transposed, start, end); } /*** Indirect interface ***/ @@ -287,3 +302,5 @@ public: }; } // namespace arm_gemm + +#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_GEMM_COMMON_HPP diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 611bc76463..58ee68fd49 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023 Arm Limited. + * Copyright (c) 2018-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -60,7 +60,8 @@ void run_parallel_pretranspose_B_array(arm_gemm::GemmCommonpretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, start, end); + gemm_asm->pretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, transpose, start, + end); } }; } @@ -279,6 +281,8 @@ private: bool _B_pretranspose_required{false}; bool _is_b_constant{true}; bool _is_c_constant{true}; + bool _run_pre_pretranspose_b{false}; + bool _B_pre_pretranspose_required{false}; }; template @@ -443,8 +447,6 @@ void Fallback::configure(const ITensorInfo * const AsmGemmInfo &gemm_info, const OutputStage &os) { - ARM_COMPUTE_UNUSED(c); - _is_b_constant = b->are_values_constant(); _is_c_constant = c ? c->are_values_constant() : true; @@ -479,16 +481,23 @@ void Fallback::configure(const ITensorInfo * _optimised_kernel = std::move(acl_gemm_wrapper); _gemm_info = gemm_info; + // Check if we need to pre-pretranspose B. Fixed format kernels need no pre-pretranspose. - const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel(); - if (run_pre_pretranspose_b) + _B_pre_pretranspose_required = _gemm_info.transpose_b && !isVarWeightsKernel(); + _B_pretranspose_required = _gemm_kernel_asm->B_pretranspose_required(); + + const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose(); + const bool kernel_can_fuse_transpose = _B_pretranspose_required && kernel_supports_transpose; + _run_pre_pretranspose_b = _B_pre_pretranspose_required && !kernel_can_fuse_transpose; + + if (_run_pre_pretranspose_b) { _pre_pretranspose_b = std::make_unique(); _pre_pretranspose_b->configure(b, &_pre_pretransposed_b_info); MemoryLifetime lifetime; if (_is_b_constant) { - if (_gemm_kernel_asm->B_pretranspose_required()) + if (_B_pretranspose_required) { // PrePretransposedB tensor is only used in prepare(), but is then succeeded by Pretranspose // So PrePretransposedB can be freed inside prepare() @@ -513,7 +522,7 @@ void Fallback::configure(const ITensorInfo * } // Check for pre-transposed support - if (_gemm_kernel_asm->B_pretranspose_required()) + if (_B_pretranspose_required) { // Fixed format kernels need no pretranspose. ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format( @@ -524,7 +533,6 @@ void Fallback::configure(const ITensorInfo * _pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8); _aux_mem[Pretranspose] = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment); - _B_pretranspose_required = true; } // Handle indirect GEMM convolution @@ -550,15 +558,16 @@ void Fallback::prepare(ITensorPack &tensors) reinterpret_cast(c->buffer() + c->info()->offset_first_element_in_bytes()), 0); } const ITensor *b_to_use = b; + // Pre-pretranspose B if required - const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel(); CpuAuxTensorHandler pre_pretransposed_b( offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors, /*pack_inject: no need to inject into tensors*/ false, /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/ - !run_pre_pretranspose_b); - if (run_pre_pretranspose_b) + !_run_pre_pretranspose_b); + + if (_run_pre_pretranspose_b) { ARM_COMPUTE_ERROR_ON(_pre_pretranspose_b == nullptr); ITensorPack pre_pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pre_pretransposed_b.get()}}; @@ -567,24 +576,29 @@ void Fallback::prepare(ITensorPack &tensors) } // Pretranspose B if required - if (_gemm_kernel_asm->B_pretranspose_required()) + if (_B_pretranspose_required) { // Fixed format kernels need no pretranspose. ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format( assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); + const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); const auto in1_ptr = reinterpret_cast(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes()); const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false); + ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); - run_parallel_pretranspose_B_array(_gemm_kernel_asm.get(), pretranspose.get(), - in1_ptr, ldb, multi_stride_b, - NEScheduler::get().num_threads()); + + const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose(); + run_parallel_pretranspose_B_array( + _gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b, + NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose); b->mark_as_unused(); - // Note that we don't need to mark b_to_use as unused, as if it's been assigned to pre_pretransposed_b, its memory will be auto-managed by the handler + // Note that we don't need to mark b_to_use as unused, as if it's been assigned to pre_pretransposed_b, + // its memory will be auto-managed by the handler } if (_gemm_info.method == AsmConvMethod::Indirect) @@ -640,12 +654,11 @@ void Fallback::run(ITensorPack &tensors) const ITensor *b_to_use = b; // Pre-pretranspose B if required - const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel(); CpuAuxTensorHandler pre_pretransposed_b( offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors, false /*pack_inject: no need to inject into tensors*/, - !run_pre_pretranspose_b /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/); - if (b_to_use && !_is_b_constant && run_pre_pretranspose_b) + !_run_pre_pretranspose_b /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/); + if (b_to_use && !_is_b_constant && _run_pre_pretranspose_b) { ARM_COMPUTE_ERROR_ON(_pre_pretranspose_b == nullptr); ITensorPack pre_pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pre_pretransposed_b.get()}}; @@ -691,9 +704,10 @@ void Fallback::run(ITensorPack &tensors) } else { - run_parallel_pretranspose_B_array(_gemm_kernel_asm.get(), pretranspose.get(), - b_ptr, ldb, multi_stride_b, - NEScheduler::get().num_threads()); + const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose(); + run_parallel_pretranspose_B_array( + _gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b, + NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose); } } } -- cgit v1.2.1