aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/assembly/arm_gemm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/assembly/arm_gemm.hpp')
-rw-r--r--src/core/NEON/kernels/assembly/arm_gemm.hpp18
1 files changed, 11 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/assembly/arm_gemm.hpp b/src/core/NEON/kernels/assembly/arm_gemm.hpp
index f6421c12ab..3088b080d6 100644
--- a/src/core/NEON/kernels/assembly/arm_gemm.hpp
+++ b/src/core/NEON/kernels/assembly/arm_gemm.hpp
@@ -43,7 +43,9 @@ enum class GemmMethod
GEMM_INTERLEAVED_2D,
QUANTIZE_WRAPPER,
QUANTIZE_WRAPPER_2D,
- GEMM_HYBRID_QUANTIZED
+ GEMM_HYBRID_QUANTIZED,
+ INDIRECT_GEMM,
+ CONVOLUTION_GEMM
};
struct KernelDescription
@@ -104,17 +106,19 @@ public:
unsigned int _Msize;
unsigned int _Nsize;
unsigned int _Ksize;
+ unsigned int _Ksections;
unsigned int _nbatches;
unsigned int _nmulti;
+ bool _indirect_input;
Activation _act;
int _maxthreads;
const GemmConfig *_cfg;
- GemmArgs(const CPUInfo *ci, const unsigned int M, const unsigned int N,
- const unsigned int K, const unsigned int nbatches,
- const unsigned int nmulti, Activation act, const int maxthreads,
+ GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N,
+ unsigned int K, unsigned int Ksections, unsigned int nbatches,
+ unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads,
const GemmConfig *cfg = nullptr)
- : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), _act(act), _maxthreads(maxthreads), _cfg(cfg)
+ : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), _cfg(cfg)
{
}
};
@@ -143,8 +147,8 @@ public:
Requantize32(const int32_t *bias, size_t bias_multi_stride,
int32_t a_offset, int32_t b_offset, int32_t c_offset,
int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv)
- : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max(requant_shift, int32_t(0))),
- per_layer_right_shift(std::min(requant_shift, int32_t(0))), per_layer_mul(requant_mul), minval(minv), maxval(maxv)
+ : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max<int32_t>(requant_shift, 0)),
+ per_layer_right_shift(std::min<int32_t>(requant_shift, 0)), per_layer_mul(requant_mul), minval(minv), maxval(maxv)
{
}