diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp | 208 |
1 files changed, 137 insertions, 71 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp index 7f52ac5a14..dbada36052 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 ARM Limited. + * Copyright (c) 2017-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -36,69 +36,121 @@ namespace arm_gemm { +namespace { + +template<typename OutputStage> +class run_gemv_kernel { +public: + template<typename strategy, typename Tlo, typename Tro, typename Tr> + static void run ( + const strategy &strat, + const Tlo *A_ptr, const Tro *B_ptr, Tr *c_ptr, + size_t N, size_t K, + const Tr *bias, const Activation &act, bool Accumulate, + const OutputStage &os, const int32_t *col_bias, unsigned int col_base + ); +}; + +template<> +template<typename strategy, typename Tlo, typename Tro, typename Tr> +void run_gemv_kernel<Nothing>::run( + const strategy &strat, + const Tlo *A_ptr, const Tro *B_ptr, Tr *C_ptr, + size_t N, size_t K, + const Tr *bias, const Activation &act, bool Accumulate, + const Nothing &, const int32_t *, unsigned int + ) { + + strat.kernel(A_ptr, B_ptr, C_ptr, N, K, bias, act, Accumulate); +} + +template<> +template<typename strategy, typename Tlo, typename Tro, typename Tr> +void run_gemv_kernel<Requantize32>::run( + const strategy &strat, + const Tlo *A_ptr, const Tro *B_ptr, Tr *C_ptr, + size_t N, size_t K, + const Tr *, const Activation &, bool, + const Requantize32 &qp, const int32_t *col_bias, unsigned int col_base + ) { + + strat.kernel(A_ptr, B_ptr, C_ptr, N, K, &qp, col_bias + col_base, col_base); +} + +} // anonymous namespace + // Implementation of the GemmCommon abstract class. // // This is implementation is for GEMV with pretransposition. // // batches are not supported as a batched GEMV makes no sense (can be converted to a GEMM). -template<typename strategy, typename To, typename Tr> +template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing> class GemvPretransposed : public GemmCommon<To, Tr> { typedef typename strategy::operand_type Toi; typedef typename strategy::result_type Tri; - const unsigned int _Nsize; - const unsigned int _Ksize; + const GemmArgs _args; - const unsigned int _nmultis; + const unsigned int _buffer_per_multi; - const bool _trB; + unsigned int k_block=0; + unsigned int n_block=0; - const Activation _act; + const Toi *_B_pretransposed = nullptr; - const CPUInfo * const _ci; + OutputStage _os; - const unsigned int _buffer_per_multi; - - unsigned int m_block=0; - unsigned int n_block=0; + // Pointer to the column sums (for quantized cases) + int32_t *col_bias = nullptr; - const Toi *_A_pretransposed = nullptr; + // Get size of the column sums + unsigned int get_col_sum_size() const { + if(std::is_same<OutputStage, Requantize32>::value) { + return _args._Nsize * _args._nmulti * sizeof(int32_t); + } else { + return 0; + } + } public: GemvPretransposed(GemvPretransposed &) = delete; GemvPretransposed & operator= (GemvPretransposed &) = delete; - GemvPretransposed(const GemmArgs &args) - : _Nsize(args._Nsize), _Ksize(args._Ksize), _nmultis(args._nmulti), _trB(args._trB), _act(args._act), _ci(args._ci), - _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave()) * strategy::A_interleave()) { + GemvPretransposed(const GemmArgs &args, const OutputStage &os = {}) + : _args(args), + _buffer_per_multi(roundup(args._Ksize, strategy::k_unroll()) * roundup(args._Nsize, strategy::out_width())), + _os(os) { /* For now don't do any blocking. TODO: figure out if we should. */ - if (args._cfg && args._cfg->inner_block_size) { - m_block = args._cfg->inner_block_size; + if (strategy::supports_accumulate() && args._cfg && args._cfg->inner_block_size) { + k_block = args._cfg->inner_block_size; } else { - m_block = _Ksize; + k_block = args._Ksize; } if (args._cfg && args._cfg->outer_block_size) { n_block = args._cfg->outer_block_size; } else { - n_block = _Nsize; + n_block = args._Nsize; } } // Window is number of out_width blocks, times number of multis. ndrange_t get_window_size() const override { - return { iceildiv(_Nsize, strategy::out_width()) * _nmultis, 1u, 1u, 1u, 1u, 1u }; + return { iceildiv(_args._Nsize, strategy::out_width()) * _args._nmulti }; } // Actually execute the GEMV. - void execute_1d(unsigned int start, unsigned int end, int) { + void execute(const ndcoord_t &work_range, const ndcoord_t &, int) override { #ifdef CYCLE_PROFILING profiler prof; #endif - strategy strat(_ci); + strategy strat(_args._ci); + + const auto start = work_range.get_position(0); + const auto end = work_range.get_position_end(0); /* Break the window values down into multis of interest... */ - const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width()); + const unsigned int window_per_multi = iceildiv(_args._Nsize, strategy::out_width()); const unsigned int multi_0 = start / window_per_multi; const unsigned int multi_end = end / window_per_multi; @@ -110,85 +162,99 @@ public: for (unsigned int multi=multi_0; multi<=multi_end; multi++) { const unsigned int n_start = (multi==multi_0) ? n_0 : 0; - const unsigned int n_end = (multi==multi_end) ? n_max : _Nsize; + const unsigned int n_end = (multi==multi_end) ? n_max : _args._Nsize; if (n_end <= n_start) continue; - for (unsigned int m0=0; m0<_Ksize; m0+=m_block) { - unsigned int mmax = std::min(m0 + m_block, _Ksize); + for (unsigned int k0=0; k0<_args._Ksize; k0+=k_block) { + unsigned int kmax = std::min(k0 + k_block, _args._Ksize); for (unsigned int n=n_start; n<n_end; n+=n_block) { unsigned int nmax = std::min(n + n_block, n_end); #ifdef CYCLE_PROFILING - auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax-m0) * (nmax-n)); + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (kmax-k0) * (nmax-n)); #endif - /* This assumes that the underlying call was a GEMM with M=1; for the N=1 case we would have to pick up this->_Bptr below instead */ - strat.kernel(_A_pretransposed + (multi * _buffer_per_multi) + (n * _Ksize) + (m0 * strategy::A_interleave()), - (_Ksize * strategy::A_interleave()), - this->_Aptr + (multi * this->_A_multi_stride) + m0, + run_gemv_kernel<OutputStage>::run(strat, this->_Aptr + (multi * this->_A_multi_stride) + k0, + _B_pretransposed + (multi * _buffer_per_multi) + (n * roundup(_args._Ksize, strategy::k_unroll())) + (k0 * strategy::out_width()), this->_Cptr + (multi * this->_C_multi_stride) + n, - static_cast<Tr>(0), (mmax-m0), (nmax-n)); - - // Handle activation separately for now - if (this->_bias) { - activator<true>(this->_Cptr + (multi * this->_C_multi_stride) + n, 0, - this->_bias + (multi * this->_bias_multi_stride) + n, - _act, 1, (nmax-n)); - } else { - activator<false>(this->_Cptr + (multi * this->_C_multi_stride) + n, 0, - static_cast<const Tr *>(nullptr), - _act, 1, (nmax-n)); - } + (nmax - n), (kmax-k0), + this->_bias ? this->_bias + (multi * this->_bias_multi_stride) + n : nullptr, + _args._act, (k0 != 0) || _args._accumulate, + _os, col_bias, n + (_args._Nsize * multi)); } } } } - // Execute - void execute(const ndcoord_t& work_range, const ndcoord_t& thread_locator, int threadid) override { - UNUSED(thread_locator); - - const auto start = work_range.get_position(0); - const auto size = work_range.get_size(0); - const auto stop = start + size; - - execute_1d(start, stop, threadid); - } - /* Pretransposed interface implementation */ bool B_is_pretransposed() const override { return true; } bool B_pretranspose_required() const override { - /* Transpose is required if _A_pretransposed is still nullptr */ - return (_A_pretransposed == nullptr); + /* Transpose is required if _B_pretransposed is still nullptr */ + return (_B_pretransposed == nullptr); } size_t get_B_pretransposed_array_size() const override { - return _buffer_per_multi * _nmultis * sizeof(To); + return _buffer_per_multi * _args._nmulti * sizeof(To) + get_col_sum_size(); } - void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { - Toi *A_buffer = reinterpret_cast<Toi *>(buffer); - - for (unsigned int multi=0; multi<_nmultis; multi++) { - /* Reverse sense here as we are dealing with B rather than A. So if - * strategy::A_transpose is false and _trB is false, we still - * transpose. */ - if (_trB ^ strategy::A_transpose()) { - Transform<strategy::A_interleave(), strategy::A_block(), false>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); - } else { - Transform<strategy::A_interleave(), strategy::A_block(), true>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); + void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + // Column sums go on the front of the pretransposed buffer in requantized cases. + // We could optimize here in case we don't actually need to sum the columns, but this code is only run on setup. + if (std::is_same<OutputStage, Requantize32>::value) { + col_bias = reinterpret_cast<int32_t *>(in_buffer); + + Requantize32 *qp_ptr = reinterpret_cast<Requantize32 *>(&_os); + + for (unsigned int i=0; i<_args._nmulti; i++) { + compute_col_sums(*qp_ptr, _args._Nsize, _args._Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _args._Nsize), _args._Ksize, i, 0); } } + } + + void set_quantized_bias(const int32_t *bias, size_t bias_multi_stride) override { + if (std::is_same<OutputStage, Requantize32>::value) { + Requantize32 *qp = reinterpret_cast<Requantize32 *>(&_os); + + qp->bias = bias; + qp->bias_multi_stride = bias_multi_stride; + } + } + + void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override { + assert(!transposed); + + requantize_bias(buffer, B, ldb, B_multi_stride); + + // The actual transposed buffer goes after the column sums (if any) + uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer); + Toi *B_buffer = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size()); + + strategy strat(_args._ci); + + for (unsigned int multi=0; multi<_args._nmulti; multi++) { + strat.transforms.PrepareB(B_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _args._Nsize, 0, _args._Ksize, false); + } - _A_pretransposed = A_buffer; + _B_pretransposed = B_buffer; } void set_pretransposed_B_data(void *buffer) override { - _A_pretransposed = reinterpret_cast<Toi *>(buffer); + _B_pretransposed = reinterpret_cast<Toi *>(buffer); + } + + GemmConfig get_config() override { + GemmConfig c; + + c.method = GemmMethod::GEMV_PRETRANSPOSED; + c.inner_block_size = k_block; + c.outer_block_size = n_block; + c.filter = get_type_name<strategy>(); + + return c; } }; |