From f1f1f87132690a8061801ef1a4638d637c780df7 Mon Sep 17 00:00:00 2001 From: Radu Salavat Date: Tue, 27 Feb 2024 18:32:26 +0000 Subject: Add in place summation to CPU GEMM kernels Instead of dispatching the sum postop for GEMM kernels to a separate kernel + add, that requires an extra destination sized allocation, plus 3 extra load/stores per element, just do it in the GEMM kernel. Resolves: ONCPUML-1442 Signed-off-by: Radu Salavat Co-authored-by: Milos Puzovic Change-Id: I7a1f2da3300875fa1ac88b705a34390969518077 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11298 Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- arm_compute/function_info/GEMMInfo.h | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) (limited to 'arm_compute') diff --git a/arm_compute/function_info/GEMMInfo.h b/arm_compute/function_info/GEMMInfo.h index a827c79fda..74fe30454e 100644 --- a/arm_compute/function_info/GEMMInfo.h +++ b/arm_compute/function_info/GEMMInfo.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2023 Arm Limited. + * Copyright (c) 2016-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -85,7 +85,8 @@ public: _pretranspose_B(false), _activation_info(), _fixed_format(false), - _weight_format(arm_compute::WeightFormat::UNSPECIFIED) + _weight_format(arm_compute::WeightFormat::UNSPECIFIED), + _accumulate(false) { } /** Constructor @@ -106,6 +107,7 @@ public: * @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat. * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED. * @param[in] pretranspose_B (Optional) Pretranspose matrix B (transposition of its lowest 2 dimensions), in addition to and before, any further transformations of B + * @param[in] accumulate (Optional) Whether to accumulate in destination or not */ GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, @@ -120,7 +122,8 @@ public: const ActivationLayerInfo &activation_info = ActivationLayerInfo(), bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED, - bool pretranspose_B = false) noexcept + bool pretranspose_B = false, + bool accumulate = false) noexcept : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), @@ -135,7 +138,8 @@ public: _pretranspose_B(pretranspose_B), _activation_info(activation_info), _fixed_format(fixed_format), - _weight_format(weight_format) + _weight_format(weight_format), + _accumulate(accumulate) { } /** Flag which specifies if the matrix A has been reshaped @@ -294,7 +298,14 @@ public: { return _fixed_format; } - + /** Flag which specifies if GEMM should accumulate the result in destination or not. + * + * @return True if GEMM is accumulating the result. + */ + bool accumulate() const + { + return _accumulate; + } /** Set fixed-format flag * * @param[in] fixed_format sets whether or not to use fixed-format kernels @@ -303,12 +314,19 @@ public: { _fixed_format = fixed_format; } + /** Set accumulate flag + * + * @param[in] accumulate sets whether or not to use accumulation + */ + void set_accumulate(bool accumulate) + { + _accumulate = accumulate; + } arm_compute::WeightFormat weight_format() const { return _weight_format; } - /** Set weight format to be used * * @param[in] weight_format arm_compute::WeightFormat enumeration @@ -334,6 +352,7 @@ private: ActivationLayerInfo _activation_info; bool _fixed_format; arm_compute::WeightFormat _weight_format; + bool _accumulate; }; } //namespace arm_compute #endif // ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H -- cgit v1.2.1