aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp34
1 files changed, 24 insertions, 10 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
index 828b0f20a7..17faab18fd 100644
--- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
+++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
@@ -65,7 +65,21 @@ struct GemmConfig
GemmConfig() { }
};
-template<typename T>
+struct Activation
+{
+ enum class Type {
+ None,
+ ReLU,
+ BoundedReLU
+ };
+
+ Type type;
+ float param1;
+ float param2;
+
+ Activation(Type type=Type::None, float p1=0.0f, float p2=0.0f) : type(type), param1(p1), param2(p2) { }
+};
+
struct GemmArgs
{
public:
@@ -77,8 +91,7 @@ public:
unsigned int _nmulti;
bool _trA;
bool _trB;
- T _alpha;
- T _beta;
+ Activation _act;
int _maxthreads;
bool _pretransposed_hint;
const GemmConfig *_cfg;
@@ -86,10 +99,10 @@ public:
GemmArgs(const CPUInfo *ci, const unsigned int M, const unsigned int N,
const unsigned int K, const unsigned int nbatches,
const unsigned int nmulti, const bool trA, const bool trB,
- const T alpha, const T beta, const int maxthreads,
+ Activation act, const int maxthreads,
const bool pretransposed_hint, const GemmConfig *cfg=nullptr ) :
_ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti),
- _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads),
+ _trA(trA), _trB(trB), _act(act), _maxthreads(maxthreads),
_pretransposed_hint(pretransposed_hint), _cfg(cfg)
{
}
@@ -99,6 +112,7 @@ struct ARequantizeLayer32
{
public:
const int32_t *bias;
+ size_t bias_multi_stride;
int32_t a_offset;
int32_t b_offset;
int32_t c_offset;
@@ -109,8 +123,8 @@ public:
ARequantizeLayer32() = default;
- ARequantizeLayer32(int32_t *b, int32_t ao, int32_t bo, int32_t co, int32_t rs, int32_t rm, int32_t minv, int32_t maxv) :
- bias(b), a_offset(ao), b_offset(bo), c_offset(co), requant_shift(rs), requant_mul(rm), minval(minv), maxval(maxv)
+ ARequantizeLayer32(const int32_t *b, size_t bms, int32_t ao, int32_t bo, int32_t co, int32_t rs, int32_t rm, int32_t minv, int32_t maxv) :
+ bias(b), bias_multi_stride(bms), a_offset(ao), b_offset(bo), c_offset(co), requant_shift(rs), requant_mul(rm), minval(minv), maxval(maxv)
{
}
};
@@ -128,12 +142,12 @@ using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret> >;
/* get_gemm_method(): Given the templated types and provided parameters,
* which is the preferred method to implement this GEMM? */
template<typename Top, typename Tret, class OutputStage = Nothing>
-KernelDescription get_gemm_method(const GemmArgs<Tret> &args, const OutputStage & ={});
+KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & ={});
template<typename Top, typename Tret, class OutputStage = Nothing>
-UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args, const OutputStage & ={});
+UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & ={});
template<typename Top, typename Tret, class OutputStage = Nothing>
-std::vector<KernelDescription> get_compatible_kernels(const GemmArgs<Tret> &args, const OutputStage & ={});
+std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & ={});
} // namespace arm_gemm