diff options
author | Radu Salavat <radu.salavat@arm.com> | 2024-02-27 18:32:26 +0000 |
---|---|---|
committer | Radu Salavat <radu.salavat@arm.com> | 2024-04-11 08:47:50 +0000 |
commit | f1f1f87132690a8061801ef1a4638d637c780df7 (patch) | |
tree | 8ad4c3739217b3bc6281f4e0b9a7a63fe6c3f9bb /src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | |
parent | 1322065a3fbd15b00dbfb0969d6b438b5ba15530 (diff) | |
download | ComputeLibrary-f1f1f87132690a8061801ef1a4638d637c780df7.tar.gz |
Add in place summation to CPU GEMM kernels
Instead of dispatching the sum postop for GEMM kernels to a
separate kernel + add, that requires an extra destination sized
allocation, plus 3 extra load/stores per element,
just do it in the GEMM kernel.
Resolves: ONCPUML-1442
Signed-off-by: Radu Salavat <radu.salavat@arm.com>
Co-authored-by: Milos Puzovic <milos.puzovic@arm.com>
Change-Id: I7a1f2da3300875fa1ac88b705a34390969518077
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11298
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index efe2a7a67e..01a74a5a56 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -775,7 +775,7 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge arm_gemm::GemmConfig cfg; cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, - info.fixed_format, info.fast_mode, &cfg); + info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>(); @@ -800,7 +800,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> & arm_gemm::GemmConfig cfg; cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, - info.fixed_format, info.fast_mode, &cfg); + info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>(); @@ -855,8 +855,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, - info.fixed_format, info.fast_mode, &cfg); - + info.fixed_format, info.fast_mode, info.accumulate, &cfg); // TODO: Incorporate info.transpose_b COMPMID-6595 switch (a->data_type()) { |