From 5aa1a0b7ca5eed010e4b297a95b1c4851f741328 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 2 Jul 2020 20:02:20 +0100 Subject: COMPID-3324: Clean GEMM kernels Signed-off-by: Georgios Pinitas Change-Id: I170de1671e061a78740caee31fb4a1b8642c1369 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3505 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio --- src/core/NEON/kernels/assembly/arm_gemm.hpp | 106 +++++++++++++++------------- 1 file changed, 57 insertions(+), 49 deletions(-) (limited to 'src/core/NEON/kernels/assembly/arm_gemm.hpp') diff --git a/src/core/NEON/kernels/assembly/arm_gemm.hpp b/src/core/NEON/kernels/assembly/arm_gemm.hpp index 7723224ec8..2df7132500 100644 --- a/src/core/NEON/kernels/assembly/arm_gemm.hpp +++ b/src/core/NEON/kernels/assembly/arm_gemm.hpp @@ -23,14 +23,14 @@ */ #pragma once -#include #include +#include #include "arm_gemm_local.hpp" #include "gemm_common.hpp" -namespace arm_gemm { - +namespace arm_gemm +{ enum class GemmMethod { DEFAULT, @@ -47,12 +47,17 @@ enum class GemmMethod struct KernelDescription { - GemmMethod method = GemmMethod::DEFAULT; - std::string name = ""; - bool is_default = false; + GemmMethod method = GemmMethod::DEFAULT; + std::string name = ""; + bool is_default = false; - KernelDescription(GemmMethod m, std::string n, bool d=false) : method(m), name(n), is_default(d) { } - KernelDescription() noexcept { } + KernelDescription(GemmMethod m, std::string n, bool d = false) + : method(m), name(n), is_default(d) + { + } + KernelDescription() noexcept + { + } }; struct GemmConfig @@ -62,23 +67,32 @@ struct GemmConfig unsigned int inner_block_size = 0; unsigned int outer_block_size = 0; - GemmConfig(GemmMethod method) : method(method) { } - GemmConfig() { } + GemmConfig(GemmMethod method) + : method(method) + { + } + GemmConfig() + { + } }; struct Activation { - enum class Type { + enum class Type + { None, ReLU, BoundedReLU }; - Type type; - float param1; - float param2; + Type type; + float param1; + float param2; - Activation(Type type=Type::None, float p1=0.0f, float p2=0.0f) : type(type), param1(p1), param2(p2) { } + Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f) + : type(type), param1(p1), param2(p2) + { + } }; struct GemmArgs @@ -101,10 +115,8 @@ public: const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, Activation act, const int maxthreads, - const bool pretransposed_hint, const GemmConfig *cfg=nullptr ) : - _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), - _trA(trA), _trB(trB), _act(act), _maxthreads(maxthreads), - _pretransposed_hint(pretransposed_hint), _cfg(cfg) + const bool pretransposed_hint, const GemmConfig *cfg = nullptr) + : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), _trA(trA), _trB(trB), _act(act), _maxthreads(maxthreads), _pretransposed_hint(pretransposed_hint), _cfg(cfg) { } }; @@ -112,18 +124,18 @@ public: struct Requantize32 { public: - const int32_t *bias = nullptr; - size_t bias_multi_stride = 0; - int32_t a_offset = 0; - int32_t b_offset = 0; - int32_t c_offset = 0; - bool per_channel_requant = false; - int32_t per_layer_shift = 0; - int32_t per_layer_mul = 0; - const int32_t *per_channel_shifts = nullptr; - const int32_t *per_channel_muls = nullptr; - int32_t minval = 0; - int32_t maxval = 0; + const int32_t *bias = nullptr; + size_t bias_multi_stride = 0; + int32_t a_offset = 0; + int32_t b_offset = 0; + int32_t c_offset = 0; + bool per_channel_requant = false; + int32_t per_layer_shift = 0; + int32_t per_layer_mul = 0; + const int32_t *per_channel_shifts = nullptr; + const int32_t *per_channel_muls = nullptr; + int32_t minval = 0; + int32_t maxval = 0; Requantize32() = default; @@ -131,11 +143,9 @@ public: Requantize32(const int32_t *bias, size_t bias_multi_stride, int32_t a_offset, int32_t b_offset, int32_t c_offset, int32_t requant_shift, int32_t requant_mul, - int32_t minv, int32_t maxv) : - bias(bias), bias_multi_stride(bias_multi_stride), - a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), - per_channel_requant(false), per_layer_shift(requant_shift), per_layer_mul(requant_mul), - minval(minv), maxval(maxv) + int32_t minv, int32_t maxv) + : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_shift(requant_shift), per_layer_mul(requant_mul), + minval(minv), maxval(maxv) { } @@ -143,11 +153,9 @@ public: Requantize32(const int32_t *bias, size_t bias_multi_stride, int32_t a_offset, int32_t b_offset, int32_t c_offset, const int32_t *requant_shifts, const int32_t *requant_muls, - int32_t minv, int32_t maxv) : - bias(bias), bias_multi_stride(bias_multi_stride), - a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), - per_channel_requant(true), per_channel_shifts(requant_shifts), per_channel_muls(requant_muls), - minval(minv), maxval(maxv) + int32_t minv, int32_t maxv) + : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_shifts(requant_shifts), + per_channel_muls(requant_muls), minval(minv), maxval(maxv) { } }; @@ -156,21 +164,21 @@ struct Nothing { }; -template -using UniqueGemmCommon = std::unique_ptr >; +template +using UniqueGemmCommon = std::unique_ptr>; /* Low level API calls. * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */ /* get_gemm_method(): Given the templated types and provided parameters, * which is the preferred method to implement this GEMM? */ -template -KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & ={}); +template +KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {}); -template -UniqueGemmCommon gemm(const GemmArgs &args, const OutputStage & ={}); +template +UniqueGemmCommon gemm(const GemmArgs &args, const OutputStage & = {}); -template -std::vector get_compatible_kernels(const GemmArgs &args, const OutputStage & ={}); +template +std::vector get_compatible_kernels(const GemmArgs &args, const OutputStage & = {}); } // namespace arm_gemm -- cgit v1.2.1