From c0b6f76561580414f08633a804fc548ccad65659 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 2 Nov 2020 01:37:17 +0000 Subject: COMPMID-3776: Indirect GEMM Signed-off-by: Georgios Pinitas Change-Id: I51a1b0f098bc3a8c408c50c92221e4df3061e12c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4343 Tested-by: Arm Jenkins Reviewed-by: Sang-Hoon Park Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- .../NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp | 621 +++++++++++++++++++++ 1 file changed, 621 insertions(+) create mode 100644 src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp new file mode 100644 index 0000000000..eede1a4f76 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -0,0 +1,621 @@ +/* + * Copyright (c) 2017-2020 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#include + +#include +#include + +#include "arm_gemm.hpp" +#include "bias_adder.hpp" +#include "convolver.hpp" +#include "ndrange.hpp" +#include "performance_parameters.hpp" +#include "transform.hpp" +#include "utils.hpp" + +#ifdef CYCLE_PROFILING +#include "profiler.hpp" +#endif + +#ifndef UNUSED +#define __I_DEFINED_UNUSED +#define UNUSED(x) ((void)(x)) +#endif + +namespace arm_gemm { + +namespace { + +// We need to invoke the kernel differently for quantizing and non-quantizing cases, so here is a shim class to do +// that. + +template +class run_hybrid_kernel { +public: + template + static void run ( +#ifdef CYCLE_PROFILING + profiler &prof, +#endif + const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg A_arg, unsigned int M, unsigned int N, + unsigned int kern_k, const To *b_ptr, IndirectOutputArg output_arg, const Tr *bias_ptr, Activation act, bool accumulate, + const OutputStage &os, const int32_t *col_bias, unsigned int n_0 ); +}; + +template<> +template +void run_hybrid_kernel::run( +#ifdef CYCLE_PROFILING + profiler &prof, +#endif + const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg A_arg, unsigned int M, unsigned int N, + unsigned int kern_k, const To *b_ptr, IndirectOutputArg output_arg, const Tr *bias_ptr, Activation act, bool accumulate, + const Nothing &, const int32_t *, unsigned int) { +#ifdef CYCLE_PROFILING + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)M * kern_k * roundup(N, strategy::out_width())); +#endif + UNUSED(kern_k); + + strat.kernel(num_strings, string_ptr, A_arg, M, N, b_ptr, output_arg, bias_ptr, act, accumulate); +} + +template<> +template +void run_hybrid_kernel::run( +#ifdef CYCLE_PROFILING + profiler &prof, +#endif + const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg A_arg, unsigned int M, unsigned int N, + unsigned int kern_k, const To *b_ptr, IndirectOutputArg output_arg, const Tr *, Activation, bool, + const Requantize32 &os, const int32_t *col_bias, unsigned int n_0 ) { +#ifdef CYCLE_PROFILING + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)M * kern_k * roundup(N, strategy::out_width())); +#endif + UNUSED(kern_k); + + strat.kernel(num_strings, string_ptr, A_arg, M, N, b_ptr, output_arg, &os, col_bias + n_0, n_0); +} + +template<> +template +void run_hybrid_kernel::run( +#ifdef CYCLE_PROFILING + profiler &prof, +#endif + const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg A_arg, unsigned int M, unsigned int N, + unsigned int kern_k, const To *b_ptr, IndirectOutputArg output_arg, const Tr *, Activation, bool, + const Requantize32 &os, const int32_t *col_bias, unsigned int n_0 ) { + UNUSED(kern_k); + // On this route we will only process one kernel height at a time and will make sure this happens in the driver loop. + assert(M <= strategy::out_height()); + // We don't yet support indirect output (as the quantizer can't do it). + assert(output_arg.is_indirect == false); + + // We need a row sum buffer and intermediate output buffer. + // These go on the stack as they are not too large, using an automatic array and alloca() respectively. + int32_t row_sums[strategy::out_height()]; + typename strategy::result_type *result_buffer; + + unsigned int output_width = roundup(N, strategy::out_width()); + + result_buffer = reinterpret_cast(alloca(output_width * strategy::out_height() * sizeof(typename strategy::result_type))); + + { +#ifdef CYCLE_PROFILING + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)M * kern_k * roundup(N, strategy::out_width())); +#endif + // Perform the GEMM, into the output buffer. + strat.kernel(num_strings, string_ptr, A_arg, M, N, b_ptr, IndirectOutputArg(result_buffer, output_width), nullptr, Activation(), false); + } + + if (os.b_offset != 0) { +#ifdef CYCLE_PROFILING + auto p = prof.ScopedProfiler(PROFILE_ROWSUMS, (unsigned long)M * kern_k); +#endif + row_sums_indirect(num_strings, string_ptr, A_arg, M, row_sums, &os); + } else { + memset(row_sums, 0, sizeof(int32_t) * strategy::out_height()); + } + + { +#ifdef CYCLE_PROFILING + auto p = prof.ScopedProfiler(PROFILE_QUANTIZE, (unsigned long)M * N); +#endif + // Quantize + requantize_block_32(os, N, M, result_buffer, output_width, output_arg.direct.base, output_arg.direct.stride, row_sums, col_bias + n_0, n_0); + } +} + +} // anonymous namespace + +// Implementation of the GemmCommon abstract class. +template +class GemmHybridIndirect : public GemmCommon { + typedef typename strategy::operand_type Toi; + typedef typename strategy::result_type Tri; + + GemmArgs _args; + OutputStage _os = {}; + + /* Quantized support (in addition to 'output stage' above) */ + int32_t *_col_bias = nullptr; + + const unsigned int _Ktotal; + const unsigned int _rounded_Ksize; + + /* Blocking info */ + const unsigned int _k_block; + const unsigned int _n_block; + const unsigned int _Mround; + + /* Pretransposed buffer. */ + const Toi *_B_transposed=nullptr; + + /* Indirect parameters. _indirect_buf doubles as a flag to indicate that "indirect" transform should be used. */ + const To * const * const * _indirect_buf = nullptr; + + /* Convolver - only set up for convolution problems, so also doubles as a flag. */ + std::unique_ptr> _convolver = nullptr; + + // Array of pointers to output rows +// Tr * const * _output_ptrs; + + const NDRange<4> _window_range; + + unsigned int get_col_sum_size() const { + if (std::is_same::value) { + return _args._Nsize * _args._nmulti * sizeof(int32_t); + } else { + return 0; + } + } + + static unsigned int get_ktotal(const GemmArgs &args) { + return args._Ksections * roundup(args._Ksize, strategy::k_unroll()); + } + + static unsigned int compute_k_block(const GemmArgs &args) { + // Some kernels don't support accumulate mode - these can't do K blocking at all. + if (!strategy::supports_accumulate() || std::is_same::value) { + return get_ktotal(args); + } + + if (args._cfg && args._cfg->inner_block_size) { + return args._cfg->inner_block_size; + } + + // Experimental data suggests an optimal block size of 512 for FP32 (scaling accordingly for other + // datatypes); but don't divide into blocks until we hit 1.5X this size. + unsigned int target_block_size = 2048 / sizeof(To); + auto ktotal = get_ktotal(args); + + if (ktotal > ((target_block_size*3)/2)) { + unsigned int target_blocks = iceildiv(ktotal, target_block_size); + + unsigned int block_size = iceildiv(ktotal, target_blocks); + + block_size = roundup(block_size, strategy::k_unroll()); + + return block_size; + } + + return ktotal; + } + + // New N blocking strategy: if it's narrow, or much taller than it is wide, do the full width. Otherwise do a + // single block. + static unsigned int compute_n_block(const GemmArgs &args, const OutputStage os = {}) { + if (args._cfg && args._cfg->outer_block_size) { + return args._cfg->outer_block_size; + } + + if (args._Nsize <= 64) { + return args._Nsize; + } + + if ((args._Msize / args._Nsize) > 155) { + return args._Nsize; + } + + // "Asymmetric" quantizing GEMMs require a different approach - the tall skinny blocks we would otherwise + // use imply a great deal of repeated work performing the row sums. If row sums are involved, work out how + // much "column" parallelism is going to be required and set the block size accordingly. + if (std::is_same::value) { + const Requantize32 *qp = reinterpret_cast(&os); + + // Row sums only needed if b_offset isn't 0 + if (qp->b_offset != 0) { + // We can already parallelize across batches, multis and rows (in units of 'out_height') + int multi_row_parallelism = args._nmulti * args._nbatches * iceildiv(args._Msize, strategy::out_height()); + + // If this isn't enough, we will need to split up the columns too. + if (multi_row_parallelism < args._maxthreads) { + unsigned int columns_needed = iceildiv(args._maxthreads, multi_row_parallelism); + + unsigned int n_block = iceildiv(args._Nsize, columns_needed); + + return roundup(n_block, strategy::out_width()); + } + + // Multi/Batch/Row parallelism is enough - don't split up the columns. + return args._Nsize; + } + } + + if (args._Ksize <= 128 && args._maxthreads <= 16) { + return strategy::out_width() * 3; + } + + return strategy::out_width(); + } + +public: + GemmHybridIndirect(GemmHybridIndirect &) = delete; + GemmHybridIndirect & operator= (GemmHybridIndirect &) = delete; + + /* Constructor */ + GemmHybridIndirect(const GemmArgs &args, const OutputStage &os) + : _args(args), _os(os), _Ktotal(get_ktotal(args)), + _rounded_Ksize(roundup(args._Ksize, strategy::k_unroll())), + _k_block(compute_k_block(args)), _n_block(compute_n_block(args, os)), + _Mround(roundup(args._Msize, strategy::out_height())), + _window_range(iceildiv(args._Msize, strategy::out_height()), args._nbatches, + iceildiv(args._Nsize, _n_block), args._nmulti) + { + // We take a copy of the arguments (not a pointer or reference), but there is no lifetime requirement on the + // GemmConfig. Clear out the pointer to avoid accidents. + _args._cfg = nullptr; + } + + /* Constructor without OutputStage */ + GemmHybridIndirect(const GemmArgs &args) + : _args(args), _Ktotal(get_ktotal(args)), + _rounded_Ksize(roundup(args._Ksize, strategy::k_unroll())), + _k_block(compute_k_block(args)), _n_block(compute_n_block(args)), + _Mround(roundup(args._Msize, strategy::out_height())), + _window_range(iceildiv(args._Msize, strategy::out_height()), args._nbatches, + iceildiv(args._Nsize, _n_block), args._nmulti) + { + // We take a copy of the arguments (not a pointer or reference), but there is no lifetime requirement on the + // GemmConfig. Clear out the pointer to avoid accidents. + _args._cfg = nullptr; + } + + // Interface implementation - Compulsory functions + ndrange_t get_window_size() const override { + return { _window_range.total_size() }; + } + + // This kernel can always be dynamically scheduled. + bool supports_dynamic_scheduling() const override { + return true; + } + + // Execute + void execute(const ndcoord_t &work_range, const ndcoord_t &, int) override { +#ifdef CYCLE_PROFILING + profiler prof; +#endif + strategy strat(_args._ci); + + std::vector in_row_ptrs; + std::vector in_row_strings; + std::vector string_lengths; + + // In convolution mode, we need input pointers. + if (_convolver) { + in_row_ptrs = std::vector(strategy::out_height() * _args._Ksections, nullptr); + in_row_strings = std::vector(_args._Ksections, nullptr); + + for (unsigned int i=0; i<_args._Ksections; i++) { + in_row_strings[i] = &(in_row_ptrs[i * strategy::out_height()]); + } + } + + // In any indirect mode, we need the string lengths. + if (_args._indirect_input) { + string_lengths = std::vector(_args._Ksections, 0); + } + + /* Make sure we've been set up correctly. */ + assert(_B_transposed); + static_assert(std::is_same::value, "gemm_native: Operand types must be the same."); +// static_assert(std::is_same::value, "gemm_native: Result types must be the same."); + + /* For now, each work item implies all the K for a given output + * pixel (so we don't need to synchronize access to the output + * array). So separate the loop over K blocks here. */ + for (unsigned int k0=0; k0<_Ktotal; k0+=_k_block) { + unsigned int kmax = std::min(k0 + _k_block, _Ktotal); + unsigned int kern_k = roundup(kmax-k0, strategy::k_unroll()); + + const bool first_pass = (k0 == 0); + const bool last_pass = (kmax == _Ktotal); + + unsigned int first_section = (k0 / _rounded_Ksize); + unsigned int first_offset = (k0 % _rounded_Ksize); + unsigned int kleft = kern_k; + unsigned int sections=0; + unsigned int offset = first_offset; + + if (_args._indirect_input) { + while (kleft) { + // When chopping into sections: the amount that goes into 'string_lengths' is the amount to be + // processed (excluding padding). But the amount we subtract from 'kleft' takes account of any + // padding applied. + string_lengths[sections] = std::min(kleft, _args._Ksize - offset); + kleft -= std::min(kleft, _rounded_Ksize - offset); + sections++; + offset=0; + } + } + + auto p = _window_range.iterator(work_range.get_position(0), work_range.get_position_end(0)); + + if (p.done()) { + return; + } + + // Process rows either 'out_height' rows at a time, or do all valid rows at once with a single kernel call. + // The separate quantizer path only handles one block of rows at a time (as it has to store sums and intermediate results). + // THe convolution path only generates the pointers for one block of rows at a time. + const bool process_all_rows = (!SeparateQuantize && !_convolver); + + do { + const unsigned int m_start = p.dim(0) * strategy::out_height(); + const unsigned int m_end = process_all_rows ? std::min(p.dim0_max() * strategy::out_height(), _args._Msize) : std::min(m_start + strategy::out_height(), _args._Msize); +// const unsigned int m_end = std::min(m_start + strategy::out_height(), _args._Msize); + const unsigned int batch = p.dim(1); + const unsigned int n0 = p.dim(2) * _n_block; + const unsigned int nmax = std::min(n0 + _n_block, _args._Nsize); + const unsigned int multi = p.dim(3); + + const Toi *b_panel = _B_transposed + + (multi * roundup(_args._Nsize, strategy::out_width()) * _Ktotal) + + (k0 * roundup(_args._Nsize, strategy::out_width())) + + (n0 * kern_k); + + IndirectOutputArg out_arg(this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc); + +#ifdef CYCLE_PROFILING + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)(m_end - m_start) * kern_k * roundup(nmax-n0, strategy::out_width())); +#endif + if (_indirect_buf) { + run_hybrid_kernel::run( +#ifdef CYCLE_PROFILING + prof, +#endif + strat, sections, string_lengths.data(), + IndirectInputArg(_indirect_buf + (multi * _args._nbatches * _args._Ksections) + (batch * _args._Ksections) + first_section, m_start, first_offset), + (m_end - m_start), (nmax - n0), kern_k, b_panel, out_arg, + (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, + last_pass ? _args._act : Activation(), + !first_pass, + // Quantization parameters + _os, _col_bias+(multi * _args._Nsize), n0); + } else if (_convolver) { + auto conv_cols = _convolver->process_columns(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride), this->_lda, k0, kmax, _rounded_Ksize); + + unsigned int pos=0; + auto conv_rows = conv_cols.process_rows(m_start, m_end - m_start); + + while (!conv_rows.finished()) { + unsigned int width, conv_offset; + + assert(pos < sections); + + std::tie(width, conv_offset) = conv_rows.next_block(&(in_row_ptrs[pos * strategy::out_height()])); + + if (pos==0) { + assert(conv_offset == first_offset); + } + assert(width == string_lengths[pos]); + pos++; + } + assert(pos == sections); + + run_hybrid_kernel::run( +#ifdef CYCLE_PROFILING + prof, +#endif + strat, sections, string_lengths.data(), + IndirectInputArg(in_row_strings.data(), 0, first_offset), + (m_end - m_start), (nmax - n0), kern_k, b_panel, out_arg, + (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, + last_pass ? _args._act : Activation(), + !first_pass, + // Quantization parameters + _os, _col_bias+(multi * _args._Nsize), n0); + } else { + // Length to process. This needs to exclude padding, but 'kmax' potentially includes it. + const unsigned int len = (std::min(_args._Ksize, kmax) - k0); + + run_hybrid_kernel::run( +#ifdef CYCLE_PROFILING + prof, +#endif + strat, 1, &len, + IndirectInputArg(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + m_start * this->_lda + k0, this->_lda), + (m_end - m_start), (nmax - n0), kern_k, b_panel, out_arg, + (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, + last_pass ? _args._act : Activation(), + !first_pass, + // Quantization parameters + _os, _col_bias+(multi * _args._Nsize), n0); + } + } while (process_all_rows ? p.next_dim1() : p.next_dim0()); + } + } + + // Interface implementation - pretransposed + bool B_is_pretransposed() const override { + return true; + } + + bool B_pretranspose_required() const override { + return (_B_transposed==nullptr); + } + + size_t get_B_pretransposed_array_size() const override { + // Start with actual pretransposed buffer... + size_t size = roundup(_args._Nsize, strategy::out_width()) * _Ktotal * _args._nmulti * sizeof(Toi); + + // Space for result row pointers (not strictly needed any more but retained for indirect output testing) + size += _args._Msize * _args._nbatches * _args._nmulti * sizeof(const Tr *); + + if (std::is_same::value) { + size += get_col_sum_size(); + } + + return size; + } + + void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + if (std::is_same::value) { + _col_bias = reinterpret_cast(in_buffer); + + Requantize32 *qp_ptr = reinterpret_cast(&_os); + + for (unsigned int i=0; i<_args._nmulti; i++) { + // The input is assumed not to have any padding between sections, so straightforward Ksize * Ksections computation gets the total size. + compute_col_sums(*qp_ptr, _args._Nsize, _args._Ksize * _args._Ksections, B + (i * B_multi_stride), ldb, _col_bias + (i * _args._Nsize), _args._Ksize * _args._Ksections, i, 0); + } + } + + // Put the transposed data after the column sums - in non-transposing cases get_col_sum_size() == 0 + uintptr_t buffer_int = reinterpret_cast(in_buffer); + Toi *buffer = reinterpret_cast(buffer_int + get_col_sum_size()); + _B_transposed = buffer; + + strategy strat(_args._ci); + + for (unsigned int multi=0; multi<_args._nmulti; multi++) { + for (unsigned int k0=0; k0<_Ktotal; k0+=_k_block) { + const unsigned int kmax=std::min(k0 + _k_block, _Ktotal); + + /* Figure out the size of each block. */ + unsigned int k_size = kmax - k0; + + // We need to insert padding at the end of each K section. + // The computation needed is a little delicate - the coordinates from the block walker are expressed in + // terms of the full, padded, _Ktotal. + // But we need to transform each section with reference to the original, unpadded, input, letting the + // transform pad each section as needed. + + // This is needed for computations below. + const unsigned int rounded_section_size = roundup(_args._Ksize, strategy::k_unroll()); + + // The expected output format is also an entire columns interleaved, then the next set of + // columns, and so on. This means, as we are breaking it up vertically, we have to do it one column at + // a time. + for (unsigned int x0=0; x0 < _args._Nsize; x0 += strategy::out_width() ){ + unsigned int xmax = std::min(x0 + strategy::out_width(), _args._Nsize); + + // Track where we are and how much work is left. + unsigned int kpos = k0; + unsigned int kleft = k_size; + + while (kleft) { + // Which section are we in? Based on the rounded-up section size. + unsigned int k_section_base = kpos / rounded_section_size; + // How far into the section are we? + unsigned int k_offset = kpos - (k_section_base * rounded_section_size); + + // We will either copy the rest of this section, or to the end of the requested length. + unsigned int k_length = std::min(_args._Ksize - k_offset, kleft); + + strat.transforms.PrepareB(buffer, B + (multi * B_multi_stride), ldb, + x0, xmax, + (k_section_base * _args._Ksize) + k_offset, // K starting point - compute row to read based on our section and the true section length. + (k_section_base * _args._Ksize) + k_offset + k_length); // K end point - starting point plus length computed above. + + // We need to modify our position based on the ROUNDED version of what we just did. + unsigned int padded_length = roundup(k_length, strategy::k_unroll()); + + buffer += strategy::out_width() * padded_length; + + kpos += padded_length; + kleft -= padded_length; + } + } + } + } + } + + void set_pretransposed_B_data(void *in_buffer) override { + // Put the transposed data after the column sums - in non-transposing cases get_col_sum_size() == 0 + uintptr_t buffer_int = reinterpret_cast(in_buffer); + _B_transposed = reinterpret_cast(buffer_int + get_col_sum_size()); + _col_bias = reinterpret_cast(in_buffer); + } + + // Estimate cycles for given problem given provided parameters + static uint64_t estimate_cycles(const GemmArgs &args, const PerformanceParameters ¶ms) { + // Note: Current hybrid kernels don't actually round up height (they + // have paths for each possible height). Might need to make this + // configurable in future. + uint64_t total_macs = static_cast(args._nbatches) * args._nmulti * args._Msize * roundup(args._Nsize, strategy::out_width()) * roundup(args._Ksize, strategy::k_unroll()); + + float mac_cycles = static_cast(total_macs) / params.kernel_macs_cycle; + + // TODO: A bit of a kludge here: current hybrid kernels incur extra + // overhead where the width is not a multiple of kernel width. It's + // most noticable where the overall width is quite low, so add 15% + // penalty for such widths. + if ((args._Nsize < strategy::out_width()) || (args._Nsize > strategy::out_width() && args._Nsize < 2*strategy::out_width())) { + mac_cycles *= 1.15f; + } + + uint64_t total_cycles = mac_cycles; + + return total_cycles; + } + + void set_quantized_bias(const int32_t *bias, size_t bias_multi_stride) override { + if (std::is_same::value) { + Requantize32 *qp = reinterpret_cast(&_os); + + qp->bias = bias; + qp->bias_multi_stride = bias_multi_stride; + } + } + + void set_indirect_parameters(size_t string_len, const To * const * const *ptr) override { + assert(string_len == _args._Ksize); + _indirect_buf = ptr; + } + + void set_convolution_parameters(ConvolutionParameters parms) override { + assert(parms.input_channels == _args._Ksize); + _convolver = std::unique_ptr>(new convolver(parms)); + } +}; + +} // namespace arm_gemm + +#ifdef __I_DEFINED_UNUSED +#undef UNUSED +#endif -- cgit v1.2.1