aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels
diff options
context:
space:
mode:
authorDavid Mansell <David.Mansell@arm.com>2023-09-19 15:49:10 +0100
committerDavid Mansell <David.Mansell@arm.com>2023-10-10 15:03:19 +0000
commitfb9c25d27791d934300581596cce7c5875a79a80 (patch)
treee29a9f5e151e26145589390897c3d8314edd18c1 /src/core/NEON/kernels
parent9aa153ae5d60fd08ec165280621f1e4fa7602048 (diff)
downloadComputeLibrary-fb9c25d27791d934300581596cce7c5875a79a80.tar.gz
arm_gemm: fix 2D threading mode for SME2
"2D" threading mode was not setting the result pointer correctly for SME2 kernels with K blocking - for non-final blocks the result pointer should be NULL so that the intermediate results get written in the accumulator buffer by the kernel. Signed-off-by: David Mansell <David.Mansell@arm.com> Change-Id: Idefa538e190a086e1e44a91998ab7e949e3989e4 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10342 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp9
1 files changed, 8 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 13f548e39e..362a3e30ea 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -802,6 +802,13 @@ public:
}
}
+ Tr *result_ptr = this->_Cptr + (batch * this->_C_batch_stride) + (multi * this->_C_multi_stride);
+
+ // If we are using an accumulation buffer and this isn't the last pass, don't pass a result pointer.
+ if (_accumulation_buffer && !last_pass) {
+ result_ptr = nullptr;
+ }
+
// Perform the kernel and merge step, either separately or together as required.
kernel_and_merge<MergeStep, FixedFormat, OutputStage>::run(
#ifdef CYCLE_PROFILING
@@ -810,7 +817,7 @@ public:
// Strategy and panel pointers
strat, a_panel, b_ptr, this->_ldb, c_panel,
// Result buffer pointers
- this->_Cptr + (batch * this->_C_batch_stride) + (multi * this->_C_multi_stride), this->_ldc,
+ result_ptr, this->_ldc,
// K size, and M/N ranges
kern_k, start_row, end_row, start_x, end_x,
// Only do bias on the first pass