From f1f1f87132690a8061801ef1a4638d637c780df7 Mon Sep 17 00:00:00 2001 From: Radu Salavat Date: Tue, 27 Feb 2024 18:32:26 +0000 Subject: Add in place summation to CPU GEMM kernels Instead of dispatching the sum postop for GEMM kernels to a separate kernel + add, that requires an extra destination sized allocation, plus 3 extra load/stores per element, just do it in the GEMM kernel. Resolves: ONCPUML-1442 Signed-off-by: Radu Salavat Co-authored-by: Milos Puzovic Change-Id: I7a1f2da3300875fa1ac88b705a34390969518077 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11298 Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 4 ++-- src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp | 6 +++--- src/core/NEON/kernels/arm_gemm/gemm_int8.cpp | 6 +++--- src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp | 9 +++++---- src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp | 6 +++--- src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp | 2 +- src/cpu/kernels/assembly/arm_gemm.hpp | 11 ++++++++++- src/cpu/operators/CpuGemm.cpp | 13 ++++++++++++- src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp | 10 +++++++++- src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 7 +++---- src/cpu/operators/internal/CpuGemmAssemblyDispatch.h | 3 ++- 11 files changed, 53 insertions(+), 24 deletions(-) (limited to 'src') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index e85dd59425..290fe87230 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -293,14 +293,14 @@ GemmImplementation::with_estimate( { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_fp32_mla_8x4", - [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmHybrid(args); } }, { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_fp32_mla_6x4", - [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input; }, + [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmHybrid(args); } }, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp index 89c2d5a23e..0cc4d4f3d9 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -530,7 +530,7 @@ public: (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), - !first_pass, + !first_pass || _args._accumulate, // Quantization parameters _os, _col_bias+(multi * _args._Nsize), n0); } else if (_convolver) { @@ -563,7 +563,7 @@ public: (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), - !first_pass, + !first_pass || _args._accumulate, // Quantization parameters _os, _col_bias+(multi * _args._Nsize), n0); } else { @@ -579,7 +579,7 @@ public: (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), - !first_pass, + !first_pass || _args._accumulate, // Quantization parameters _os, _col_bias+(multi * _args._Nsize), n0); } diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index fd20e53f60..0dc0d55b27 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020, 2022-2023 Arm Limited. + * Copyright (c) 2017-2020, 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -128,14 +128,14 @@ GemmImplementation::with_estimate( { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_s8s32_dot_8x4", - [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; }, [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); }, [](const GemmArgs &args) { return new GemmHybrid(args); } }, { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_s8s32_dot_6x4", - [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; }, [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); }, [](const GemmArgs &args) { return new GemmHybrid(args); } }, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index 4f732f7d94..d8b464584a 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -350,6 +350,7 @@ class GemmInterleaved : public GemmCommon { const bool _thread_columns; const Activation _act; + const bool _accumulate; const int _maxthreads; int _nthreads; @@ -680,7 +681,7 @@ public: _Ksections(args._Ksections), _Ktotal(get_ktotal(args)), _rounded_Ksize(roundup(_Ksize, strategy::k_unroll())), _nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)), - _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), + _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), _k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())), _os(os) { } @@ -690,7 +691,7 @@ public: _Ksections(args._Ksections), _Ktotal(get_ktotal(args)), _rounded_Ksize(roundup(_Ksize, strategy::k_unroll())), _nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)), - _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), + _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), _k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())), _os() { } @@ -823,7 +824,7 @@ public: // Only do bias on the first pass ((first_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr), // Only do activation on the last pass, and accumulation on any non-first pass. - (last_pass ? _act : Activation()), !first_pass, + (last_pass ? _act : Activation()), (!first_pass || _accumulate), // Pass in quantization parameters for requantizing kernels (others will ignore) _os, col_bias + (multi * _Nsize), // Accumulation buffer @@ -971,7 +972,7 @@ public: // Only do bias on the first pass ((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr), // Only do activation on the last pass, and accumulation on any non-first pass. - (last_pass ? _act : Activation()), !first_pass, + (last_pass ? _act : Activation()), (!first_pass || _accumulate), // Pass in quantization parameters for requantizing kernels (others will ignore) _os, col_bias + (current.multi() * _Nsize), // Accumulation buffer diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp index af5cfbbf2b..dfacb687a8 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020, 2022-2023 Arm Limited. + * Copyright (c) 2017-2020, 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -94,14 +94,14 @@ GemmImplementation::with_estimate( { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_u8u32_dot_8x4", - [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; }, [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); }, [](const GemmArgs &args) { return new GemmHybrid(args); } }, { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_u8u32_dot_6x4", - [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; }, [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); }, [](const GemmArgs &args) { return new GemmHybrid(args); } }, diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp index 92c884ce18..dbada36052 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp @@ -180,7 +180,7 @@ public: this->_Cptr + (multi * this->_C_multi_stride) + n, (nmax - n), (kmax-k0), this->_bias ? this->_bias + (multi * this->_bias_multi_stride) + n : nullptr, - _args._act, (k0 != 0), + _args._act, (k0 != 0) || _args._accumulate, _os, col_bias, n + (_args._Nsize * multi)); } } diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp index 9a913c5c58..5d7cf79857 100644 --- a/src/cpu/kernels/assembly/arm_gemm.hpp +++ b/src/cpu/kernels/assembly/arm_gemm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022 Arm Limited. + * Copyright (c) 2018-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,6 +21,10 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + +#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP +#define ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP + #pragma once #include "arm_gemm_local.hpp" @@ -151,6 +155,7 @@ public: int _maxthreads; bool _fixed_format; bool _fast_mode; + bool _accumulate; const GemmConfig *_cfg; GemmArgs(const CPUInfo *ci, @@ -165,6 +170,7 @@ public: const int maxthreads, bool fixed_format = false, bool fast_mode = false, + bool accumulate = false, const GemmConfig *cfg = nullptr) : _ci(ci), _Msize(M), @@ -178,6 +184,7 @@ public: _maxthreads(maxthreads), _fixed_format(fixed_format), _fast_mode(fast_mode), + _accumulate(accumulate), _cfg(cfg) { } @@ -278,3 +285,5 @@ template bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {}); } // namespace arm_gemm + +#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp index e035de0131..905e86c185 100644 --- a/src/cpu/operators/CpuGemm.cpp +++ b/src/cpu/operators/CpuGemm.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023 Arm Limited. + * Copyright (c) 2021-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -53,6 +53,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) asm_info.fast_mode = info.fast_math(); asm_info.fixed_format = info.fixed_format(); asm_info.weight_format = info.weight_format(); + asm_info.accumulate = info.accumulate(); asm_info.transpose_b = info.pretranspose_B(); // The "pretranspose_B" flag here is not the same as the pretranspose_B_array method. The flag here signals to pretranspose_B_array method if we want to perform additional transpose on B before the pretranspose_B_array method @@ -219,6 +220,16 @@ Status CpuGemm::validate(const ITensorInfo *a, const GEMMInfo &gemm_info) { ARM_COMPUTE_UNUSED(alpha); + // When using accumulation(in place summation), for now, the only supported values for alpha and beta are 1 respectively 0. + // Do the appropriate checks before proceeding. + if (gemm_info.accumulate()) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(alpha != 1, "Accumulation is not supported when alpha is different from 1"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + (beta != 0 && c != nullptr), + "Accumulation is not supported when beta is different from 0 with a non-null bias matrix c"); + } + const bool is_c_bias = beta == 1 && c != nullptr; const bool run_addition = c != nullptr && beta != 0 && beta != 1; // Check if we should use the pretransposed_b or original b diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp index b25505a85d..94e86c6077 100644 --- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp +++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023 Arm Limited. + * Copyright (c) 2021-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -65,6 +65,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) asm_info.activation_info = info.activation_info(); asm_info.output_stage = info.gemmlowp_output_stage(); asm_info.fast_mode = info.fast_math(); + asm_info.accumulate = info.accumulate(); return asm_info; } @@ -343,6 +344,13 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported"); + // When using accumulation(in place summation), for now, the only supported DataType for output is S32. + if (gemm_info.accumulate()) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE, + "Accumulation is not supported for output QASYMM8/QASYMM8_SIGNED"); + } + GEMMInfo info = gemm_info; const ITensorInfo *matrix_a_info = a; const ITensorInfo *matrix_b_info = b; diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index efe2a7a67e..01a74a5a56 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -775,7 +775,7 @@ void create_arm_gemm(std::unique_ptr &arm_ge arm_gemm::GemmConfig cfg; cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, - info.fixed_format, info.fast_mode, &cfg); + info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback auto fallback = std::make_unique>(); @@ -800,7 +800,7 @@ void create_arm_gemm_quant(std::unique_ptr & arm_gemm::GemmConfig cfg; cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, - info.fixed_format, info.fast_mode, &cfg); + info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback auto fallback = std::make_unique>(); @@ -855,8 +855,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, - info.fixed_format, info.fast_mode, &cfg); - + info.fixed_format, info.fast_mode, info.accumulate, &cfg); // TODO: Incorporate info.transpose_b COMPMID-6595 switch (a->data_type()) { diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h index 671a222fed..44c5c189a5 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023 Arm Limited. + * Copyright (c) 2018-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -57,6 +57,7 @@ struct AsmGemmInfo bool fixed_format{false}; arm_compute::WeightFormat weight_format{arm_compute::WeightFormat::UNSPECIFIED}; bool reshape_b_only_on_first_run{true}; + bool accumulate{false}; /** Whether we want to perform an additional transpose of b before passing it to gemm or pretranspose_B_array * @note This transpose b operation is also considered a form of "reshape" or "transform", so should be counted for * by the reshape_b_only_on_first_run flag -- cgit v1.2.1