diff options
Diffstat (limited to 'src/cpu/kernels/assembly/arm_gemm.hpp')
-rw-r--r-- | src/cpu/kernels/assembly/arm_gemm.hpp | 91 |
1 files changed, 68 insertions, 23 deletions
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp index 4c127b4ec3..9a913c5c58 100644 --- a/src/cpu/kernels/assembly/arm_gemm.hpp +++ b/src/cpu/kernels/assembly/arm_gemm.hpp @@ -23,13 +23,12 @@ */ #pragma once +#include "arm_gemm_local.hpp" +#include "gemm_common.hpp" #include <cstring> #include <memory> #include <vector> -#include "arm_gemm_local.hpp" -#include "gemm_common.hpp" - namespace arm_gemm { enum class GemmMethod @@ -111,8 +110,7 @@ struct GemmConfig unsigned int outer_block_size = 0; WeightFormat weight_format = WeightFormat::ANY; - GemmConfig(GemmMethod method) - : method(method) + GemmConfig(GemmMethod method) : method(method) { } GemmConfig() @@ -133,8 +131,7 @@ struct Activation float param1; float param2; - Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f) - : type(type), param1(p1), param2(p2) + Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f) : type(type), param1(p1), param2(p2) { } }; @@ -156,12 +153,32 @@ public: bool _fast_mode; const GemmConfig *_cfg; - GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N, - unsigned int K, unsigned int Ksections, unsigned int nbatches, - unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads, - bool fixed_format = false, bool fast_mode = false, const GemmConfig *cfg = nullptr) - : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), - _fixed_format(fixed_format), _fast_mode(fast_mode), _cfg(cfg) + GemmArgs(const CPUInfo *ci, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int Ksections, + unsigned int nbatches, + unsigned int nmulti, + bool indirect_input, + Activation act, + const int maxthreads, + bool fixed_format = false, + bool fast_mode = false, + const GemmConfig *cfg = nullptr) + : _ci(ci), + _Msize(M), + _Nsize(N), + _Ksize(K), + _Ksections(Ksections), + _nbatches(nbatches), + _nmulti(nmulti), + _indirect_input(indirect_input), + _act(act), + _maxthreads(maxthreads), + _fixed_format(fixed_format), + _fast_mode(fast_mode), + _cfg(cfg) { } }; @@ -187,23 +204,51 @@ public: Requantize32() = default; // Constructor for per-tensor quantization - Requantize32(const int32_t *bias, size_t bias_multi_stride, - int32_t a_offset, int32_t b_offset, int32_t c_offset, - int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv) - : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max<int32_t>(requant_shift, 0)), - per_layer_right_shift(std::min<int32_t>(requant_shift, 0)), per_layer_mul(requant_mul), minval(minv), maxval(maxv) + Requantize32(const int32_t *bias, + size_t bias_multi_stride, + int32_t a_offset, + int32_t b_offset, + int32_t c_offset, + int32_t requant_shift, + int32_t requant_mul, + int32_t minv, + int32_t maxv) + : bias(bias), + bias_multi_stride(bias_multi_stride), + a_offset(a_offset), + b_offset(b_offset), + c_offset(c_offset), + per_channel_requant(false), + per_layer_left_shift(std::max<int32_t>(requant_shift, 0)), + per_layer_right_shift(std::min<int32_t>(requant_shift, 0)), + per_layer_mul(requant_mul), + minval(minv), + maxval(maxv) { } // Constructor for per-channel quantization - Requantize32(const int32_t *bias, size_t bias_multi_stride, - int32_t a_offset, int32_t b_offset, int32_t c_offset, + Requantize32(const int32_t *bias, + size_t bias_multi_stride, + int32_t a_offset, + int32_t b_offset, + int32_t c_offset, const int32_t *requant_left_shifts, const int32_t *requant_right_shifts, const int32_t *requant_muls, - int32_t minv, int32_t maxv) - : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_left_shifts(requant_left_shifts), - per_channel_right_shifts(requant_right_shifts), per_channel_muls(requant_muls), minval(minv), maxval(maxv) + int32_t minv, + int32_t maxv) + : bias(bias), + bias_multi_stride(bias_multi_stride), + a_offset(a_offset), + b_offset(b_offset), + c_offset(c_offset), + per_channel_requant(true), + per_channel_left_shifts(requant_left_shifts), + per_channel_right_shifts(requant_right_shifts), + per_channel_muls(requant_muls), + minval(minv), + maxval(maxv) { } }; |