diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp | 140 |
1 files changed, 116 insertions, 24 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp index 5b3ef4203d..c41b0a5b3e 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -33,6 +33,7 @@ #include "arm_gemm.hpp" #include "bias_adder.hpp" #include "convolver.hpp" +#include "kernel_weight_format.hpp" #include "ndrange.hpp" #include "performance_parameters.hpp" #include "transform.hpp" @@ -54,7 +55,7 @@ namespace { // We need to invoke the kernel differently for quantizing and non-quantizing cases, so here is a shim class to do // that. -template<typename OutputStage, bool SeparateQuantize = false> +template<typename OutputStage, bool SeparateQuantize, bool FixedFormat> class run_hybrid_kernel { public: template<typename strategy, typename Tlo, typename Tro, typename Tr> @@ -63,18 +64,18 @@ public: profiler &prof, #endif const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg<Tlo> A_arg, unsigned int M, unsigned int N, - unsigned int kern_k, const Tro *b_ptr, IndirectOutputArg<Tr> output_arg, const Tr *bias_ptr, Activation act, bool accumulate, + unsigned int kern_k, const Tro *b_ptr, size_t b_stride, IndirectOutputArg<Tr> output_arg, const Tr *bias_ptr, Activation act, bool accumulate, const OutputStage &os, const int32_t *col_bias, unsigned int n_0 ); }; template<> template<typename strategy, typename Tlo, typename Tro, typename Tr> -inline void run_hybrid_kernel<Nothing, false>::run( +inline void run_hybrid_kernel<Nothing, false, false>::run( #ifdef CYCLE_PROFILING profiler &prof, #endif const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg<Tlo> A_arg, unsigned int M, unsigned int N, - unsigned int kern_k, const Tro *b_ptr, IndirectOutputArg<Tr> output_arg, const Tr *bias_ptr, Activation act, bool accumulate, + unsigned int kern_k, const Tro *b_ptr, size_t, IndirectOutputArg<Tr> output_arg, const Tr *bias_ptr, Activation act, bool accumulate, const Nothing &, const int32_t *, unsigned int) { #ifdef CYCLE_PROFILING auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)M * kern_k * roundup(N, strategy::out_width())); @@ -115,12 +116,60 @@ inline void run_hybrid_kernel<Nothing, false>::run( template<> template<typename strategy, typename Tlo, typename Tro, typename Tr> -inline void run_hybrid_kernel<Requantize32, false>::run( +inline void run_hybrid_kernel<Nothing, false, true>::run( #ifdef CYCLE_PROFILING profiler &prof, #endif const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg<Tlo> A_arg, unsigned int M, unsigned int N, - unsigned int kern_k, const Tro *b_ptr, IndirectOutputArg<Tr> output_arg, const Tr *, Activation, bool, + unsigned int kern_k, const Tro *b_ptr, size_t b_stride, IndirectOutputArg<Tr> output_arg, const Tr *bias_ptr, Activation act, bool accumulate, + const Nothing &, const int32_t *, unsigned int) { +#ifdef CYCLE_PROFILING + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)M * kern_k * roundup(N, strategy::out_width())); +#endif + UNUSED(kern_k); + + /* Indirect hybrid kernels read the full width of the bias. So we need to detect the case where we are writing + * a partial block and pad the bias for that block. */ + if (bias_ptr && !accumulate && (N % strategy::out_width() != 0)) { + /* Break N into "N_bulk" (a multiple of output width) and "N_remainder" */ + unsigned int N_remainder = N % strategy::out_width(); + unsigned int N_bulk = N - N_remainder; + + /* Output argument to be used for the tail */ + IndirectOutputArg<Tr> offset_output = output_arg; + + /* If there is a "bulk" to be processed, handle that and update "offset_output" appropriately. */ + if (N_bulk > 0) { + strat.kernel(num_strings, string_ptr, A_arg, M, N_bulk, b_ptr, b_stride, output_arg, bias_ptr, act, accumulate); + + if (output_arg.is_indirect) { + offset_output = IndirectOutputArg<Tr>(output_arg.indirect.ptr, output_arg.indirect.offset + N_bulk); + } else { + offset_output = IndirectOutputArg<Tr>(output_arg.direct.base + N_bulk, output_arg.direct.stride); + } + } + + /* Pad the bias buffer for the remainder */ + Tr *bias_pad_buffer = reinterpret_cast<Tr *>(alloca(strategy::out_width() * sizeof(Tr))); + memcpy(bias_pad_buffer, bias_ptr + N_bulk, N_remainder * sizeof(Tr)); + + /* Process the remainder, offsetting the B pointer as needed. */ + strat.kernel(num_strings, string_ptr, A_arg, M, N_remainder, + b_ptr + (N_bulk / strategy::stripe_width()) * b_stride, b_stride, offset_output, + bias_pad_buffer, act, accumulate); + } else { + strat.kernel(num_strings, string_ptr, A_arg, M, N, b_ptr, b_stride, output_arg, bias_ptr, act, accumulate); + } +} + +template<> +template<typename strategy, typename Tlo, typename Tro, typename Tr> +inline void run_hybrid_kernel<Requantize32, false, false>::run( +#ifdef CYCLE_PROFILING + profiler &prof, +#endif + const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg<Tlo> A_arg, unsigned int M, unsigned int N, + unsigned int kern_k, const Tro *b_ptr, size_t, IndirectOutputArg<Tr> output_arg, const Tr *, Activation, bool, const Requantize32 &os, const int32_t *col_bias, unsigned int n_0 ) { #ifdef CYCLE_PROFILING auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)M * kern_k * roundup(N, strategy::out_width())); @@ -132,12 +181,12 @@ inline void run_hybrid_kernel<Requantize32, false>::run( template<> template<typename strategy, typename Tlo, typename Tro, typename Tr> -inline void run_hybrid_kernel<Requantize32, true>::run( +inline void run_hybrid_kernel<Requantize32, true, false>::run( #ifdef CYCLE_PROFILING profiler &prof, #endif const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg<Tlo> A_arg, unsigned int M, unsigned int N, - unsigned int kern_k, const Tro *b_ptr, IndirectOutputArg<Tr> output_arg, const Tr *, Activation, bool, + unsigned int kern_k, const Tro *b_ptr, size_t, IndirectOutputArg<Tr> output_arg, const Tr *, Activation, bool, const Requantize32 &os, const int32_t *col_bias, unsigned int n_0 ) { UNUSED(kern_k); // On this route we will only process one kernel height at a time and will make sure this happens in the driver loop. @@ -180,10 +229,38 @@ inline void run_hybrid_kernel<Requantize32, true>::run( } } +template<typename strategy, bool FixedFormat> +struct stripe_width { + static unsigned int get() { + return strategy::stripe_width(); + } +}; + +template<typename strategy> +struct stripe_width<strategy, false> { + static unsigned int get() { + return 0; + } +}; + +template<typename strategy, bool FixedFormat> +struct kernel_weight_format { + static KernelWeightFormat get() { + return strategy::kernel_weight_format(); + } +}; + +template<typename strategy> +struct kernel_weight_format<strategy, false> { + static KernelWeightFormat get() { + return KernelWeightFormat::NON_FIXED; + } +}; + } // anonymous namespace // Implementation of the GemmCommon abstract class. -template<typename strategy, typename To, typename Tr, typename OutputStage = Nothing, bool SeparateQuantize = false> +template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool SeparateQuantize=false, bool FixedFormat=false> class GemmHybridIndirect : public GemmCommon<To, Tr> { typedef typename strategy::lhs_operand_type Tloi; typedef typename strategy::rhs_operand_type Troi; @@ -425,24 +502,32 @@ public: const unsigned int nmax = std::min(n0 + _n_block, _args._Nsize); const unsigned int multi = p.dim(3); - const Troi *b_panel = _B_transposed + - (multi * roundup(_args._Nsize, strategy::out_width()) * _Ktotal) + - (k0 * roundup(_args._Nsize, strategy::out_width())) + - (n0 * kern_k); + const Troi *b_panel; + if (FixedFormat) { + b_panel = reinterpret_cast<const Troi *>(this->_Bptr) + + (multi * this->_B_multi_stride) + + ((n0 / stripe_width<strategy, FixedFormat>::get()) * this->_ldb) + + (k0 * stripe_width<strategy, FixedFormat>::get()); + } else { + b_panel = _B_transposed + + (multi * roundup(_args._Nsize, strategy::out_width()) * _Ktotal) + + (k0 * roundup(_args._Nsize, strategy::out_width())) + + (n0 * kern_k); + } - IndirectOutputArg<Tr> out_arg(this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc); + IndirectOutputArg<Tr> out_arg(this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc); #ifdef CYCLE_PROFILING auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)(m_end - m_start) * kern_k * roundup(nmax-n0, strategy::out_width())); #endif if (_indirect_buf) { - run_hybrid_kernel<OutputStage, SeparateQuantize>::run( + run_hybrid_kernel<OutputStage, SeparateQuantize, FixedFormat>::run( #ifdef CYCLE_PROFILING prof, #endif strat, sections, string_lengths.data(), IndirectInputArg<To>(_indirect_buf + (multi * _args._nbatches * _args._Ksections) + (batch * _args._Ksections) + first_section, m_start, first_offset), - (m_end - m_start), (nmax - n0), kern_k, b_panel, out_arg, + (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), !first_pass, @@ -469,13 +554,13 @@ public: } assert(pos == sections); - run_hybrid_kernel<OutputStage, SeparateQuantize>::run( + run_hybrid_kernel<OutputStage, SeparateQuantize, FixedFormat>::run( #ifdef CYCLE_PROFILING prof, #endif strat, sections, string_lengths.data(), IndirectInputArg<To>(in_row_strings.data(), 0, first_offset), - (m_end - m_start), (nmax - n0), kern_k, b_panel, out_arg, + (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), !first_pass, @@ -485,13 +570,13 @@ public: // Length to process. This needs to exclude padding, but 'kmax' potentially includes it. const unsigned int len = (std::min(_args._Ksize, kmax) - k0); - run_hybrid_kernel<OutputStage, SeparateQuantize>::run( + run_hybrid_kernel<OutputStage, SeparateQuantize, FixedFormat>::run( #ifdef CYCLE_PROFILING prof, #endif strat, 1, &len, IndirectInputArg<To>(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + m_start * this->_lda + k0, this->_lda), - (m_end - m_start), (nmax - n0), kern_k, b_panel, out_arg, + (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), !first_pass, @@ -504,14 +589,18 @@ public: // Interface implementation - pretransposed bool B_is_pretransposed() const override { - return true; + return (FixedFormat == false); } bool B_pretranspose_required() const override { - return (_B_transposed==nullptr); + return (FixedFormat == false) && (_B_transposed==nullptr); } size_t get_B_pretransposed_array_size() const override { + if (FixedFormat) { + return 0; + } + // Start with actual pretransposed buffer... size_t size = roundup(_args._Nsize, strategy::out_width()) * _Ktotal * _args._nmulti * sizeof(Troi); @@ -599,8 +688,7 @@ public: } } } else { - // In the single K section case, can process the whole lot in one go. - // Caution: 'blockwalker::kmax()' rounds up, so clamp to valid _Ksize. + // In the single K section case, can process the whole lot in one go. strat.transforms.PrepareB(buffer, B + (multi * B_multi_stride), ldb, 0, _args._Nsize, k0, std::min(kmax, _args._Ksize)); buffer += roundup(_args._Nsize, strategy::out_width()) * roundup(kmax-k0, strategy::k_unroll()); @@ -694,11 +782,15 @@ public: c.inner_block_size = _k_block; c.outer_block_size = _n_block; c.filter = get_type_name<strategy>(); + c.weight_format = get_weight_format(kernel_weight_format<strategy, FixedFormat>::get(), sizeof(To)); return c; } }; +template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing> +using GemmHybridIndirectFixedFormat = GemmHybridIndirect<strategy, To, Tr, OutputStage, false, true>; + } // namespace arm_gemm #ifdef __I_DEFINED_UNUSED |