diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp')
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp | 34 |
1 files changed, 24 insertions, 10 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp index 828b0f20a7..17faab18fd 100644 --- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp +++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp @@ -65,7 +65,21 @@ struct GemmConfig GemmConfig() { } }; -template<typename T> +struct Activation +{ + enum class Type { + None, + ReLU, + BoundedReLU + }; + + Type type; + float param1; + float param2; + + Activation(Type type=Type::None, float p1=0.0f, float p2=0.0f) : type(type), param1(p1), param2(p2) { } +}; + struct GemmArgs { public: @@ -77,8 +91,7 @@ public: unsigned int _nmulti; bool _trA; bool _trB; - T _alpha; - T _beta; + Activation _act; int _maxthreads; bool _pretransposed_hint; const GemmConfig *_cfg; @@ -86,10 +99,10 @@ public: GemmArgs(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, - const T alpha, const T beta, const int maxthreads, + Activation act, const int maxthreads, const bool pretransposed_hint, const GemmConfig *cfg=nullptr ) : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), - _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads), + _trA(trA), _trB(trB), _act(act), _maxthreads(maxthreads), _pretransposed_hint(pretransposed_hint), _cfg(cfg) { } @@ -99,6 +112,7 @@ struct ARequantizeLayer32 { public: const int32_t *bias; + size_t bias_multi_stride; int32_t a_offset; int32_t b_offset; int32_t c_offset; @@ -109,8 +123,8 @@ public: ARequantizeLayer32() = default; - ARequantizeLayer32(int32_t *b, int32_t ao, int32_t bo, int32_t co, int32_t rs, int32_t rm, int32_t minv, int32_t maxv) : - bias(b), a_offset(ao), b_offset(bo), c_offset(co), requant_shift(rs), requant_mul(rm), minval(minv), maxval(maxv) + ARequantizeLayer32(const int32_t *b, size_t bms, int32_t ao, int32_t bo, int32_t co, int32_t rs, int32_t rm, int32_t minv, int32_t maxv) : + bias(b), bias_multi_stride(bms), a_offset(ao), b_offset(bo), c_offset(co), requant_shift(rs), requant_mul(rm), minval(minv), maxval(maxv) { } }; @@ -128,12 +142,12 @@ using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret> >; /* get_gemm_method(): Given the templated types and provided parameters, * which is the preferred method to implement this GEMM? */ template<typename Top, typename Tret, class OutputStage = Nothing> -KernelDescription get_gemm_method(const GemmArgs<Tret> &args, const OutputStage & ={}); +KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & ={}); template<typename Top, typename Tret, class OutputStage = Nothing> -UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args, const OutputStage & ={}); +UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & ={}); template<typename Top, typename Tret, class OutputStage = Nothing> -std::vector<KernelDescription> get_compatible_kernels(const GemmArgs<Tret> &args, const OutputStage & ={}); +std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & ={}); } // namespace arm_gemm |