aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp22
1 files changed, 14 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index e7346e8039..442739b55e 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -267,18 +267,24 @@ public:
};
// We need a similar trick here to figure out what type the accumulator buffer should be.
-template<typename strategy, typename OutputStage>
+template<typename strategy, typename OutputStage, bool ForceFloat>
class accumulate_buffer_type {
public:
typedef typename strategy::result_type type;
};
template<typename strategy>
-class accumulate_buffer_type<strategy, Requantize32> {
+class accumulate_buffer_type<strategy, Requantize32, false> {
public:
typedef int32_t type;
};
+template<typename strategy, typename OutputStage>
+class accumulate_buffer_type<strategy, OutputStage, true> {
+public:
+ typedef float type;
+};
+
// Stripe width is a concept only needed for FixedFormat kernels. Use an accessor to avoid issues in other scenarios.
template<typename strategy, bool FixedFormat>
struct get_stripe_width {
@@ -321,11 +327,11 @@ struct get_kernel_weight_format<strategy, true, To> {
} // anonymous namespace
-template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool FixedFormat=false, bool ForceThreadColumns=false>
+template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool FixedFormat=false, bool ForceThreadColumns=false, bool ForceFloatAccumulate=false>
class GemmInterleaved : public GemmCommon<To, Tr> {
typedef typename strategy::operand_type Toi;
typedef typename strategy::result_type Tri;
- typedef typename accumulate_buffer_type<strategy, OutputStage>::type Tab;
+ typedef typename accumulate_buffer_type<strategy, OutputStage, ForceFloatAccumulate>::type Tab;
/* const properties set by constructor */
const CPUInfo * const _ci;
@@ -383,7 +389,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
class blockwalker {
private:
/* Size loops, etc. based on our parent's configuration */
- const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns> &_parent;
+ const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &_parent;
/* K, X and multi parameters for current iteration. */
unsigned int _k0=0, _x0=0, _multi=0;
@@ -398,9 +404,9 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
bool _newmulti=true;
public:
- blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns> &parent) : _parent(parent) { }
+ blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent) : _parent(parent) { }
- blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns> &parent,
+ blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent,
unsigned int x_start, unsigned int x_end) : _parent(parent), _x0 (_x_start), _x_start(x_start), _x_end(x_end) { }
unsigned int xmax() {