aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRadu Salavat <radu.salavat@arm.com>2024-02-27 18:32:26 +0000
committerRadu Salavat <radu.salavat@arm.com>2024-04-11 08:47:50 +0000
commitf1f1f87132690a8061801ef1a4638d637c780df7 (patch)
tree8ad4c3739217b3bc6281f4e0b9a7a63fe6c3f9bb
parent1322065a3fbd15b00dbfb0969d6b438b5ba15530 (diff)
downloadComputeLibrary-f1f1f87132690a8061801ef1a4638d637c780df7.tar.gz
Add in place summation to CPU GEMM kernels
Instead of dispatching the sum postop for GEMM kernels to a separate kernel + add, that requires an extra destination sized allocation, plus 3 extra load/stores per element, just do it in the GEMM kernel. Resolves: ONCPUML-1442 Signed-off-by: Radu Salavat <radu.salavat@arm.com> Co-authored-by: Milos Puzovic <milos.puzovic@arm.com> Change-Id: I7a1f2da3300875fa1ac88b705a34390969518077 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11298 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/function_info/GEMMInfo.h31
-rw-r--r--docs/user_guide/release_version_and_change_log.dox1
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp4
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int8.cpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp9
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp2
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp11
-rw-r--r--src/cpu/operators/CpuGemm.cpp13
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp10
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp7
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h3
-rw-r--r--tests/datasets/LargeGEMMDataset.h21
-rw-r--r--tests/datasets/SmallGEMMDataset.h19
-rw-r--r--tests/validation/CL/GEMMLowp.cpp13
-rw-r--r--tests/validation/NEON/GEMM.cpp145
-rw-r--r--tests/validation/NEON/GEMMLowp.cpp89
-rw-r--r--tests/validation/fixtures/GEMMFixture.h60
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h127
-rw-r--r--tests/validation/reference/GEMM.cpp30
-rw-r--r--tests/validation/reference/GEMM.h11
-rw-r--r--tests/validation/reference/GEMMLowp.cpp12
-rw-r--r--tests/validation/reference/GEMMLowp.h11
24 files changed, 471 insertions, 176 deletions
diff --git a/arm_compute/function_info/GEMMInfo.h b/arm_compute/function_info/GEMMInfo.h
index a827c79fda..74fe30454e 100644
--- a/arm_compute/function_info/GEMMInfo.h
+++ b/arm_compute/function_info/GEMMInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2023 Arm Limited.
+ * Copyright (c) 2016-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -85,7 +85,8 @@ public:
_pretranspose_B(false),
_activation_info(),
_fixed_format(false),
- _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
+ _weight_format(arm_compute::WeightFormat::UNSPECIFIED),
+ _accumulate(false)
{
}
/** Constructor
@@ -106,6 +107,7 @@ public:
* @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
* @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
* @param[in] pretranspose_B (Optional) Pretranspose matrix B (transposition of its lowest 2 dimensions), in addition to and before, any further transformations of B
+ * @param[in] accumulate (Optional) Whether to accumulate in destination or not
*/
GEMMInfo(bool is_a_reshaped,
bool is_b_reshaped,
@@ -120,7 +122,8 @@ public:
const ActivationLayerInfo &activation_info = ActivationLayerInfo(),
bool fixed_format = false,
arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED,
- bool pretranspose_B = false) noexcept
+ bool pretranspose_B = false,
+ bool accumulate = false) noexcept
: _is_a_reshaped(is_a_reshaped),
_is_b_reshaped(is_b_reshaped),
_reshape_b_only_on_first_run(reshape_b_only_on_first_run),
@@ -135,7 +138,8 @@ public:
_pretranspose_B(pretranspose_B),
_activation_info(activation_info),
_fixed_format(fixed_format),
- _weight_format(weight_format)
+ _weight_format(weight_format),
+ _accumulate(accumulate)
{
}
/** Flag which specifies if the matrix A has been reshaped
@@ -294,7 +298,14 @@ public:
{
return _fixed_format;
}
-
+ /** Flag which specifies if GEMM should accumulate the result in destination or not.
+ *
+ * @return True if GEMM is accumulating the result.
+ */
+ bool accumulate() const
+ {
+ return _accumulate;
+ }
/** Set fixed-format flag
*
* @param[in] fixed_format sets whether or not to use fixed-format kernels
@@ -303,12 +314,19 @@ public:
{
_fixed_format = fixed_format;
}
+ /** Set accumulate flag
+ *
+ * @param[in] accumulate sets whether or not to use accumulation
+ */
+ void set_accumulate(bool accumulate)
+ {
+ _accumulate = accumulate;
+ }
arm_compute::WeightFormat weight_format() const
{
return _weight_format;
}
-
/** Set weight format to be used
*
* @param[in] weight_format arm_compute::WeightFormat enumeration
@@ -334,6 +352,7 @@ private:
ActivationLayerInfo _activation_info;
bool _fixed_format;
arm_compute::WeightFormat _weight_format;
+ bool _accumulate;
};
} //namespace arm_compute
#endif // ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H
diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox
index aa27c2b44c..b8910c9237 100644
--- a/docs/user_guide/release_version_and_change_log.dox
+++ b/docs/user_guide/release_version_and_change_log.dox
@@ -48,6 +48,7 @@ v24.04 Public major release
- Add support for SoftMax in SME2 for FP32.
- Performance optimizations:
- Optimize @ref NESoftmaxLayer for axis != 0 by natively supporting higher axes up to axis 3.
+ - Add support for in place accumulation to CPU GEMM kernels.
v24.02.1 Public patch release
- Fix performance regression in fixed-format kernels
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index e85dd59425..290fe87230 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -293,14 +293,14 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_fp32_mla_8x4",
- [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_8x4, float, float>(args); }
},
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_fp32_mla_6x4",
- [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input; },
+ [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_6x4, float, float>(args); }
},
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index 89c2d5a23e..0cc4d4f3d9 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -530,7 +530,7 @@ public:
(m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg,
(this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
last_pass ? _args._act : Activation(),
- !first_pass,
+ !first_pass || _args._accumulate,
// Quantization parameters
_os, _col_bias+(multi * _args._Nsize), n0);
} else if (_convolver) {
@@ -563,7 +563,7 @@ public:
(m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg,
(this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
last_pass ? _args._act : Activation(),
- !first_pass,
+ !first_pass || _args._accumulate,
// Quantization parameters
_os, _col_bias+(multi * _args._Nsize), n0);
} else {
@@ -579,7 +579,7 @@ public:
(m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg,
(this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
last_pass ? _args._act : Activation(),
- !first_pass,
+ !first_pass || _args._accumulate,
// Quantization parameters
_os, _col_bias+(multi * _args._Nsize), n0);
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index fd20e53f60..0dc0d55b27 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020, 2022-2023 Arm Limited.
+ * Copyright (c) 2017-2020, 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -128,14 +128,14 @@ GemmImplementation<int8_t, int32_t>::with_estimate(
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_s8s32_dot_8x4",
- [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; },
[](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_8x4, int8_t, int32_t>(args); }
},
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_s8s32_dot_6x4",
- [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; },
[](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_6x4, int8_t, int32_t>(args); }
},
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 4f732f7d94..d8b464584a 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -350,6 +350,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
const bool _thread_columns;
const Activation _act;
+ const bool _accumulate;
const int _maxthreads;
int _nthreads;
@@ -680,7 +681,7 @@ public:
_Ksections(args._Ksections), _Ktotal(get_ktotal(args)),
_rounded_Ksize(roundup(_Ksize, strategy::k_unroll())),
_nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)),
- _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
_k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())),
_os(os) { }
@@ -690,7 +691,7 @@ public:
_Ksections(args._Ksections), _Ktotal(get_ktotal(args)),
_rounded_Ksize(roundup(_Ksize, strategy::k_unroll())),
_nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)),
- _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
_k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())),
_os() { }
@@ -823,7 +824,7 @@ public:
// Only do bias on the first pass
((first_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr),
// Only do activation on the last pass, and accumulation on any non-first pass.
- (last_pass ? _act : Activation()), !first_pass,
+ (last_pass ? _act : Activation()), (!first_pass || _accumulate),
// Pass in quantization parameters for requantizing kernels (others will ignore)
_os, col_bias + (multi * _Nsize),
// Accumulation buffer
@@ -971,7 +972,7 @@ public:
// Only do bias on the first pass
((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr),
// Only do activation on the last pass, and accumulation on any non-first pass.
- (last_pass ? _act : Activation()), !first_pass,
+ (last_pass ? _act : Activation()), (!first_pass || _accumulate),
// Pass in quantization parameters for requantizing kernels (others will ignore)
_os, col_bias + (current.multi() * _Nsize),
// Accumulation buffer
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index af5cfbbf2b..dfacb687a8 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020, 2022-2023 Arm Limited.
+ * Copyright (c) 2017-2020, 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -94,14 +94,14 @@ GemmImplementation<uint8_t, uint32_t>::with_estimate(
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_u8u32_dot_8x4",
- [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; },
[](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint32_t>(args); }
},
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_u8u32_dot_6x4",
- [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; },
[](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint32_t>(args); }
},
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index 92c884ce18..dbada36052 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -180,7 +180,7 @@ public:
this->_Cptr + (multi * this->_C_multi_stride) + n,
(nmax - n), (kmax-k0),
this->_bias ? this->_bias + (multi * this->_bias_multi_stride) + n : nullptr,
- _args._act, (k0 != 0),
+ _args._act, (k0 != 0) || _args._accumulate,
_os, col_bias, n + (_args._Nsize * multi));
}
}
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 9a913c5c58..5d7cf79857 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,6 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
+#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
+#define ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
+
#pragma once
#include "arm_gemm_local.hpp"
@@ -151,6 +155,7 @@ public:
int _maxthreads;
bool _fixed_format;
bool _fast_mode;
+ bool _accumulate;
const GemmConfig *_cfg;
GemmArgs(const CPUInfo *ci,
@@ -165,6 +170,7 @@ public:
const int maxthreads,
bool fixed_format = false,
bool fast_mode = false,
+ bool accumulate = false,
const GemmConfig *cfg = nullptr)
: _ci(ci),
_Msize(M),
@@ -178,6 +184,7 @@ public:
_maxthreads(maxthreads),
_fixed_format(fixed_format),
_fast_mode(fast_mode),
+ _accumulate(accumulate),
_cfg(cfg)
{
}
@@ -278,3 +285,5 @@ template <typename Top, typename Tret, class OutputStage = Nothing>
bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm
+
+#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index e035de0131..905e86c185 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,6 +53,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.fast_mode = info.fast_math();
asm_info.fixed_format = info.fixed_format();
asm_info.weight_format = info.weight_format();
+ asm_info.accumulate = info.accumulate();
asm_info.transpose_b =
info.pretranspose_B(); // The "pretranspose_B" flag here is not the same as the pretranspose_B_array method. The flag here signals to pretranspose_B_array method if we want to perform additional transpose on B before the pretranspose_B_array method
@@ -219,6 +220,16 @@ Status CpuGemm::validate(const ITensorInfo *a,
const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
+ // When using accumulation(in place summation), for now, the only supported values for alpha and beta are 1 respectively 0.
+ // Do the appropriate checks before proceeding.
+ if (gemm_info.accumulate())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(alpha != 1, "Accumulation is not supported when alpha is different from 1");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ (beta != 0 && c != nullptr),
+ "Accumulation is not supported when beta is different from 0 with a non-null bias matrix c");
+ }
+
const bool is_c_bias = beta == 1 && c != nullptr;
const bool run_addition = c != nullptr && beta != 0 && beta != 1;
// Check if we should use the pretransposed_b or original b
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index b25505a85d..94e86c6077 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -65,6 +65,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.activation_info = info.activation_info();
asm_info.output_stage = info.gemmlowp_output_stage();
asm_info.fast_mode = info.fast_math();
+ asm_info.accumulate = info.accumulate();
return asm_info;
}
@@ -343,6 +344,13 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
+ // When using accumulation(in place summation), for now, the only supported DataType for output is S32.
+ if (gemm_info.accumulate())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE,
+ "Accumulation is not supported for output QASYMM8/QASYMM8_SIGNED");
+ }
+
GEMMInfo info = gemm_info;
const ITensorInfo *matrix_a_info = a;
const ITensorInfo *matrix_b_info = b;
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index efe2a7a67e..01a74a5a56 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -775,7 +775,7 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
arm_gemm::GemmConfig cfg;
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
- info.fixed_format, info.fast_mode, &cfg);
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
@@ -800,7 +800,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
arm_gemm::GemmConfig cfg;
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
- info.fixed_format, info.fast_mode, &cfg);
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
@@ -855,8 +855,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
- info.fixed_format, info.fast_mode, &cfg);
-
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// TODO: Incorporate info.transpose_b COMPMID-6595
switch (a->data_type())
{
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 671a222fed..44c5c189a5 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2023 Arm Limited.
+ * Copyright (c) 2018-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -57,6 +57,7 @@ struct AsmGemmInfo
bool fixed_format{false};
arm_compute::WeightFormat weight_format{arm_compute::WeightFormat::UNSPECIFIED};
bool reshape_b_only_on_first_run{true};
+ bool accumulate{false};
/** Whether we want to perform an additional transpose of b before passing it to gemm or pretranspose_B_array
* @note This transpose b operation is also considered a form of "reshape" or "transform", so should be counted for
* by the reshape_b_only_on_first_run flag
diff --git a/tests/datasets/LargeGEMMDataset.h b/tests/datasets/LargeGEMMDataset.h
index 6cdff7f559..e45319ef57 100644
--- a/tests/datasets/LargeGEMMDataset.h
+++ b/tests/datasets/LargeGEMMDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2019, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_LARGE_GEMM_DATASET
-#define ARM_COMPUTE_TEST_LARGE_GEMM_DATASET
+#ifndef ACL_TESTS_DATASETS_LARGEGEMMDATASET_H
+#define ACL_TESTS_DATASETS_LARGEGEMMDATASET_H
#include "tests/datasets/GEMMDataset.h"
@@ -79,7 +79,20 @@ public:
add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U), TensorShape(128U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.3f);
}
};
+
+class LargeAccumulateGEMMDataset final : public GEMMDataset
+{
+public:
+ LargeAccumulateGEMMDataset()
+ {
+ add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 429U), TensorShape(871U, 429U), 1.0f, 0.0f);
+ add_config(TensorShape(1021U, 1U), TensorShape(783U, 1021U), TensorShape(783U, 1U), TensorShape(783U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(1021U, 1U), TensorShape(783U, 1021U), TensorShape(783U, 1U), TensorShape(783U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(941U, 1U), TensorShape(623U, 941U), TensorShape(623U, 1U), TensorShape(623U, 1U), 1.0f, 0.0f);
+ }
+};
+
} // namespace datasets
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_LARGE_GEMM_DATASET */
+#endif // ACL_TESTS_DATASETS_LARGEGEMMDATASET_H
diff --git a/tests/datasets/SmallGEMMDataset.h b/tests/datasets/SmallGEMMDataset.h
index c12f57b266..99c7abbf64 100644
--- a/tests/datasets/SmallGEMMDataset.h
+++ b/tests/datasets/SmallGEMMDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_SMALL_GEMM_DATASET
-#define ARM_COMPUTE_TEST_SMALL_GEMM_DATASET
+#ifndef ACL_TESTS_DATASETS_SMALLGEMMDATASET_H
+#define ACL_TESTS_DATASETS_SMALLGEMMDATASET_H
#include "tests/datasets/GEMMDataset.h"
@@ -97,7 +97,18 @@ public:
}
};
+class SmallAccumulateGEMMDataset final : public GEMMDataset
+{
+public:
+ SmallAccumulateGEMMDataset()
+ {
+ add_config(TensorShape(8U, 2U), TensorShape(16U, 8U), TensorShape(16U, 2U), TensorShape(16U, 2U), 1.0f, 0.0f);
+ add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U), TensorShape(23U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(21U, 13U), TensorShape(33U, 21U), TensorShape(33U, 13U), TensorShape(33U, 13U), 1.0f, 0.0f);
+ }
+};
+
} // namespace datasets
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_SMALL_GEMM_DATASET */
+#endif // ACL_TESTS_DATASETS_SMALLGEMMDATASET_H
diff --git a/tests/validation/CL/GEMMLowp.cpp b/tests/validation/CL/GEMMLowp.cpp
index 1ae9e96626..78d794a9bb 100644
--- a/tests/validation/CL/GEMMLowp.cpp
+++ b/tests/validation/CL/GEMMLowp.cpp
@@ -71,7 +71,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMLowpMatrixMultiplyCoreFixture, framework:
}
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
+ GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
TEST_SUITE(BatchedMatMul)
TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned, framework::DatasetMode::ALL,
@@ -84,7 +84,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFi
TEST_SUITE_END() // QASYMM8
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
+ GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedBatchedMatMulDataset(),
@@ -98,7 +98,7 @@ TEST_SUITE_END() // BatchedMatMul
TEST_SUITE(FusedOffsetOutput)
TEST_SUITE(QASYMM8)
-using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUint8Fixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore>;
+using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUint8Fixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUint8Fixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputUint8Dataset(),
make("DataType", { DataType::QASYMM8 }),
@@ -110,7 +110,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUi
TEST_SUITE(Output3D)
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputOutput3DUint8Fixture =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, true>;
+ GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputOutput3DUint8Fixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputOutput3DUint8Dataset(),
make("DataType", { DataType::QASYMM8 }),
@@ -123,7 +123,7 @@ TEST_SUITE_END() // Output3D
TEST_SUITE(InputOutput3D)
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputInputOutput3DUint8Fixture =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, true, true>;
+ GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, true, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputInputOutput3DUint8Fixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputInputOutput3DUint8Dataset(),
make("DataType", { DataType::QASYMM8 }),
@@ -148,7 +148,8 @@ using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputInt8Fixture =
GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputInt8Fixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputInt8Dataset(),
- make("DataType", { DataType::QASYMM8_SIGNED })))
+ make("DataType", { DataType::QASYMM8_SIGNED }),
+ make("reshape_b_only_on_first_run", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_quant);
diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp
index f956cdfeda..5f6a402204 100644
--- a/tests/validation/NEON/GEMM.cpp
+++ b/tests/validation/NEON/GEMM.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -51,6 +51,8 @@ namespace test
{
namespace validation
{
+using framework::dataset::make;
+
namespace
{
constexpr AbsoluteTolerance<float> tolerance_f(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for FP32 data types */
@@ -60,7 +62,7 @@ const AbsoluteTolerance<float> abs_tolerance_f16(0.2f); /**< Absolute
constexpr float tolerance_num = 0.07f; /**< Tolerance number for FP16 data types */
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
/** CNN data types */
-const auto CNNDataTypes = framework::dataset::make("DataType",
+const auto CNNDataTypes = make("DataType",
{
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
DataType::F16,
@@ -68,8 +70,8 @@ const auto CNNDataTypes = framework::dataset::make("DataType",
DataType::F32,
});
-const auto data_interleave = framework::dataset::make("M", 8, 12) * framework::dataset::make("N", 8, 12);
-const auto data_transpose = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14);
+const auto data_interleave = make("M", 8, 12) * make("N", 8, 12);
+const auto data_transpose = make("M", 8, 14) * make("N", 7, 14);
/** Zero padding test */
template <typename FunctionType>
@@ -204,16 +206,16 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
- framework::dataset::make("LhsInfo", { TensorInfo(TensorShape(27U, 13U), 1, DataType::S32), // Unsupported data type
+ make("LhsInfo", { TensorInfo(TensorShape(27U, 13U), 1, DataType::S32), // Unsupported data type
TensorInfo(TensorShape(27U, 13U), 1, DataType::F32),
}),
- framework::dataset::make("RhsInfo",{ TensorInfo(TensorShape(8U, 27U), 1, DataType::S32),
+ make("RhsInfo",{ TensorInfo(TensorShape(8U, 27U), 1, DataType::S32),
TensorInfo(TensorShape(8U, 27U), 1, DataType::F32),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(8U, 13U), 1, DataType::S32),
+ make("OutputInfo",{ TensorInfo(TensorShape(8U, 13U), 1, DataType::S32),
TensorInfo(TensorShape(8U, 13U), 1, DataType::F32),
})),
- framework::dataset::make("Expected", { false, true })),
+ make("Expected", { false, true })),
lhs_info, rhs_info, output_info, expected)
{
constexpr float alpha = 1.0;
@@ -226,8 +228,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
// *INDENT-ON*
TEST_SUITE(KERNEL_SELECTION)
DATA_TEST_CASE(KernelSelection_mul_and_add, framework::DatasetMode::ALL,
- combine(framework::dataset::make("CpuExt", std::string("NEON")),
- framework::dataset::make("DataType", { DataType::F32,
+ combine(make("CpuExt", std::string("NEON")),
+ make("DataType", { DataType::F32,
DataType::F16
})),
cpu_ext, data_type)
@@ -261,8 +263,8 @@ TEST_SUITE_END() // KERNEL_SELECTION
TEST_SUITE(TRANSPOSE_1XW)
using CpuGemmTranspose1xW = NESynthetizeFunctionWithZeroConstantKernelBorder<cpu::kernels::CpuGemmTranspose1xWKernel>;
DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
- framework::dataset::make("N", { 1, 23, 63, 101 }),
- framework::dataset::make("K", { 1, 47, 29, 27 })),
+ make("N", { 1, 23, 63, 101 }),
+ make("K", { 1, 47, 29, 27 })),
n_value, k_value)
{
bool status = validate_zero_padding<CpuGemmTranspose1xW>(n_value, k_value);
@@ -271,7 +273,7 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
TEST_SUITE(U32)
using CpuGemmTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, CpuGemmTranspose1xW, uint32_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U32))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * make("DataType", DataType::U32))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -280,7 +282,7 @@ TEST_SUITE_END() // U32
TEST_SUITE(U16)
using CpuGemmTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, CpuGemmTranspose1xW, uint16_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U16))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * make("DataType", DataType::U16))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -289,7 +291,7 @@ TEST_SUITE_END() // U16
TEST_SUITE(U8)
using CpuGemmTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, CpuGemmTranspose1xW, uint8_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U8))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * make("DataType", DataType::U8))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -302,8 +304,8 @@ TEST_SUITE(INTERLEAVE_4X4)
using CpuGemmInterleave4x4 = NESynthetizeFunctionWithZeroConstantKernelBorder<cpu::kernels::CpuGemmInterleave4x4Kernel>;
DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
- framework::dataset::make("M", { 1, 23, 63, 101 }),
- framework::dataset::make("K", { 1, 47, 29, 27 })),
+ make("M", { 1, 23, 63, 101 }),
+ make("K", { 1, 47, 29, 27 })),
m_value, k_value)
{
bool status = validate_zero_padding<cpu::kernels::CpuGemmInterleave4x4Kernel>(m_value, k_value);
@@ -312,7 +314,7 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
TEST_SUITE(U32)
using CpuGemmInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, CpuGemmInterleave4x4, uint32_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::U32))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * make("DataType", DataType::U32))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -321,7 +323,7 @@ TEST_SUITE_END() // U32
TEST_SUITE(U16)
using CpuGemmInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, CpuGemmInterleave4x4, uint16_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::U16))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * make("DataType", DataType::U16))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -330,7 +332,7 @@ TEST_SUITE_END() // U16
TEST_SUITE(U8)
using CpuGemmInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, CpuGemmInterleave4x4, uint8_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::QASYMM8))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * make("DataType", DataType::QASYMM8))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -345,15 +347,18 @@ using NEGEMMFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T>;
template <typename T>
using NEBatchedMatMulFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T, true, false, false, false, false, true>;
+template <typename T>
+using NEGEMMAccumulateFixture = GEMMAccumulateValidationFixture<Tensor, Accessor, NEGEMM, T>;
+
TEST_SUITE(Float)
-DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(framework::dataset::make("In0", { TensorShape(21U, 13U),
+DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(make("In0", { TensorShape(21U, 13U),
TensorShape(31U, 1U),
TensorShape(31U, 1U),
TensorShape(8U, 2U),
TensorShape(38U, 12U),
TensorShape(32U, 1U)
}),
- framework::dataset::make("In1", { TensorShape(33U, 21U),
+ make("In1", { TensorShape(33U, 21U),
TensorShape(23U, 31U),
TensorShape(23U, 31U),
TensorShape(16U, 8U),
@@ -366,75 +371,111 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(framework::
ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS);
}
+DATA_TEST_CASE(ValidateAccumulate, framework::DatasetMode::ALL, combine(
+ zip(make("In0",{ TensorShape(21U, 13U) }),
+ make("In1", { TensorShape(33U, 21U) }),
+ make("Dst", { TensorShape(33U, 13U) })),
+ zip(
+ make("alpha", { 1.0, 100.0, 1.0, 1.0 }),
+ make("beta", { 0.0, 0.0, 1.0, 1.0 }),
+ make("is_c_null", { false, false, false, true }),
+ make("Expected", { true, false, false, true }))),
+ shape_a, shape_b, shape_dst, alpha, beta, is_c_null, expected)
+{
+ /* Accumulation test for GEMM kernels */
+ // Create tensors
+ TensorInfo in_a(shape_a, 1, DataType::F32);
+ TensorInfo in_b(shape_b, 1, DataType::F32);
+ TensorInfo in_c(shape_dst, 1, DataType::F32);
+ TensorInfo dst(shape_dst, 1, DataType::F32);
+
+ GEMMInfo gemm_info = GEMMInfo();
+ gemm_info.set_accumulate(true);
+
+ // Validate accumulation
+ cpu::CpuGemm gemm;
+ Status status = gemm.validate(&in_a, &in_b, (is_c_null ? nullptr : &in_c), &dst, alpha, beta, gemm_info);
+ ARM_COMPUTE_EXPECT((expected == bool(status)), framework::LogLevel::ERRORS);
+}
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
- framework::dataset::make("DataType", DataType::F16)))
+ make("ReshapeWeights", { true, false })),
+ make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
}
-
-TEST_SUITE(BATCHED_MATMUL)
-
-FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
- framework::dataset::make("ReshapeWeights", { false })),
- framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
+ make("ReshapeWeights", { true, false })),
+ make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
}
-TEST_SUITE_END()
-FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
-
- framework::dataset::make("DataType", DataType::F16)))
+TEST_SUITE(BATCHED_MATMUL)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
+ make("ReshapeWeights", { false })),
+ make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
}
-TEST_SUITE_END()
+TEST_SUITE_END() // BATCHED_MATMUL
+
+TEST_SUITE_END() // FP16
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
-
- framework::dataset::make("DataType", DataType::F32)))
+ make("ReshapeWeights", { true, false })),
+ make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f);
}
FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
-
- framework::dataset::make("DataType", DataType::F32)))
+ make("ReshapeWeights", { true, false })),
+ make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f);
}
TEST_SUITE(BATCHED_MATMUL)
-
-TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
- framework::dataset::make("ReshapeWeights", { false })),
- framework::dataset::make("DataType", DataType::F32)))
+ make("ReshapeWeights", { false })),
+ make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f);
}
-TEST_SUITE_END()
+TEST_SUITE_END() // BATCHED_MATMUL
-TEST_SUITE_END()
+TEST_SUITE(ACCUMULATE)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMAccumulateFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallAccumulateGEMMDataset(),
+ make("ReshapeWeights", { false }),
+ make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMAccumulateFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeAccumulateGEMMDataset(),
+ make("ReshapeWeights", { false }),
+ make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f);
+}
+TEST_SUITE_END() // ACCUMULATE
-TEST_SUITE_END()
-TEST_SUITE_END()
+TEST_SUITE_END() // FP32
-TEST_SUITE_END()
-TEST_SUITE_END()
+TEST_SUITE_END() // Float
+TEST_SUITE_END() // GEMM
+TEST_SUITE_END() // NEON
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/NEON/GEMMLowp.cpp b/tests/validation/NEON/GEMMLowp.cpp
index 9c4d1741eb..1b07975bb3 100644
--- a/tests/validation/NEON/GEMMLowp.cpp
+++ b/tests/validation/NEON/GEMMLowp.cpp
@@ -47,14 +47,21 @@ namespace test
{
namespace validation
{
+using framework::dataset::make;
+
+namespace
+{
+ constexpr AbsoluteTolerance<float> tolerance_batched(1);
+ constexpr AbsoluteTolerance<float> tolerance_quant(1);
+} // namespace
+
+
TEST_SUITE(NEON)
TEST_SUITE(GEMMLowp)
TEST_SUITE(MatrixMultiplyCore)
using NEGEMMLowpMatrixMultiplyCoreFixture = GEMMLowpMatrixMultiplyCoreValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
-using NEGEMMLowpBatchedMatMulFixture = GEMMLowpMatrixMultiplyCoreValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, true>;
-
-using framework::dataset::make;
+using NEGEMMLowpMatrixMultiplyCoreAccumulateFixture = GEMMLowpMatrixMultiplyAccumulateValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, framework::dataset::concat(datasets::SmallGEMMLowpDataset(), datasets::LargeGEMMLowpDataset()),
shape_a, shape_b, shape_c, a_offset, b_offset)
@@ -81,6 +88,44 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, framework::dataset::c
validate(c.info()->padding(), PaddingSize());
}
+DATA_TEST_CASE(ValidateAccumulate, framework::DatasetMode::ALL, combine(
+ zip(
+ make("In0",{ TensorShape(21U, 1U) }),
+ make("In1", { TensorShape(1U, 21U) }),
+ make("Dst", { TensorShape(1U, 1U) }),
+ make("a_offset", { -2 }),
+ make("a_offset", { 13 })
+ ),
+ zip(
+ make("OutputDataType", { DataType::S32, DataType::QASYMM8, DataType::QASYMM8_SIGNED}),
+ make("Expected", { true, false, false })
+ )),
+ shape_a, shape_b, shape_dst, a_offset, b_offset, output_data_type, expected)
+{
+ DataType input_data_type = (output_data_type == DataType::S32 ? DataType::QASYMM8 : output_data_type);
+ // Accumulation test for GEMM kernels
+ TensorInfo a(shape_a, 1, input_data_type, QuantizationInfo(1.0f / 255, a_offset));
+ TensorInfo b(shape_b, 1, input_data_type, QuantizationInfo(1.0f / 255, b_offset));
+ TensorInfo dst(shape_dst, 1, output_data_type, QuantizationInfo());
+
+ // Create and configure function
+ GEMMInfo gemm_info = GEMMInfo();
+ gemm_info.set_accumulate(true);
+
+ if (is_data_type_quantized(output_data_type))
+ {
+ GEMMLowpOutputStageInfo gemmLowpOutputStageInfo = GEMMLowpOutputStageInfo();
+ gemmLowpOutputStageInfo.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
+
+ gemm_info.set_gemmlowp_output_stage(gemmLowpOutputStageInfo);
+ }
+
+ cpu::CpuGemmLowpMatrixMultiplyCore gemmlowp_mm;
+ Status status = gemmlowp_mm.validate(&a, &b, nullptr, &dst, gemm_info);
+
+ ARM_COMPUTE_EXPECT((expected == bool(status)), framework::LogLevel::ERRORS);
+}
+
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(
@@ -226,13 +271,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreFixture, framework:
validate(Accessor(_target), _reference);
}
-constexpr AbsoluteTolerance<float> tolerance_batched(1);
-
-using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
-
TEST_SUITE(BatchedMatMul)
TEST_SUITE(QASYMM8)
+using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned =
+ GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedBatchedMatMulDataset(),
make("DataType", { DataType::QASYMM8 }),
@@ -242,9 +284,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFi
}
TEST_SUITE_END() // QASYMM8
-using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
TEST_SUITE(QASYMM8_SIGNED)
+using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned =
+ GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedBatchedMatMulDataset(),
make("DataType", { DataType::QASYMM8_SIGNED }),
@@ -255,26 +297,41 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFi
TEST_SUITE_END() // QASYMM8_SIGNED
TEST_SUITE_END() // BatchedMatMul
-using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
-constexpr AbsoluteTolerance<float> tolerance_quant(1);
-
TEST_SUITE(FusedOffsetOutput)
+using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputUint8Dataset(),
- make("DataType", { DataType::QASYMM8 })))
+ make("DataType", { DataType::QASYMM8 }),
+ make("reshape_b_only_on_first_run", { false })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_quant);
}
-
FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture, framework::DatasetMode::NIGHTLY,
combine(datasets::LargeGEMMLowpFusedOffsetOutputUint8Dataset(),
- make("DataType", { DataType::QASYMM8 })))
+ make("DataType", { DataType::QASYMM8 }),
+ make("reshape_b_only_on_first_run", { false })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_quant);
}
TEST_SUITE_END() // FusedOffsetOutput
+
+TEST_SUITE(ACCUMULATION)
+TEST_SUITE(S32)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreAccumulateFixture, framework::DatasetMode::ALL, datasets::SmallGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreAccumulateFixture, framework::DatasetMode::NIGHTLY, datasets::LargeGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // S32
+TEST_SUITE_END() // ACCUMULATION
+
TEST_SUITE_END() // MatrixMultiplyCore
TEST_SUITE_END() // GEMMLowp
TEST_SUITE_END() // NEON
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index afde3d8067..94bedc83e1 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,14 +46,14 @@ namespace test
namespace validation
{
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool pretranspose_a = false, bool pretranspose_b = false, bool run_twice = false>
-class GEMMValidationFixture : public framework::Fixture
+class GEMMGenericValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type, bool accumulate=false)
{
ARM_COMPUTE_UNUSED(pretranspose);
- _target = compute_target(shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type);
- _reference = compute_reference(shape_a, shape_b, output_shape, alpha, beta, data_type);
+ _target = compute_target(shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type, accumulate);
+ _reference = compute_reference(shape_a, shape_b, output_shape, alpha, beta, data_type, accumulate);
}
protected:
@@ -80,7 +80,7 @@ protected:
}
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &output_shape, float alpha, float beta,
- DataType data_type)
+ DataType data_type, bool accumulate=false)
{
// Create tensors
TensorType a = create_tensor<TensorType>(shape_a, data_type, 1);
@@ -99,7 +99,7 @@ protected:
&dst,
alpha, beta,
GEMMInfo(false, false, false, (reinterpret_output_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d, false, GEMMLowpOutputStageInfo(), false, false, (reinterpret_input_as_3d
- || reinterpret_output_as_3d)));
+ || reinterpret_output_as_3d), arm_compute::ActivationLayerInfo(), false /* fixed_format */, arm_compute::WeightFormat::UNSPECIFIED, false /* pretranspose_B */, accumulate));
ARM_COMPUTE_ASSERT(a.info()->is_resizable());
ARM_COMPUTE_ASSERT(b.info()->is_resizable());
ARM_COMPUTE_ASSERT(c.info()->is_resizable());
@@ -121,11 +121,14 @@ protected:
// Fill tensors
fill(AccessorType(a), 0);
fill(AccessorType(b), 1);
+ if (accumulate)
+ {
+ fill(AccessorType(dst), 6);
+ }
if(!disable_c)
{
fill(AccessorType(c), 2);
}
-
// Run with variable inputs.
if(run_twice)
{
@@ -145,7 +148,7 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, float alpha, float beta,
- DataType data_type)
+ DataType data_type, bool accumulate=false)
{
TensorShape shape_a_to_use = shape_a;
if(reinterpret_input_as_3d)
@@ -158,6 +161,7 @@ protected:
SimpleTensor<T> a{ shape_a_to_use, data_type, 1 };
SimpleTensor<T> b{ shape_b, data_type, 1 };
SimpleTensor<T> c{ output_shape, data_type, 1 };
+ SimpleTensor<T> dst{ output_shape, data_type, 1 };
// Fill reference
fill(a, 0);
@@ -211,17 +215,51 @@ protected:
fill(c, 5);
}
+ // Do in place summation
+ if (accumulate)
+ {
+ fill(dst, 6);
+ }
+
// Setting beta to 0 will effectively disable C for the
// computation of the reference: alpha * A * B + 0 * C
// Use transposed tensors if boolean enabled else use original tensors
- auto r = reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta);
- return r;
+ if (accumulate)
+ {
+ reference::gemm_accumulate<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta, dst);
+ return dst;
+ }
+ else
+ {
+ return reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta);
+ }
}
TensorType _target{};
SimpleTensor<T> _reference{};
};
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool pretranspose_a = false, bool pretranspose_b = false, bool run_twice = false>
+class GEMMValidationFixture : protected GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
+ {
+ GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>::setup(shape_a, shape_b, shape_c, output_shape, alpha, beta, pretranspose, data_type, false /*accumulate*/);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool pretranspose_a = false, bool pretranspose_b = false, bool run_twice = false>
+class GEMMAccumulateValidationFixture : protected GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
+ {
+ bool accumulate = true;
+ GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>::setup(shape_a, shape_b, shape_c, output_shape, alpha, beta, pretranspose, data_type, accumulate);
+ }
+};
+
template <typename TensorType, typename AccessorType, typename T, typename GEMMOperatorType>
class GEMMMatrixMultiplyValidationFixture : public framework::Fixture
{
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index a65a1e6bd8..0f6908a468 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -30,6 +30,7 @@
#include "tests/framework/Fixture.h"
#include "tests/validation/Validation.h"
#include "tests/validation/reference/GEMMLowp.h"
+#include "tests/validation/reference/ArithmeticOperations.h"
#include <cstdint>
#include <vector>
@@ -42,16 +43,21 @@ namespace validation
{
namespace
{
-
template <typename U>
void fill(U &&tensor, int i)
{
+ library->fill_tensor_uniform(tensor, i);
+}
+
+template <typename U>
+void fill_quantized(U &&tensor, int i)
+{
ARM_COMPUTE_ASSERT(is_data_type_quantized(tensor.data_type()));
library->fill_tensor_uniform(tensor, i);
}
template <typename U>
-void fill_bias_s32(U &&tensor, int i, int32_t min, int32_t max)
+void fill_s32(U &&tensor, int i, int32_t min, int32_t max)
{
ARM_COMPUTE_ASSERT(tensor.data_type() == DataType::S32);
std::uniform_int_distribution<int32_t> distribution(min, max);
@@ -64,6 +70,11 @@ struct TensorFillInfo
// Bias fill range. Default values are arbitrary
int32_t min_bias {-20000};
int32_t max_bias {20000};
+
+ // Output fill range. Default values are arbitrary
+ int32_t min_output {-20000};
+ int32_t max_output {20000};
+
// Optional extra hash to randomize tensor filling
int32_t hash {0};
};
@@ -71,7 +82,8 @@ struct TensorFillInfo
template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false, bool run_twice = false>
TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo,
const QuantizationInfo& output_qinfo, DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8,
- GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), bool reshape_b_only_on_first_run = false, const TensorFillInfo& finfo = TensorFillInfo() )
+ GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), bool reshape_b_only_on_first_run = false, const TensorFillInfo& finfo = TensorFillInfo(),
+ bool accumulate = false)
{
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a));
ARM_COMPUTE_ASSERT(data_type_a == data_type_b);
@@ -93,7 +105,9 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape
// The GEMMinfo includes the values of the depth in case of reinterpreted 3d input/output
FunctionType gemmlowp;
gemmlowp.configure(&a, &b, is_fused ? &bias : nullptr, &output, GEMMInfo(false, false, reshape_b_only_on_first_run, (reinterpret_output_as_3d ? shape_output[2] : 0), reinterpret_input_as_3d, false,
- output_stage));
+ output_stage, false /*fp_mixed_precision*/, false /*fast_math*/, false /*broadcast_bias*/,
+ arm_compute::ActivationLayerInfo(), false /* fixed_format */, arm_compute::WeightFormat::UNSPECIFIED,
+ false /* pretranspose_B */, accumulate));
ARM_COMPUTE_ASSERT(a.info()->is_resizable());
ARM_COMPUTE_ASSERT(b.info()->is_resizable());
@@ -111,26 +125,32 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape
ARM_COMPUTE_ASSERT(!output.info()->is_resizable());
// Fill tensors
- fill(AccessorType(a), 0 + finfo.hash);
- fill(AccessorType(b), 1 + finfo.hash);
+ fill_quantized(AccessorType(a), 0 + finfo.hash);
+ fill_quantized(AccessorType(b), 1 + finfo.hash);
+
+ if (accumulate)
+ {
+ ARM_COMPUTE_ASSERT(accumulate != run_twice);
+ fill_s32(AccessorType(output), 6 + finfo.hash, finfo.min_output, finfo.max_output);
+ }
if(is_fused)
{
ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
bias.allocator()->allocate();
ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
- fill_bias_s32(AccessorType(bias), 2 + finfo.hash, finfo.min_bias, finfo.max_bias);
+ fill_s32(AccessorType(bias), 2 + finfo.hash, finfo.min_bias, finfo.max_bias);
}
// Run with variable inputs.
if(run_twice)
{
gemmlowp.run();
- fill(AccessorType(a), 3 + finfo.hash); // Fill tensors with new seed after run
- fill(AccessorType(b), 4 + finfo.hash);
+ fill_quantized(AccessorType(a), 3 + finfo.hash); // Fill tensors with new seed after run
+ fill_quantized(AccessorType(b), 4 + finfo.hash);
if(is_fused)
{
- fill_bias_s32(AccessorType(bias), 5 + finfo.hash, finfo.min_bias, finfo.max_bias);
+ fill_s32(AccessorType(bias), 5 + finfo.hash, finfo.min_bias, finfo.max_bias);
}
}
@@ -168,8 +188,8 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
SimpleTensor<TW> b_transposed{ shape_b_transposed, data_type_b, 1, b_qinfo };
// Fill reference
- fill(a, 0 + finfo.hash);
- fill(b, 1 + finfo.hash);
+ fill_quantized(a, 0 + finfo.hash);
+ fill_quantized(b, 1 + finfo.hash);
// Transpose reference if required
/* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if pretranspose_A is set to true, then A is assumed to be (B x K x M),
@@ -189,11 +209,12 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
// Run with variable inputs.
const int32_t a_offset = a_qinfo.uniform().offset;
const int32_t b_offset = b_qinfo.uniform().offset;
+
if(run_twice)
{
reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
- fill((pretranspose_A) ? a_transposed : a, 3 + finfo.hash);
- fill((pretranspose_B) ? b_transposed : b, 4 + finfo.hash);
+ fill_quantized((pretranspose_A) ? a_transposed : a, 3 + finfo.hash);
+ fill_quantized((pretranspose_B) ? b_transposed : b, 4 + finfo.hash);
}
return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
@@ -201,35 +222,68 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
} // namespace
template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
-class GEMMLowpMatrixMultiplyCoreValidationFixture : public framework::Fixture
+class GEMMLowpGenericMatrixMultiplyCoreValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, bool accumulate=false)
{
const auto a_qinfo = QuantizationInfo(1.0f / 255, a_offset);
const auto b_qinfo = QuantizationInfo(1.0f / 255, b_offset);
- _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo);
- _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo);
+ TensorFillInfo finfo;
+ _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate);
+ _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate);
}
protected:
- TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo)
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, const bool accumulate)
{
const auto output_qinfo = QuantizationInfo(); // No output stage
- return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo);
+ return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo,
+ DataType::QASYMM8, DataType::QASYMM8, GEMMLowpOutputStageInfo(), false, finfo, accumulate);
}
- SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo)
+ SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, bool accumulate)
{
- return compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo);
+ SimpleTensor<int32_t> ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo,
+ DataType::QASYMM8, DataType::QASYMM8, finfo);
+
+ if (accumulate)
+ {
+ SimpleTensor<int32_t> output{ shape_output, DataType::S32, 1 };
+ fill_s32(output, 6 + finfo.hash, finfo.min_output, finfo.max_output);
+ reference::arithmetic_operation<int32_t>(reference::ArithmeticOperation::ADD, output, ref_output, output, ConvertPolicy::SATURATE);
+ return output;
+ }
+
+ return ref_output;
}
TensorType _target{};
SimpleTensor<int32_t> _reference{};
};
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
+class GEMMLowpMatrixMultiplyCoreValidationFixture : protected GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ {
+ GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>::setup(shape_a, shape_b, shape_output, a_offset, b_offset, false /* accumulate */);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
+class GEMMLowpMatrixMultiplyAccumulateValidationFixture : protected GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ {
+ GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>::setup(shape_a, shape_b, shape_output, a_offset, b_offset, true /* accumulate */);
+ }
+};
+
template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t, bool run_twice = false>
-class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture : public framework::Fixture
+class GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public framework::Fixture
{
public:
/** Dynamically initialize the quantization info with saturation awareness
@@ -363,16 +417,16 @@ protected:
TensorShape bias_shape(shape_b[0]);
SimpleTensor<int32_t> bias{ bias_shape, DataType::S32, 1 };
- (run_twice) ? fill_bias_s32(bias, 5 + finfo.hash, finfo.min_bias, finfo.max_bias) : fill_bias_s32(bias, 2 + finfo.hash, finfo.min_bias, finfo.max_bias); // Fill bias with same seed as last run of gemmlowp_target
+ (run_twice) ? fill_s32(bias, 5 + finfo.hash, finfo.min_bias, finfo.max_bias) : fill_s32(bias, 2 + finfo.hash, finfo.min_bias, finfo.max_bias); // Fill bias with same seed as last run of gemmlowp_target
switch(output_stage.type)
{
case GEMMLowpOutputStageType::QUANTIZE_DOWN:
- return reference::gemmlowp_quantize_down_scale<int32_t, TW>(output, bias,
+ return reference::gemmlowp_quantize_down_scale<int32_t, TI>(output, bias,
output_stage.gemmlowp_offset, output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound);
break;
case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT:
- return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TW>(output, bias,
+ return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TI>(output, bias,
output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_offset, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound);
break;
default:
@@ -384,15 +438,24 @@ protected:
SimpleTensor<TI> _reference{};
};
-template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t>
-class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t, bool run_twice = false>
+class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type, bool reshape_b_only_on_first_run)
+ {
+ GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>::setup(shape_a, shape_b,
+ shape_output, output_stage_type, data_type, reshape_b_only_on_first_run);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t, bool run_twice = false>
+class GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture : public GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type, bool reshape_b_only_on_first_run)
{
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>::setup(shape_a, shape_b,
- shape_output, output_stage_type, data_type, false /* reshape_b_only_on_first_run */);
+ GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>::setup(shape_a, shape_b, shape_output, output_stage_type, data_type, reshape_b_only_on_first_run);
}
};
diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp
index 20f1139a02..d513343796 100644
--- a/tests/validation/reference/GEMM.cpp
+++ b/tests/validation/reference/GEMM.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021,2024 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,6 +25,7 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
+#include "tests/validation/reference/ArithmeticOperations.h"
namespace arm_compute
{
@@ -180,17 +181,22 @@ SimpleTensor<T> gemm_mixed_precision(
return dst;
}
-template SimpleTensor<float>
-gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta);
-template SimpleTensor<bfloat16> gemm(const SimpleTensor<bfloat16> &a,
- const SimpleTensor<bfloat16> &b,
- const SimpleTensor<bfloat16> &c,
- float alpha,
- float beta);
-template SimpleTensor<half>
-gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
-template SimpleTensor<half> gemm_mixed_precision(
- const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
+template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
+void gemm_accumulate(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta, SimpleTensor<T> &dst)
+{
+ // Compute reference
+ SimpleTensor<T> dst_gemm = gemm(a, b, c, alpha, beta);
+ reference::arithmetic_operation<T>(reference::ArithmeticOperation::ADD, dst, dst_gemm, dst, ConvertPolicy::SATURATE);
+}
+
+template SimpleTensor<bfloat16> gemm(const SimpleTensor<bfloat16> &a, const SimpleTensor<bfloat16> &b, const SimpleTensor<bfloat16> &c, float alpha, float beta);
+template SimpleTensor<float> gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta);
+template SimpleTensor<half> gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
+
+template void gemm_accumulate(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta, SimpleTensor<float> &dst);
+template void gemm_accumulate(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta, SimpleTensor<half> &dst);
+
+template SimpleTensor<half> gemm_mixed_precision(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/GEMM.h b/tests/validation/reference/GEMM.h
index 5feaeda584..1b97570122 100644
--- a/tests/validation/reference/GEMM.h
+++ b/tests/validation/reference/GEMM.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2019, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_GEMM_H
-#define ARM_COMPUTE_TEST_GEMM_H
+#ifndef ACL_TESTS_VALIDATION_REFERENCE_GEMM_H
+#define ACL_TESTS_VALIDATION_REFERENCE_GEMM_H
#include "tests/SimpleTensor.h"
#include "tests/validation/Helpers.h"
@@ -41,8 +41,11 @@ SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const S
template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta);
+template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
+void gemm_accumulate(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta, SimpleTensor<T> &dst);
+
} // namespace reference
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_GEMM_H */
+#endif // ACL_TESTS_VALIDATION_REFERENCE_GEMM_H
diff --git a/tests/validation/reference/GEMMLowp.cpp b/tests/validation/reference/GEMMLowp.cpp
index 1615b51e73..30c577d850 100644
--- a/tests/validation/reference/GEMMLowp.cpp
+++ b/tests/validation/reference/GEMMLowp.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,6 +24,7 @@
#include "GEMMLowp.h"
#include "arm_compute/core/Types.h"
+#include "tests/validation/reference/ArithmeticOperations.h"
#include "tests/validation/reference/UtilsQuantizedAsymm.h"
#include "support/ToolchainSupport.h"
@@ -230,6 +231,13 @@ SimpleTensor<T_out> gemmlowp_matrix_multiply_core(const SimpleTensor<T_in> &a, c
return c;
}
+template <typename T_out, typename T_in, typename T_in_1>
+void gemmlowp_matrix_multiply_core_accumulate(const SimpleTensor<T_in> &a, const SimpleTensor<T_in_1> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset, SimpleTensor<T_out> &dst)
+{
+ SimpleTensor<T_out> dst_gemm = gemmlowp_matrix_multiply_core<T_out, T_in, T_in_1>(a, b, shape_c, a_offset, b_offset);
+ reference::arithmetic_operation<T_out>(reference::ArithmeticOperation::ADD, dst, dst_gemm, dst, ConvertPolicy::SATURATE);
+}
+
// used to validate assembly kernels which don't know anything about offsets
template <typename T1, typename T2, typename T3>
SimpleTensor<T1> gemmlowp(const SimpleTensor<T2> &a, const SimpleTensor<T3> &b, TensorShape shape_c)
@@ -336,6 +344,8 @@ template SimpleTensor<int8_t> gemmlowp_quantize_down_scale(const SimpleTensor<in
std::vector<int32_t> result_shift, int32_t min, int32_t max);
template SimpleTensor<int32_t> gemmlowp_matrix_multiply_core(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset);
template SimpleTensor<int32_t> gemmlowp_matrix_multiply_core(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset);
+template void gemmlowp_matrix_multiply_core_accumulate(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset, SimpleTensor<int32_t> &dst);
+template void gemmlowp_matrix_multiply_core_accumulate(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset, SimpleTensor<int32_t> &dst);
template SimpleTensor<int32_t> gemmlowp<int32_t, int8_t, int8_t>(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c);
template SimpleTensor<int32_t> gemmlowp<int32_t, uint8_t, uint8_t>(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, TensorShape shape_c);
template SimpleTensor<int32_t> gemmlowp<int32_t, uint8_t, int8_t>(const SimpleTensor<uint8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c);
diff --git a/tests/validation/reference/GEMMLowp.h b/tests/validation/reference/GEMMLowp.h
index 99015d71fb..6e471fdad1 100644
--- a/tests/validation/reference/GEMMLowp.h
+++ b/tests/validation/reference/GEMMLowp.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_GEMMLOWP_H
-#define ARM_COMPUTE_TEST_GEMMLOWP_H
+#ifndef ACL_TESTS_VALIDATION_REFERENCE_GEMMLOWP_H
+#define ACL_TESTS_VALIDATION_REFERENCE_GEMMLOWP_H
#include "tests/SimpleTensor.h"
#include "tests/validation/Helpers.h"
@@ -38,6 +38,9 @@ namespace reference
template <typename T1, typename T2, typename T3>
SimpleTensor<T1> gemmlowp_matrix_multiply_core(const SimpleTensor<T2> &a, const SimpleTensor<T3> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset);
+template <typename T1, typename T2, typename T3>
+void gemmlowp_matrix_multiply_core_accumulate(const SimpleTensor<T2> &a, const SimpleTensor<T3> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset, SimpleTensor<T1> &dst_);
+
template <typename T1, typename T2, typename T3 = T2>
SimpleTensor<T1> gemmlowp(const SimpleTensor<T2> &a, const SimpleTensor<T3> &b, TensorShape shape_c);
@@ -71,4 +74,4 @@ SimpleTensor<TOut> gemmlowp_quantize_down_scale_by_float(const SimpleTensor<TIn>
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_GEMMLOWP_H */
+#endif // ACL_TESTS_VALIDATION_REFERENCE_GEMMLOWP_H