diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index e7346e8039..442739b55e 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -267,18 +267,24 @@ public: }; // We need a similar trick here to figure out what type the accumulator buffer should be. -template<typename strategy, typename OutputStage> +template<typename strategy, typename OutputStage, bool ForceFloat> class accumulate_buffer_type { public: typedef typename strategy::result_type type; }; template<typename strategy> -class accumulate_buffer_type<strategy, Requantize32> { +class accumulate_buffer_type<strategy, Requantize32, false> { public: typedef int32_t type; }; +template<typename strategy, typename OutputStage> +class accumulate_buffer_type<strategy, OutputStage, true> { +public: + typedef float type; +}; + // Stripe width is a concept only needed for FixedFormat kernels. Use an accessor to avoid issues in other scenarios. template<typename strategy, bool FixedFormat> struct get_stripe_width { @@ -321,11 +327,11 @@ struct get_kernel_weight_format<strategy, true, To> { } // anonymous namespace -template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool FixedFormat=false, bool ForceThreadColumns=false> +template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool FixedFormat=false, bool ForceThreadColumns=false, bool ForceFloatAccumulate=false> class GemmInterleaved : public GemmCommon<To, Tr> { typedef typename strategy::operand_type Toi; typedef typename strategy::result_type Tri; - typedef typename accumulate_buffer_type<strategy, OutputStage>::type Tab; + typedef typename accumulate_buffer_type<strategy, OutputStage, ForceFloatAccumulate>::type Tab; /* const properties set by constructor */ const CPUInfo * const _ci; @@ -383,7 +389,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { class blockwalker { private: /* Size loops, etc. based on our parent's configuration */ - const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns> &_parent; + const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &_parent; /* K, X and multi parameters for current iteration. */ unsigned int _k0=0, _x0=0, _multi=0; @@ -398,9 +404,9 @@ class GemmInterleaved : public GemmCommon<To, Tr> { bool _newmulti=true; public: - blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns> &parent) : _parent(parent) { } + blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent) : _parent(parent) { } - blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns> &parent, + blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent, unsigned int x_start, unsigned int x_end) : _parent(parent), _x0 (_x_start), _x_start(x_start), _x_end(x_end) { } unsigned int xmax() { |