diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-10-14 19:03:09 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-10-23 12:08:12 +0000 |
commit | 48b3ef89de5f21a0169d8416e3d54081f82c7bf8 (patch) | |
tree | f857d733ccf446c704823dc7ac796a96eb55095e /arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp | |
parent | 1dce3101ef8d77c8cf0af7dfd4af6595a0136b91 (diff) | |
download | ComputeLibrary-48b3ef89de5f21a0169d8416e3d54081f82c7bf8.tar.gz |
COMPMID-2577: Fuse bias addition and activation in gemm assembly kernels
Change-Id: I7f52112d2d05b1ea3d3f3d4b19b8eafab05d6c44
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2141
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp')
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp | 34 |
1 files changed, 24 insertions, 10 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp index 828b0f20a7..17faab18fd 100644 --- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp +++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp @@ -65,7 +65,21 @@ struct GemmConfig GemmConfig() { } }; -template<typename T> +struct Activation +{ + enum class Type { + None, + ReLU, + BoundedReLU + }; + + Type type; + float param1; + float param2; + + Activation(Type type=Type::None, float p1=0.0f, float p2=0.0f) : type(type), param1(p1), param2(p2) { } +}; + struct GemmArgs { public: @@ -77,8 +91,7 @@ public: unsigned int _nmulti; bool _trA; bool _trB; - T _alpha; - T _beta; + Activation _act; int _maxthreads; bool _pretransposed_hint; const GemmConfig *_cfg; @@ -86,10 +99,10 @@ public: GemmArgs(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, - const T alpha, const T beta, const int maxthreads, + Activation act, const int maxthreads, const bool pretransposed_hint, const GemmConfig *cfg=nullptr ) : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), - _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads), + _trA(trA), _trB(trB), _act(act), _maxthreads(maxthreads), _pretransposed_hint(pretransposed_hint), _cfg(cfg) { } @@ -99,6 +112,7 @@ struct ARequantizeLayer32 { public: const int32_t *bias; + size_t bias_multi_stride; int32_t a_offset; int32_t b_offset; int32_t c_offset; @@ -109,8 +123,8 @@ public: ARequantizeLayer32() = default; - ARequantizeLayer32(int32_t *b, int32_t ao, int32_t bo, int32_t co, int32_t rs, int32_t rm, int32_t minv, int32_t maxv) : - bias(b), a_offset(ao), b_offset(bo), c_offset(co), requant_shift(rs), requant_mul(rm), minval(minv), maxval(maxv) + ARequantizeLayer32(const int32_t *b, size_t bms, int32_t ao, int32_t bo, int32_t co, int32_t rs, int32_t rm, int32_t minv, int32_t maxv) : + bias(b), bias_multi_stride(bms), a_offset(ao), b_offset(bo), c_offset(co), requant_shift(rs), requant_mul(rm), minval(minv), maxval(maxv) { } }; @@ -128,12 +142,12 @@ using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret> >; /* get_gemm_method(): Given the templated types and provided parameters, * which is the preferred method to implement this GEMM? */ template<typename Top, typename Tret, class OutputStage = Nothing> -KernelDescription get_gemm_method(const GemmArgs<Tret> &args, const OutputStage & ={}); +KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & ={}); template<typename Top, typename Tret, class OutputStage = Nothing> -UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args, const OutputStage & ={}); +UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & ={}); template<typename Top, typename Tret, class OutputStage = Nothing> -std::vector<KernelDescription> get_compatible_kernels(const GemmArgs<Tret> &args, const OutputStage & ={}); +std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & ={}); } // namespace arm_gemm |