diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp | 83 |
1 files changed, 31 insertions, 52 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp index 18f030fec0..d35825c428 100644 --- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp +++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 ARM Limited. + * Copyright (c) 2019-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -61,16 +61,9 @@ private: } /* Local working space: We need space for the subgemm output (above) and - * the row sums. If the GEMM is not pretransposed we need to store the - * column sums here too. */ + * the row sums. */ size_t local_working_size() const { - size_t sz = subgemm_output_size() + row_sum_size(); - - if (_args._pretransposed_hint) { - return sz; - } - - return sz + col_sum_size(); + return subgemm_output_size() + row_sum_size(); } void set_child_arrays() { @@ -90,15 +83,6 @@ private: } } - void col_sums_runtime(unsigned int threadid) { - unsigned int first_col = (threadid * _args._Nsize) / _args._maxthreads; - unsigned int last_col = ((threadid + 1) * _args._Nsize) / _args._maxthreads; - - for (unsigned int multi=0; multi<_args._nmulti; multi++) { - compute_col_sums(_params, (last_col - first_col), _args._Ksize, this->_Bptr + (multi * this->_B_multi_stride) + first_col, this->_ldb, _col_sums + (multi * _args._Nsize) + first_col, _args._Ksize, multi, first_col); - } - } - void requantize_runtime(unsigned int threadid) { unsigned int first_row = (threadid * _args._Msize) / _args._maxthreads; unsigned int last_row = ((threadid+1) * _args._Msize) / _args._maxthreads; @@ -115,7 +99,7 @@ private: _args._Nsize, this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (first_row * this->_ldc), this->_ldc, _row_sums + (multi * _args._nbatches * _args._Msize) + (batch * _args._Msize) + first_row, - _col_sums + (multi * _args._Nsize)); + _col_sums + (multi * _args._Nsize), 0); } } } @@ -126,16 +110,12 @@ public: QuantizeWrapper operator=(const QuantizeWrapper &) = delete; QuantizeWrapper(const GemmArgs &args, const Requantize32 &qp) : _params(qp), _args(args), _barrier(args._maxthreads) { - GemmArgs newargs = GemmArgs(args._ci, args._Msize, args._Nsize, args._Ksize, args._nbatches, args._nmulti, args._trA, args._trB, Activation(), args._maxthreads, args._pretransposed_hint, nullptr); + GemmArgs newargs = GemmArgs(args._ci, args._Msize, args._Nsize, args._Ksize, args._Ksections, args._nbatches, args._nmulti, args._indirect_input, Activation(), args._maxthreads); _subgemm = gemm<To, Tgemm>(newargs); if (_subgemm == nullptr) { return; } - - if (!_subgemm->B_is_pretransposed()) { - _args._pretransposed_hint = false; - } } void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, @@ -149,7 +129,7 @@ public: } ndrange_t get_window_size() const override { - return _subgemm->get_window_size(); + return { _subgemm->get_window_size() }; } void set_nthreads(int nthreads) override { @@ -158,12 +138,8 @@ public: _args._maxthreads = nthreads; } - // Execute - void execute(const ndcoord_t& work_range, const ndcoord_t& thread_locator, int threadid) override { + void execute(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid) override { _subgemm->execute(work_range, thread_locator, threadid); - if (!_args._pretransposed_hint) { - col_sums_runtime(threadid); - } _barrier.arrive_and_wait(); @@ -178,7 +154,7 @@ public: // ptr // V - // | subgemm output | row_sums | col_sums (if not pretransposed | subgemm working space | + // | subgemm output | row_sums | subgemm working space | void set_working_space(void *space) override { uintptr_t space_int = reinterpret_cast<uintptr_t>(space); @@ -186,16 +162,13 @@ public: _subgemm->set_working_space(reinterpret_cast<void *>(space_int + local_working_size())); _row_sums = reinterpret_cast<int32_t *>(space_int + subgemm_output_size()); - if (!_args._pretransposed_hint) { - _col_sums = reinterpret_cast<int32_t *>(space_int + subgemm_output_size() + row_sum_size()); - } set_child_arrays(); } bool B_is_pretransposed() const override { /* We clear this flag if the subgemm isn't pretransposed, so just return its value */ - return _args._pretransposed_hint; + return _subgemm->B_is_pretransposed(); } bool B_pretranspose_required() const override { @@ -203,31 +176,24 @@ public: } size_t get_B_pretransposed_array_size() const override { - if (_args._pretransposed_hint) { - return _subgemm->get_B_pretransposed_array_size() + col_sum_size(); - } + return _subgemm->get_B_pretransposed_array_size() + col_sum_size(); + } - return 0; + void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + _col_sums = reinterpret_cast<int32_t *>(in_buffer); + col_sums_pretransposed(B, ldb, B_multi_stride); } - void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { - if (!_args._pretransposed_hint) { - return; - } + void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override { + assert(!transposed); uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer); - _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride); + _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride, transposed); - _col_sums = reinterpret_cast<int32_t *>(buffer); - - col_sums_pretransposed(B, ldb, B_multi_stride); + requantize_bias(buffer, B, ldb, B_multi_stride); } void set_pretransposed_B_data(void *buffer) override { - if (!_args._pretransposed_hint) { - return; - } - uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer); _subgemm->set_pretransposed_B_data(reinterpret_cast<void *>(buffer_int + col_sum_size())); _col_sums = reinterpret_cast<int32_t *>(buffer); @@ -237,6 +203,19 @@ public: _params.bias = bias; _params.bias_multi_stride = bias_multi_stride; } + + GemmConfig get_config() override { + GemmConfig c = _subgemm->get_config(); + + std::string n = "quantize_wrapper["; + n.append(c.filter); + n.append("]"); + + c.method = GemmMethod::QUANTIZE_WRAPPER; + c.filter = n; + + return c; + } }; } // namespace arm_gemm |