aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators')
-rw-r--r--src/cpu/operators/CpuGemm.cpp13
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp10
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp7
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h3
4 files changed, 26 insertions, 7 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index e035de0131..905e86c185 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,6 +53,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.fast_mode = info.fast_math();
asm_info.fixed_format = info.fixed_format();
asm_info.weight_format = info.weight_format();
+ asm_info.accumulate = info.accumulate();
asm_info.transpose_b =
info.pretranspose_B(); // The "pretranspose_B" flag here is not the same as the pretranspose_B_array method. The flag here signals to pretranspose_B_array method if we want to perform additional transpose on B before the pretranspose_B_array method
@@ -219,6 +220,16 @@ Status CpuGemm::validate(const ITensorInfo *a,
const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
+ // When using accumulation(in place summation), for now, the only supported values for alpha and beta are 1 respectively 0.
+ // Do the appropriate checks before proceeding.
+ if (gemm_info.accumulate())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(alpha != 1, "Accumulation is not supported when alpha is different from 1");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ (beta != 0 && c != nullptr),
+ "Accumulation is not supported when beta is different from 0 with a non-null bias matrix c");
+ }
+
const bool is_c_bias = beta == 1 && c != nullptr;
const bool run_addition = c != nullptr && beta != 0 && beta != 1;
// Check if we should use the pretransposed_b or original b
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index b25505a85d..94e86c6077 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -65,6 +65,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.activation_info = info.activation_info();
asm_info.output_stage = info.gemmlowp_output_stage();
asm_info.fast_mode = info.fast_math();
+ asm_info.accumulate = info.accumulate();
return asm_info;
}
@@ -343,6 +344,13 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
+ // When using accumulation(in place summation), for now, the only supported DataType for output is S32.
+ if (gemm_info.accumulate())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE,
+ "Accumulation is not supported for output QASYMM8/QASYMM8_SIGNED");
+ }
+
GEMMInfo info = gemm_info;
const ITensorInfo *matrix_a_info = a;
const ITensorInfo *matrix_b_info = b;
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index efe2a7a67e..01a74a5a56 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -775,7 +775,7 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
arm_gemm::GemmConfig cfg;
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
- info.fixed_format, info.fast_mode, &cfg);
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
@@ -800,7 +800,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
arm_gemm::GemmConfig cfg;
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
- info.fixed_format, info.fast_mode, &cfg);
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
@@ -855,8 +855,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
- info.fixed_format, info.fast_mode, &cfg);
-
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// TODO: Incorporate info.transpose_b COMPMID-6595
switch (a->data_type())
{
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 671a222fed..44c5c189a5 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2023 Arm Limited.
+ * Copyright (c) 2018-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -57,6 +57,7 @@ struct AsmGemmInfo
bool fixed_format{false};
arm_compute::WeightFormat weight_format{arm_compute::WeightFormat::UNSPECIFIED};
bool reshape_b_only_on_first_run{true};
+ bool accumulate{false};
/** Whether we want to perform an additional transpose of b before passing it to gemm or pretranspose_B_array
* @note This transpose b operation is also considered a form of "reshape" or "transform", so should be counted for
* by the reshape_b_only_on_first_run flag