aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/assembly/arm_gemm.hpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2020-07-02 20:02:20 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-07-06 16:51:32 +0000
commit5aa1a0b7ca5eed010e4b297a95b1c4851f741328 (patch)
treeba882de9e86589dfdd33937d538a89bbdf01c40e /src/core/NEON/kernels/assembly/arm_gemm.hpp
parent42550c039105597ff6acd4e5efc0ee3c7c20b08e (diff)
downloadComputeLibrary-5aa1a0b7ca5eed010e4b297a95b1c4851f741328.tar.gz
COMPID-3324: Clean GEMM kernels
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: I170de1671e061a78740caee31fb4a1b8642c1369 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3505 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/assembly/arm_gemm.hpp')
-rw-r--r--src/core/NEON/kernels/assembly/arm_gemm.hpp106
1 files changed, 57 insertions, 49 deletions
diff --git a/src/core/NEON/kernels/assembly/arm_gemm.hpp b/src/core/NEON/kernels/assembly/arm_gemm.hpp
index 7723224ec8..2df7132500 100644
--- a/src/core/NEON/kernels/assembly/arm_gemm.hpp
+++ b/src/core/NEON/kernels/assembly/arm_gemm.hpp
@@ -23,14 +23,14 @@
*/
#pragma once
-#include <memory>
#include <cstring>
+#include <memory>
#include "arm_gemm_local.hpp"
#include "gemm_common.hpp"
-namespace arm_gemm {
-
+namespace arm_gemm
+{
enum class GemmMethod
{
DEFAULT,
@@ -47,12 +47,17 @@ enum class GemmMethod
struct KernelDescription
{
- GemmMethod method = GemmMethod::DEFAULT;
- std::string name = "";
- bool is_default = false;
+ GemmMethod method = GemmMethod::DEFAULT;
+ std::string name = "";
+ bool is_default = false;
- KernelDescription(GemmMethod m, std::string n, bool d=false) : method(m), name(n), is_default(d) { }
- KernelDescription() noexcept { }
+ KernelDescription(GemmMethod m, std::string n, bool d = false)
+ : method(m), name(n), is_default(d)
+ {
+ }
+ KernelDescription() noexcept
+ {
+ }
};
struct GemmConfig
@@ -62,23 +67,32 @@ struct GemmConfig
unsigned int inner_block_size = 0;
unsigned int outer_block_size = 0;
- GemmConfig(GemmMethod method) : method(method) { }
- GemmConfig() { }
+ GemmConfig(GemmMethod method)
+ : method(method)
+ {
+ }
+ GemmConfig()
+ {
+ }
};
struct Activation
{
- enum class Type {
+ enum class Type
+ {
None,
ReLU,
BoundedReLU
};
- Type type;
- float param1;
- float param2;
+ Type type;
+ float param1;
+ float param2;
- Activation(Type type=Type::None, float p1=0.0f, float p2=0.0f) : type(type), param1(p1), param2(p2) { }
+ Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f)
+ : type(type), param1(p1), param2(p2)
+ {
+ }
};
struct GemmArgs
@@ -101,10 +115,8 @@ public:
const unsigned int K, const unsigned int nbatches,
const unsigned int nmulti, const bool trA, const bool trB,
Activation act, const int maxthreads,
- const bool pretransposed_hint, const GemmConfig *cfg=nullptr ) :
- _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti),
- _trA(trA), _trB(trB), _act(act), _maxthreads(maxthreads),
- _pretransposed_hint(pretransposed_hint), _cfg(cfg)
+ const bool pretransposed_hint, const GemmConfig *cfg = nullptr)
+ : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), _trA(trA), _trB(trB), _act(act), _maxthreads(maxthreads), _pretransposed_hint(pretransposed_hint), _cfg(cfg)
{
}
};
@@ -112,18 +124,18 @@ public:
struct Requantize32
{
public:
- const int32_t *bias = nullptr;
- size_t bias_multi_stride = 0;
- int32_t a_offset = 0;
- int32_t b_offset = 0;
- int32_t c_offset = 0;
- bool per_channel_requant = false;
- int32_t per_layer_shift = 0;
- int32_t per_layer_mul = 0;
- const int32_t *per_channel_shifts = nullptr;
- const int32_t *per_channel_muls = nullptr;
- int32_t minval = 0;
- int32_t maxval = 0;
+ const int32_t *bias = nullptr;
+ size_t bias_multi_stride = 0;
+ int32_t a_offset = 0;
+ int32_t b_offset = 0;
+ int32_t c_offset = 0;
+ bool per_channel_requant = false;
+ int32_t per_layer_shift = 0;
+ int32_t per_layer_mul = 0;
+ const int32_t *per_channel_shifts = nullptr;
+ const int32_t *per_channel_muls = nullptr;
+ int32_t minval = 0;
+ int32_t maxval = 0;
Requantize32() = default;
@@ -131,11 +143,9 @@ public:
Requantize32(const int32_t *bias, size_t bias_multi_stride,
int32_t a_offset, int32_t b_offset, int32_t c_offset,
int32_t requant_shift, int32_t requant_mul,
- int32_t minv, int32_t maxv) :
- bias(bias), bias_multi_stride(bias_multi_stride),
- a_offset(a_offset), b_offset(b_offset), c_offset(c_offset),
- per_channel_requant(false), per_layer_shift(requant_shift), per_layer_mul(requant_mul),
- minval(minv), maxval(maxv)
+ int32_t minv, int32_t maxv)
+ : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_shift(requant_shift), per_layer_mul(requant_mul),
+ minval(minv), maxval(maxv)
{
}
@@ -143,11 +153,9 @@ public:
Requantize32(const int32_t *bias, size_t bias_multi_stride,
int32_t a_offset, int32_t b_offset, int32_t c_offset,
const int32_t *requant_shifts, const int32_t *requant_muls,
- int32_t minv, int32_t maxv) :
- bias(bias), bias_multi_stride(bias_multi_stride),
- a_offset(a_offset), b_offset(b_offset), c_offset(c_offset),
- per_channel_requant(true), per_channel_shifts(requant_shifts), per_channel_muls(requant_muls),
- minval(minv), maxval(maxv)
+ int32_t minv, int32_t maxv)
+ : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_shifts(requant_shifts),
+ per_channel_muls(requant_muls), minval(minv), maxval(maxv)
{
}
};
@@ -156,21 +164,21 @@ struct Nothing
{
};
-template<typename Top, typename Tret>
-using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret> >;
+template <typename Top, typename Tret>
+using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
/* Low level API calls.
* These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
/* get_gemm_method(): Given the templated types and provided parameters,
* which is the preferred method to implement this GEMM? */
-template<typename Top, typename Tret, class OutputStage = Nothing>
-KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & ={});
+template <typename Top, typename Tret, class OutputStage = Nothing>
+KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {});
-template<typename Top, typename Tret, class OutputStage = Nothing>
-UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & ={});
+template <typename Top, typename Tret, class OutputStage = Nothing>
+UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
-template<typename Top, typename Tret, class OutputStage = Nothing>
-std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & ={});
+template <typename Top, typename Tret, class OutputStage = Nothing>
+std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm