From f1f1f87132690a8061801ef1a4638d637c780df7 Mon Sep 17 00:00:00 2001 From: Radu Salavat Date: Tue, 27 Feb 2024 18:32:26 +0000 Subject: Add in place summation to CPU GEMM kernels Instead of dispatching the sum postop for GEMM kernels to a separate kernel + add, that requires an extra destination sized allocation, plus 3 extra load/stores per element, just do it in the GEMM kernel. Resolves: ONCPUML-1442 Signed-off-by: Radu Salavat Co-authored-by: Milos Puzovic Change-Id: I7a1f2da3300875fa1ac88b705a34390969518077 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11298 Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- tests/validation/reference/GEMM.h | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'tests/validation/reference/GEMM.h') diff --git a/tests/validation/reference/GEMM.h b/tests/validation/reference/GEMM.h index 5feaeda584..1b97570122 100644 --- a/tests/validation/reference/GEMM.h +++ b/tests/validation/reference/GEMM.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 Arm Limited. + * Copyright (c) 2017-2019, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_TEST_GEMM_H -#define ARM_COMPUTE_TEST_GEMM_H +#ifndef ACL_TESTS_VALIDATION_REFERENCE_GEMM_H +#define ACL_TESTS_VALIDATION_REFERENCE_GEMM_H #include "tests/SimpleTensor.h" #include "tests/validation/Helpers.h" @@ -41,8 +41,11 @@ SimpleTensor gemm(const SimpleTensor &a, const SimpleTensor &b, const S template ::value, int>::type = 0> SimpleTensor gemm_mixed_precision(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta); +template ::value, int>::type = 0> +void gemm_accumulate(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta, SimpleTensor &dst); + } // namespace reference } // namespace validation } // namespace test } // namespace arm_compute -#endif /* ARM_COMPUTE_TEST_GEMM_H */ +#endif // ACL_TESTS_VALIDATION_REFERENCE_GEMM_H -- cgit v1.2.1