diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp index b4edece8d5..574ecef5b2 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp @@ -58,8 +58,6 @@ class GemmHybridQuantized : public GemmCommon<To, Tr> { const bool _trB; - const Tr _beta; - /* Blocking info */ const unsigned int _k_block; const unsigned int _n_block; @@ -82,7 +80,7 @@ class GemmHybridQuantized : public GemmCommon<To, Tr> { return _Nsize * _nmulti * sizeof(int32_t); } - static unsigned int compute_k_block(const GemmArgs<Tr> &args) { + static unsigned int compute_k_block(const GemmArgs &args) { // We don't support K blocks as we only temporarily store 32 bit results. return args._Ksize; @@ -112,7 +110,7 @@ class GemmHybridQuantized : public GemmCommon<To, Tr> { return k_block; } - static unsigned int compute_n_block(const GemmArgs<Tr> &args) { + static unsigned int compute_n_block(const GemmArgs &args) { if (args._cfg && args._cfg->outer_block_size) { return args._cfg->outer_block_size; } @@ -142,9 +140,9 @@ public: GemmHybridQuantized & operator= (GemmHybridQuantized &) = delete; /* Constructor */ - GemmHybridQuantized(const GemmArgs<Tr> &args, const ARequantizeLayer32 &qp) + GemmHybridQuantized(const GemmArgs &args, const ARequantizeLayer32 &qp) : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), - _nbatches(args._nbatches), _nmulti(args._nmulti), _trB(args._trB), _beta(args._beta), + _nbatches(args._nbatches), _nmulti(args._nmulti), _trB(args._trB), _k_block(compute_k_block(args)), _n_block(compute_n_block(args)), _Mround(roundup(args._Msize, strategy::out_height())), _window_range(iceildiv(args._Msize, strategy::out_height()), _nbatches, iceildiv(_Nsize, _n_block), _nmulti), @@ -210,8 +208,8 @@ public: strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (m_start * this->_lda) + k0, this->_lda, b_panel, result_buffer, (nmax-n0), - (k0 == 0) ? _beta : static_cast<Tr>(1), - (m_end - m_start), (nmax - n0), kern_k); + (m_end - m_start), (nmax - n0), kern_k, + nullptr, Activation(), false); } { @@ -262,7 +260,7 @@ public: col_bias = reinterpret_cast<int32_t *>(in_buffer); for (unsigned int i=0; i<_nmulti; i++) { - compute_col_sums(_qp, _Nsize, _Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize, 0); + compute_col_sums(_qp, _Nsize, _Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize, i, 0); } uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer); @@ -295,8 +293,9 @@ public: col_bias = reinterpret_cast<int32_t *>(in_buffer); } - void set_quantized_bias(const int32_t *bias) override { + void set_quantized_bias(const int32_t *bias, size_t bias_multi_stride) override { _qp.bias = bias; + _qp.bias_multi_stride = bias_multi_stride; } }; |