aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
diff options
context:
space:
mode:
authorDavid Mansell <David.Mansell@arm.com>2023-03-10 13:48:50 +0000
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-03-13 16:40:36 +0000
commitaaa9da1efa83911c7a67d50811ad669a92a7d12f (patch)
tree457c72960724b6e9b2e50e8b039735515c629216 /src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
parent0ffc88b7ae8f73fc66338a4eee5348bab634edf7 (diff)
downloadComputeLibrary-aaa9da1efa83911c7a67d50811ad669a92a7d12f.tar.gz
arm_gemm: Add SME2 FP16 kernels.
Resolves: COMPMID-5966 Change-Id: Ic0d694493178da029a297643855bd0cff01b174f Signed-off-by: David Mansell <David.Mansell@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9302 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp22
1 files changed, 14 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index e7346e8039..442739b55e 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -267,18 +267,24 @@ public:
};
// We need a similar trick here to figure out what type the accumulator buffer should be.
-template<typename strategy, typename OutputStage>
+template<typename strategy, typename OutputStage, bool ForceFloat>
class accumulate_buffer_type {
public:
typedef typename strategy::result_type type;
};
template<typename strategy>
-class accumulate_buffer_type<strategy, Requantize32> {
+class accumulate_buffer_type<strategy, Requantize32, false> {
public:
typedef int32_t type;
};
+template<typename strategy, typename OutputStage>
+class accumulate_buffer_type<strategy, OutputStage, true> {
+public:
+ typedef float type;
+};
+
// Stripe width is a concept only needed for FixedFormat kernels. Use an accessor to avoid issues in other scenarios.
template<typename strategy, bool FixedFormat>
struct get_stripe_width {
@@ -321,11 +327,11 @@ struct get_kernel_weight_format<strategy, true, To> {
} // anonymous namespace
-template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool FixedFormat=false, bool ForceThreadColumns=false>
+template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool FixedFormat=false, bool ForceThreadColumns=false, bool ForceFloatAccumulate=false>
class GemmInterleaved : public GemmCommon<To, Tr> {
typedef typename strategy::operand_type Toi;
typedef typename strategy::result_type Tri;
- typedef typename accumulate_buffer_type<strategy, OutputStage>::type Tab;
+ typedef typename accumulate_buffer_type<strategy, OutputStage, ForceFloatAccumulate>::type Tab;
/* const properties set by constructor */
const CPUInfo * const _ci;
@@ -383,7 +389,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
class blockwalker {
private:
/* Size loops, etc. based on our parent's configuration */
- const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns> &_parent;
+ const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &_parent;
/* K, X and multi parameters for current iteration. */
unsigned int _k0=0, _x0=0, _multi=0;
@@ -398,9 +404,9 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
bool _newmulti=true;
public:
- blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns> &parent) : _parent(parent) { }
+ blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent) : _parent(parent) { }
- blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns> &parent,
+ blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent,
unsigned int x_start, unsigned int x_end) : _parent(parent), _x0 (_x_start), _x_start(x_start), _x_end(x_end) { }
unsigned int xmax() {