diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp | 73 |
1 files changed, 46 insertions, 27 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp index 6897e64d4b..f12efe4282 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 ARM Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,10 +28,9 @@ #include <algorithm> #include "arm_gemm.hpp" +#include "ndrange.hpp" #include "utils.hpp" -#include "arm_compute/core/NEON/kernels/arm_gemm/ndrange.hpp" - #include "mergeresults.hpp" #include "transform.hpp" @@ -57,8 +56,6 @@ class GemmHybridQuantized : public GemmCommon<To, Tr> { const unsigned int _nbatches; const unsigned int _nmulti; - const bool _trB; - /* Blocking info */ const unsigned int _k_block; const unsigned int _n_block; @@ -113,7 +110,13 @@ class GemmHybridQuantized : public GemmCommon<To, Tr> { static unsigned int compute_n_block(const GemmArgs &args) { if (args._cfg && args._cfg->outer_block_size) { - return args._cfg->outer_block_size; + unsigned int n_block = args._cfg->outer_block_size; + + // Needs to be (at least a single) multiple of the kernel output width. + n_block /= strategy::out_width(); + n_block = std::max(n_block, 1u) * strategy::out_width(); + + return n_block; } const unsigned int k_block = compute_k_block(args); @@ -121,18 +124,27 @@ class GemmHybridQuantized : public GemmCommon<To, Tr> { // n_block: Work out how many rows (of length k_block) will fit in the L2 // Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents. - unsigned int n_block = (((L2_size * 9) / 10) - (k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()))) / - (sizeof(Toi) * k_block); + const unsigned int scaled_l2_size = (L2_size * 9) / 10; + const unsigned int k_block_area = k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()); + + // .. if the L1 contents is bigger than the L2, just return a minimal size block. + if (k_block_area > scaled_l2_size) { + return strategy::out_width(); + } + + unsigned int n_block = (scaled_l2_size - k_block_area) / (sizeof(Toi) * k_block); // Needs to be (at least a single) multiple of the kernel output width. n_block /= strategy::out_width(); - n_block = std::max(n_block, 1U) * strategy::out_width(); + n_block = std::max(n_block, 1u) * strategy::out_width(); // And tune to the presented problem size. unsigned int numblocks = iceildiv(args._Nsize, n_block); n_block = iceildiv(args._Nsize, numblocks); n_block = roundup(n_block, strategy::out_width()); + assert(n_block > 0); + return n_block; } @@ -143,7 +155,7 @@ public: /* Constructor */ GemmHybridQuantized(const GemmArgs &args, const Requantize32 &qp) : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), - _nbatches(args._nbatches), _nmulti(args._nmulti), _trB(args._trB), + _nbatches(args._nbatches), _nmulti(args._nmulti), _k_block(compute_k_block(args)), _n_block(compute_n_block(args)), _Mround(roundup(args._Msize, strategy::out_height())), _window_range(iceildiv(args._Msize, strategy::out_height()), _nbatches, iceildiv(_Nsize, _n_block), _nmulti), @@ -151,7 +163,7 @@ public: // Interface implementation - Compulsory functions ndrange_t get_window_size() const override { - return { _window_range.total_size(), 1u, 1u, 1u, 1u, 1u }; + return { _window_range.total_size() }; } // This kernel can always be dynamically scheduled. @@ -159,7 +171,8 @@ public: return true; } - void execute_1d(unsigned int start, unsigned int end, int threadid) { + // Execute + void execute(const ndcoord_t &work_range, const ndcoord_t &, int threadid) override { #ifdef CYCLE_PROFILING profiler prof; #endif @@ -180,7 +193,7 @@ public: unsigned int kmax = std::min(k0 + _k_block, _Ksize); unsigned int kern_k = roundup(kmax-k0, strategy::k_unroll()); - auto p = _window_range.iterator(start, end); + auto p = _window_range.iterator(work_range.get_position(0), work_range.get_position_end(0)); if (p.done()) { return; @@ -228,23 +241,12 @@ public: requantize_block_32(_qp, (nmax - n0), (m_end - m_start), result_buffer, (nmax - n0), this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc, - local_row_sums, col_bias + (multi * _Nsize) + n0); + local_row_sums, col_bias + (multi * _Nsize) + n0, n0); } } while (p.next_dim0()); } } - // Execute - void execute(const ndcoord_t& work_range, const ndcoord_t& thread_locator, int threadid) override { - UNUSED(thread_locator); - - const auto start = work_range.get_position(0); - const auto size = work_range.get_size(0); - const auto stop = start + size; - - execute_1d(start, stop, threadid); - } - // Working space needed for intermediate result buffers. size_t get_working_size() const override { return (_nthreads * strategy::out_height() * _Nsize * sizeof(Tri)); @@ -267,12 +269,18 @@ public: return get_col_sum_size() + (roundup(_Nsize, strategy::out_width()) * roundup(_Ksize, strategy::k_unroll()) * _nmulti * sizeof(Toi)); } - void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { col_bias = reinterpret_cast<int32_t *>(in_buffer); for (unsigned int i=0; i<_nmulti; i++) { compute_col_sums(_qp, _Nsize, _Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize, i, 0); } + } + + void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override { + assert(!transposed); + + requantize_bias(in_buffer, B, ldb, B_multi_stride); uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer); Toi *buffer = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size()); @@ -290,7 +298,7 @@ public: const unsigned int size = roundup(xmax-x0, strategy::out_width()) * k_size; strat.transforms.PrepareB( buffer, B + (multi * B_multi_stride), ldb, - x0, xmax, k0, kmax, _trB); + x0, xmax, k0, kmax, false); buffer += size; } @@ -308,6 +316,17 @@ public: _qp.bias = bias; _qp.bias_multi_stride = bias_multi_stride; } + + GemmConfig get_config() override { + GemmConfig c; + + c.method = GemmMethod::GEMM_HYBRID; + c.inner_block_size = _k_block; + c.outer_block_size = _n_block; + c.filter = get_type_name<strategy>(); + + return c; + } }; } // namespace arm_gemm |