aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp486
1 files changed, 407 insertions, 79 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 7f870b83d7..897ec9d05f 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,7 +27,10 @@
#include <cassert>
#include "arm_gemm.hpp"
+#include "bfloat.hpp"
#include "convolver.hpp"
+#include "kernel_traits.hpp"
+#include "kernel_weight_format.hpp"
#include "mergeresults.hpp"
#include "performance_parameters.hpp"
#include "quantized.hpp"
@@ -56,7 +59,7 @@ namespace {
// Others output directly to the matrix result. This helper class calls the
// appropriate functions, using templating to avoid calling non-existent
// functions.
-template<bool MergeStep, typename OutputStage>
+template<bool MergeStep, bool FixedFormat, typename OutputStage>
class kernel_and_merge {
public:
template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
@@ -64,7 +67,7 @@ public:
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, Tri *c_panel,
+ strategy &strat, const To *a_ptr, const To *b_panel, size_t b_stride, Tri *c_panel,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr,
const Activation &act, bool accumulate, const OutputStage &os, const int32_t *col_bias,
@@ -74,11 +77,11 @@ public:
// Run a kernel and call the separate merge step
template<>
template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
-void kernel_and_merge<true, Nothing>::run(
+void kernel_and_merge<true, false, Nothing>::run(
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, Tri *c_panel,
+ strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr,
const Activation &act, bool accumulate, const Nothing &, const int32_t *, Tab *)
@@ -101,14 +104,44 @@ void kernel_and_merge<true, Nothing>::run(
}
}
+// Run a fixed-format kernel and call the separate merge step
+template<>
+template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+void kernel_and_merge<true, true, Nothing>::run(
+#ifdef CYCLE_PROFILING
+ profiler &prof,
+#endif
+ strategy &strat, const To *a_ptr, const To *b_panel, size_t b_stride, Tri *c_panel,
+ Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
+ unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr,
+ const Activation &act, bool accumulate, const Nothing &, const int32_t *, Tab *)
+{
+ {
+#ifdef CYCLE_PROFILING
+ const int bblocks = iceildiv(n_max - n_0, strategy::out_width());
+ auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height() * bblocks * strategy::out_width() * kern_k));
+#endif
+
+ strat.kernel(a_ptr, b_panel, b_stride, c_panel, 1, (n_max - n_0), kern_k);
+ }
+
+ {
+#ifdef CYCLE_PROFILING
+ const int bblocks = iceildiv(n_max - n_0, strategy::out_width());
+ auto p=prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height() * bblocks * strategy::out_width() * sizeof(Tr)));
+#endif
+ strat.transforms.Merge(c_ptr, c_panel, ldc, m_0, m_max, n_0, n_max, biasptr, act, accumulate);
+ }
+}
+
// Run a kernel with integrated merge
template<>
template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
-void kernel_and_merge<false, Nothing>::run(
+void kernel_and_merge<false, false, Nothing>::run(
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, Tri *,
+ strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max,
unsigned int n_0, unsigned int n_max, const Tr *biasptr,
const Activation &act, bool accumulate, const Nothing &, const int32_t *,
@@ -143,11 +176,11 @@ void kernel_and_merge<false, Nothing>::run(
// Run a kernel with integrated merge, quantizing
template<>
template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
-void kernel_and_merge<false, Requantize32>::run(
+void kernel_and_merge<false, false, Requantize32>::run(
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, Tri *,
+ strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max,
unsigned int n_0, unsigned int n_max, const Tr *,
const Activation &, bool accumulate, const Requantize32 &qp, const int32_t *col_bias,
@@ -157,10 +190,19 @@ void kernel_and_merge<false, Requantize32>::run(
auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k);
#endif
+ // Offset C pointer in a similar way to non-quantized case above.
+ Tri *offset_c_ptr;
+
+ if (c_ptr == nullptr) {
+ offset_c_ptr = nullptr;
+ } else {
+ offset_c_ptr = c_ptr + m_0 * ldc + n_0;
+ }
+
strat.kernel(// A and B pointers are just the packed panels.
a_ptr, b_panel,
// Provide relevant part of output array and row stride.
- c_ptr + m_0 * ldc + n_0, ldc,
+ offset_c_ptr, ldc,
// M, N, K sizes
m_max-m_0, n_max - n_0, kern_k,
// Bias, activation, accumulation. Need to offset the bias as needed.
@@ -170,11 +212,11 @@ void kernel_and_merge<false, Requantize32>::run(
// Run a kernel and call the separate quantize step
template<>
template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
-void kernel_and_merge<true, Requantize32>::run(
+void kernel_and_merge<true, false, Requantize32>::run(
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, Tri *c_panel,
+ strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *,
const Activation &, bool, const Requantize32 &qp, const int32_t *col_bias,
@@ -192,7 +234,7 @@ void kernel_and_merge<true, Requantize32>::run(
{
#ifdef CYCLE_PROFILING
- auto p=prof.ScopedProfiler(PROFILE_QUANTIZE, (strategy::out_height() * bblocks * strategy::out_width() * sizeof(Tr)));
+ auto p=prof.ScopedProfiler(PROFILE_QUANTIZE, ((m_max-m_0) * bblocks * strategy::out_width() * sizeof(Tr)));
#endif
// The interleaved kernel outputs in blocks - each block is a
// row-major matrix of size out_width * out_height. The merge
@@ -213,6 +255,84 @@ void kernel_and_merge<true, Requantize32>::run(
}
}
+// Run a kernel with integrated merge, dequantizing to FP32
+template<>
+template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+void kernel_and_merge<false, false, DequantizeFloat>::run(
+#ifdef CYCLE_PROFILING
+ profiler &prof,
+#endif
+ strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *,
+ Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max,
+ unsigned int n_0, unsigned int n_max, const Tr *bias,
+ const Activation &act, bool accumulate, const DequantizeFloat &dq, const int32_t *col_bias,
+ Tab *acc_buff)
+{
+#ifdef CYCLE_PROFILING
+ auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k);
+#endif
+
+ const int32_t *offset_col_bias = nullptr;
+ const Tr *offset_bias = nullptr;
+
+ if (col_bias) {
+ offset_col_bias = col_bias + n_0;
+ }
+
+ if (bias) {
+ offset_bias = bias + n_0;
+ }
+
+ strat.kernel(// A and B pointers are just the packed panels.
+ a_ptr, b_panel,
+ // Provide relevant part of output array and row stride.
+ c_ptr ? (c_ptr + m_0 * ldc + n_0) : nullptr, ldc,
+ // M, N, K sizes
+ m_max-m_0, n_max - n_0, kern_k,
+ // Bias, activation, accumulation. Need to offset the bias as needed.
+ offset_col_bias, dq, offset_bias, act, accumulate, acc_buff);
+}
+
+template<>
+template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+void kernel_and_merge<true, false, DequantizeFloat>::run(
+#ifdef CYCLE_PROFILING
+ profiler &prof,
+#endif
+ strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel,
+ Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
+ unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *bias,
+ const Activation &act, bool accumulate, const DequantizeFloat &qp, const int32_t *,
+ Tab *)
+{
+ const int bblocks = iceildiv(n_max - n_0, strategy::out_width());
+
+ {
+#ifdef CYCLE_PROFILING
+ auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height() * bblocks * strategy::out_width() * kern_k));
+#endif
+
+ strat.kernel(a_ptr, b_panel, c_panel, 1, bblocks, kern_k);
+ }
+
+ {
+#ifdef CYCLE_PROFILING
+ auto p=prof.ScopedProfiler(PROFILE_QUANTIZE, ((m_max-m_0) * bblocks * strategy::out_width() * sizeof(Tr)));
+#endif
+ auto out_area = strategy::out_width() * strategy::out_height();
+ for (int i=0; i<bblocks; i++) {
+ const unsigned int n_start = n_0 + (strategy::out_width() * i);
+ const unsigned int n_end = std::min(n_start + strategy::out_width(), n_max);
+
+ dequantize_block_32(qp, (n_end - n_start), (m_max - m_0),
+ c_panel + (i * out_area), strategy::out_width(),
+ c_ptr + m_0 * ldc + n_start, ldc,
+ bias != nullptr ? bias + n_start : nullptr, accumulate, act);
+
+ }
+ }
+}
+
// Integer GEMMs can be used in two contexts - "normal" where the full 32-bit output is required, or in
// "requantizing" context where the output will be requantized.
//
@@ -234,25 +354,77 @@ public:
};
// We need a similar trick here to figure out what type the accumulator buffer should be.
-template<typename strategy, typename OutputStage>
+template<typename strategy, typename OutputStage, bool ForceFloat>
class accumulate_buffer_type {
public:
typedef typename strategy::result_type type;
};
template<typename strategy>
-class accumulate_buffer_type<strategy, Requantize32> {
+class accumulate_buffer_type<strategy, Requantize32, false> {
public:
typedef int32_t type;
};
+template<typename strategy>
+class accumulate_buffer_type<strategy, DequantizeFloat, false> {
+public:
+ typedef int32_t type;
+};
+
+template<typename strategy, typename OutputStage>
+class accumulate_buffer_type<strategy, OutputStage, true> {
+public:
+ typedef float type;
+};
+
+// Stripe width is a concept only needed for FixedFormat kernels. Use an accessor to avoid issues in other scenarios.
+template<typename strategy, bool FixedFormat>
+struct get_stripe_width {
+ static unsigned int get() {
+ return 0;
+ }
+};
+
+template<typename strategy>
+struct get_stripe_width<strategy, true> {
+ static unsigned int get() {
+ return strategy::stripe_width();
+ }
+};
+
+// KernelWeightFormat is a similar story.
+template<typename strategy, bool FixedFormat, typename To>
+struct get_kernel_weight_format {
+ static KernelWeightFormat get() {
+ return KernelWeightFormat::NON_FIXED;
+ }
+};
+
+template<typename strategy, typename To>
+struct get_kernel_weight_format<strategy, true, To> {
+ static KernelWeightFormat get() {
+ KernelWeightFormat kwf = strategy::kernel_weight_format();
+
+ // If we are using a BF16 kernel to do an FP32 problem (fast mode) then we need to set the BF16 flag on the
+ // weight format.
+ if (std::is_same<To, float>::value && std::is_same<typename strategy::operand_type, bfloat16>::value) {
+ uint32_t kwf_i = static_cast<uint32_t>(kwf);
+ kwf_i |= 0x10;
+ kwf = static_cast<KernelWeightFormat>(kwf_i);
+ }
+
+ return kwf;
+ }
+};
+
} // anonymous namespace
-template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool ForceThreadColumns=false>
+template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool FixedFormat=false, bool ForceThreadColumns=false, bool ForceFloatAccumulate=false>
class GemmInterleaved : public GemmCommon<To, Tr> {
typedef typename strategy::operand_type Toi;
typedef typename strategy::result_type Tri;
- typedef typename accumulate_buffer_type<strategy, OutputStage>::type Tab;
+ typedef typename accumulate_buffer_type<strategy, OutputStage, ForceFloatAccumulate>::type Tab;
/* const properties set by constructor */
const CPUInfo * const _ci;
@@ -270,6 +442,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
const bool _thread_columns;
const Activation _act;
+ const bool _accumulate;
const int _maxthreads;
int _nthreads;
@@ -310,7 +483,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
class blockwalker {
private:
/* Size loops, etc. based on our parent's configuration */
- const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, ForceThreadColumns> &_parent;
+ const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &_parent;
/* K, X and multi parameters for current iteration. */
unsigned int _k0=0, _x0=0, _multi=0;
@@ -325,9 +498,9 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
bool _newmulti=true;
public:
- blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, ForceThreadColumns> &parent) : _parent(parent) { }
+ blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent) : _parent(parent) { }
- blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, ForceThreadColumns> &parent,
+ blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent,
unsigned int x_start, unsigned int x_end) : _parent(parent), _x0 (_x_start), _x_start(x_start), _x_end(x_end) { }
unsigned int xmax() {
@@ -496,15 +669,46 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
static unsigned int get_k_block_size(const GemmArgs &args) {
if (args._cfg && args._cfg->inner_block_size) {
- return args._cfg->inner_block_size;
+ return roundup(args._cfg->inner_block_size, strategy::k_unroll());
}
- // K blocking not supported if we are requantizing.
- if (std::is_same<OutputStage, Requantize32>::value) {
+ // K blocking not supported if we are requantizing with the merging
+ // kernels.
+ if (std::is_same<OutputStage, Requantize32>::value && MergeStep) {
return get_ktotal(args);
}
const unsigned int L1_size = args._ci->get_L1_cache_size();
+
+ // Special blocking for SME
+ if (is_sme<strategy>::value) {
+ // Target 512 bytes for 64kB L1, or 1024 bytes for 128kB L1.
+ unsigned int target_bytes_per_block = L1_size / 128;
+
+ // Default cache size in gemm-linux is 32kB though - so make
+ // sure minimum is 512
+ if (target_bytes_per_block < 512) {
+ target_bytes_per_block = 512;
+ }
+
+ // Don't bother to block below this size threshold (1.25X target size)
+ unsigned int scaling_threshold = ((target_bytes_per_block * 5) / 4) / sizeof(Toi);
+
+ if (get_ktotal(args) <= scaling_threshold) {
+ return get_ktotal(args);
+ }
+
+ // Once we are blocking, this (lower) threshold determines when we should use more blocks
+ // NOTE: Could be that some factor-based solution would work better here.
+ unsigned int max_block_size = target_bytes_per_block / sizeof(Toi);
+
+ unsigned int num_k_blocks = iceildiv(get_ktotal(args), max_block_size);
+
+ unsigned int k_block = roundup(iceildiv(get_ktotal(args), num_k_blocks), strategy::k_unroll());
+
+ return k_block;
+ }
+
unsigned int k_block;
// k_block: Find out how much of the larger array can be loaded into half the cache.
@@ -539,6 +743,17 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
return roundup(args._cfg->outer_block_size, strategy::out_width());
}
+ // Special blocking for SME
+ if (is_sme<strategy>::value) {
+ // If total width is less than 4x kernel width, return the entire width.
+ if (args._Nsize < strategy::out_width()*4) {
+ return roundup(args._Nsize, strategy::out_width());
+ }
+
+ // Otherwise block to single kernel width.
+ return strategy::out_width();
+ }
+
unsigned int x_block;
const unsigned int L2_size = args._ci->get_L2_cache_size();
const unsigned int k_block = get_k_block_size(args);
@@ -580,7 +795,7 @@ public:
_Ksections(args._Ksections), _Ktotal(get_ktotal(args)),
_rounded_Ksize(roundup(_Ksize, strategy::k_unroll())),
_nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)),
- _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
_k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())),
_os(os) { }
@@ -590,7 +805,7 @@ public:
_Ksections(args._Ksections), _Ktotal(get_ktotal(args)),
_rounded_Ksize(roundup(_Ksize, strategy::k_unroll())),
_nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)),
- _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
_k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())),
_os() { }
@@ -623,7 +838,7 @@ public:
#endif
/* Make sure we've been set up correctly. */
- assert(_B_transposed);
+ assert(FixedFormat || _B_transposed);
assert(_working_space);
int8_t *working_space_bytes = reinterpret_cast<int8_t *>(_working_space);
@@ -663,10 +878,17 @@ public:
const bool first_pass = (k0==0);
const bool last_pass = (kmax==_Ktotal);
+ // Bias is passed for the first pass only, except for dequantizefloat nomerge cases where it's the last pass.
+ const bool bias_pass = (std::is_same<OutputStage, DequantizeFloat>::value && !MergeStep) ? last_pass : first_pass;
+
// Figure out how many "K" the kernel will actually process.
unsigned int kern_k = roundup(kmax - k0, strategy::k_unroll());
- const Toi *b_ptr = _B_transposed + (rounded_width * _Ktotal * multi) + (k0 * rounded_width) + (start_x * kern_k);
+ const Toi *b_ptr = FixedFormat ?
+ reinterpret_cast<const Toi *>(this->_Bptr) + (multi * this->_B_multi_stride) +
+ ((start_x / get_stripe_width<strategy, FixedFormat>::get()) * this->_ldb) +
+ (k0 * get_stripe_width<strategy, FixedFormat>::get()) :
+ _B_transposed + (rounded_width * _Ktotal * multi) + (k0 * rounded_width) + (start_x * kern_k);
unsigned int batch = batch_0;
unsigned int start_row = (start - (batch_0 * window_per_batch)) * strategy::out_height();
@@ -698,25 +920,32 @@ public:
}
}
+ Tr *result_ptr = this->_Cptr + (batch * this->_C_batch_stride) + (multi * this->_C_multi_stride);
+
+ // If we are using an accumulation buffer and this isn't the last pass, don't pass a result pointer.
+ if (_accumulation_buffer && !last_pass) {
+ result_ptr = nullptr;
+ }
+
// Perform the kernel and merge step, either separately or together as required.
- kernel_and_merge<MergeStep, OutputStage>::run(
+ kernel_and_merge<MergeStep, FixedFormat, OutputStage>::run(
#ifdef CYCLE_PROFILING
prof,
#endif
// Strategy and panel pointers
- strat, a_panel, b_ptr, c_panel,
+ strat, a_panel, b_ptr, this->_ldb, c_panel,
// Result buffer pointers
- this->_Cptr + (batch * this->_C_batch_stride) + (multi * this->_C_multi_stride), this->_ldc,
+ result_ptr, this->_ldc,
// K size, and M/N ranges
kern_k, start_row, end_row, start_x, end_x,
// Only do bias on the first pass
- ((first_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr),
+ ((bias_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr),
// Only do activation on the last pass, and accumulation on any non-first pass.
- (last_pass ? _act : Activation()), !first_pass,
+ (last_pass ? _act : Activation()), (!first_pass || _accumulate),
// Pass in quantization parameters for requantizing kernels (others will ignore)
_os, col_bias + (multi * _Nsize),
- // Accumulation buffer (not yet implemented on this path)
- static_cast<Tab *>(nullptr));
+ // Accumulation buffer
+ get_accumulation_buffer(start_row, start_x, batch, multi));
/* Increment to the next block */
start_row += strategy::out_height();
@@ -802,6 +1031,13 @@ public:
}
}
+ // For FixedFormat cases, figure out the B pointer. The loop below moves through batches and vertically through the output so this will be the same throughout.
+ if (FixedFormat) {
+ b_panel = reinterpret_cast<const Toi *>(this->_Bptr) + (current.multi() * this->_B_multi_stride) +
+ ((current.x0() / get_stripe_width<strategy, FixedFormat>::get()) * this->_ldb) +
+ (current.k0() * get_stripe_width<strategy, FixedFormat>::get());
+ }
+
/* Do the actual work. */
for (unsigned int batch = batch_0; batch <= batch_end; batch++) {
unsigned int first_m = (batch == batch_0) ? m_0 : 0;
@@ -830,6 +1066,9 @@ public:
const bool first_pass = (current.k0() == 0);
const bool last_pass = (current.kmax() == _Ktotal);
+ // Bias is passed for the first pass only, except for dequantizefloat nomerge cases where it's the last pass.
+ const bool bias_pass = (std::is_same<OutputStage, DequantizeFloat>::value && !MergeStep) ? last_pass : first_pass;
+
// Pointer to appropriate part of result array.
Tr *result_ptr = this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride);
@@ -840,20 +1079,20 @@ public:
}
// Perform the kernel and merge step, either separately or together as required.
- kernel_and_merge<MergeStep, OutputStage>::run(
+ kernel_and_merge<MergeStep, FixedFormat, OutputStage>::run(
#ifdef CYCLE_PROFILING
prof,
#endif
// Strategy and panel pointers
- strat, a_ptr, b_panel, c_panel,
+ strat, a_ptr, b_panel, this->_ldb, c_panel,
// Result buffer pointers
result_ptr, this->_ldc,
// K size, and M/N ranges
kern_k, y, ymax, current.x0(), current.xmax(),
// Only do bias on the first pass
- ((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr),
+ ((bias_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr),
// Only do activation on the last pass, and accumulation on any non-first pass.
- (last_pass ? _act : Activation()), !first_pass,
+ (last_pass ? _act : Activation()), (!first_pass || _accumulate),
// Pass in quantization parameters for requantizing kernels (others will ignore)
_os, col_bias + (current.multi() * _Nsize),
// Accumulation buffer
@@ -863,7 +1102,9 @@ public:
}
}
- b_panel += (roundup(current.xmax() - current.x0(), strategy::out_width()) * kern_k);
+ if (FixedFormat == false) {
+ b_panel += (roundup(current.xmax() - current.x0(), strategy::out_width()) * kern_k);
+ }
}
}
}
@@ -910,20 +1151,31 @@ public:
// Interface implementation - pretransposed
bool B_is_pretransposed() const override {
- return true;
+ return (FixedFormat == false);
}
bool B_pretranspose_required() const override {
- return (_B_transposed==nullptr);
+ return (FixedFormat == false) && (_B_transposed==nullptr);
}
size_t get_B_pretransposed_array_size() const override {
+ if (FixedFormat) {
+ return 0;
+ }
+
unsigned int x_size = roundup(_Nsize, strategy::out_width());
return (x_size * _Ktotal * _nmulti * sizeof(Toi)) + get_col_sum_size();
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ size_t get_B_pretranspose_window_size() const override {
+ size_t n_blocks = iceildiv(_Nsize, _x_block);
+ size_t k_blocks = iceildiv(_Ktotal, _k_block);
+
+ return n_blocks * k_blocks * _nmulti;
+ }
+
+ void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
if (std::is_same<OutputStage, Requantize32>::value) {
col_bias = reinterpret_cast<int32_t *>(in_buffer);
@@ -934,8 +1186,26 @@ public:
compute_col_sums(*qp_ptr, _Nsize, _Ksize * _Ksections, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize * _Ksections, i, 0);
}
}
+ }
+
+ // Support for transposed B is a property of the strategy::transpose type
+ bool B_pretranspose_supports_transpose() const override {
+ typename transform_type<strategy, MergeStep && std::is_same<OutputStage, Requantize32>::value>::type transforms;
+
+ return transforms.PrepareB_supports_transpose();
+ }
- // Put the transposed data after the column sums - in non-transposing cases get_col_sum_size() == 0
+ void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed) override {
+ pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, transposed, 0, get_B_pretranspose_window_size());
+ }
+
+ void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed, size_t start, size_t end) override {
+ // Perform column sums etc as part of the last block.
+ if (end >= get_B_pretranspose_window_size()) {
+ requantize_bias(in_buffer, B, ldb, B_multi_stride);
+ }
+
+ // Put the transposed data after the column sums - in non-quantized cases get_col_sum_size() == 0
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
Toi *buffer = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size());
_B_transposed = buffer;
@@ -943,57 +1213,84 @@ public:
blockwalker current(*this);
strategy strat(_ci);
- do {
+ // Skip over blocks we aren't doing
+ for(size_t i = 0; i < start; i++) {
+ buffer += roundup(current.xmax() - current.x0(), strategy::out_width()) * roundup(current.kmax() - current.k0(), strategy::k_unroll());
+ current.advance();
+ }
+
+ size_t blocks_left = (end - start);
+
+ // Double check that we haven't run out of work
+ if (current.done()) {
+ blocks_left = 0;
+ }
+
+ for (/* blocks_left initialized above */; blocks_left > 0; blocks_left--) {
/* Figure out the size of each block. */
unsigned int k_size = (current.kmax() - current.k0());
- // We need to insert padding at the end of each K section.
- // The computation needed is a little delicate - the coordinates from the block walker are expressed in
- // terms of the full, padded, _Ktotal.
- // But we need to transform each section with reference to the original, unpadded, input, letting the
- // transform pad each section as needed.
+ if (_Ksections > 1) {
+ // We need to insert padding at the end of each K section.
+ // The computation needed is a little delicate - the coordinates from the block walker are expressed in
+ // terms of the full, padded, _Ktotal.
+ // But we need to transform each section with reference to the original, unpadded, input, letting the
+ // transform pad each section as needed.
- // This is needed for computations below.
- const unsigned int rounded_section_size = roundup(_Ksize, strategy::k_unroll());
+ // This is needed for computations below.
+ const unsigned int rounded_section_size = roundup(_Ksize, strategy::k_unroll());
- // The expected output format is also an entire <out_width> columns interleaved, then the next set of
- // columns, and so on. This means, as we are breaking it up vertically, we have to do it one column at
- // a time.
- for (unsigned int x0=current.x0(); x0 < current.xmax(); x0 += strategy::out_width() ){
- unsigned int xmax = std::min(x0 + strategy::out_width(), current.xmax());
+ // The expected output format is also an entire <out_width> columns interleaved, then the next set of
+ // columns, and so on. This means, as we are breaking it up vertically, we have to do it one column at
+ // a time.
+ for (unsigned int x0=current.x0(); x0 < current.xmax(); x0 += strategy::out_width() ) {
+ unsigned int xmax = std::min(x0 + strategy::out_width(), current.xmax());
- // Track where we are and how much work is left.
- unsigned int kpos = current.k0();
- unsigned int kleft = k_size;
+ // Track where we are and how much work is left.
+ unsigned int kpos = current.k0();
+ unsigned int kleft = k_size;
- while (kleft) {
- // Which section are we in? Based on the rounded-up section size.
- unsigned int k_section_base = kpos / rounded_section_size;
- // How far into the section are we?
- unsigned int k_offset = kpos - (k_section_base * rounded_section_size);
+ while (kleft) {
+ // Which section are we in? Based on the rounded-up section size.
+ unsigned int k_section_base = kpos / rounded_section_size;
+ // How far into the section are we?
+ unsigned int k_offset = kpos - (k_section_base * rounded_section_size);
- // We will either copy the rest of this section, or to the end of the requested length.
- unsigned int k_length = std::min(_Ksize - k_offset, kleft);
+ // We will either copy the rest of this section, or to the end of the requested length.
+ unsigned int k_length = std::min(_Ksize - k_offset, kleft);
- strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb,
- x0, xmax,
- (k_section_base * _Ksize) + k_offset, // K starting point - compute row to read based on our section and the true section length.
- (k_section_base * _Ksize) + k_offset + k_length); // K end point - starting point plus length computed above.
+ strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb,
+ x0, xmax,
+ (k_section_base * _Ksize) + k_offset, // K starting point - compute row to read based on our section and the true section length.
+ (k_section_base * _Ksize) + k_offset + k_length, // K end point - starting point plus length computed above.
+ transposed);
- // We need to modify our position based on the ROUNDED version of what we just did.
- unsigned int padded_length = roundup(k_length, strategy::k_unroll());
+ // We need to modify our position based on the ROUNDED version of what we just did.
+ unsigned int padded_length = roundup(k_length, strategy::k_unroll());
- buffer += strategy::out_width() * padded_length;
+ buffer += strategy::out_width() * padded_length;
- kpos += padded_length;
- kleft -= padded_length;
+ kpos += padded_length;
+ kleft -= padded_length;
+ }
}
+ } else {
+ // In the single K section case, can process the whole lot in one go.
+ // Caution: 'blockwalker::kmax()' rounds up, so clamp to valid _Ksize.
+ strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb,
+ current.x0(), current.xmax(), current.k0(), std::min(current.kmax(), _Ksize), transposed);
+ buffer += roundup(current.xmax() - current.x0(), strategy::out_width()) * roundup(current.kmax() - current.k0(), strategy::k_unroll());
}
- } while (current.advance());
+
+ // Advance to the next block, break if we run off the end.
+ if (!current.advance()) {
+ break;
+ }
+ }
}
void set_pretransposed_B_data(void *in_buffer) override {
- // Put the transposed data after the column sums - in non-transposing cases get_col_sum_size() == 0
+ // Put the transposed data after the column sums - in non-quantized cases get_col_sum_size() == 0
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
_B_transposed = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size());
col_bias = reinterpret_cast<int32_t *>(in_buffer);
@@ -1008,6 +1305,13 @@ public:
}
}
+ void set_dequantize_scale(const float scale) override {
+ if(std::is_same<OutputStage, DequantizeFloat>::value) {
+ DequantizeFloat* df = reinterpret_cast<DequantizeFloat *>(&_os);
+ df->scale = scale;
+ }
+ }
+
void set_indirect_parameters(size_t string_len, const To * const * const *ptr) override {
assert(string_len == _Ksize);
_indirect_buf = ptr;
@@ -1019,12 +1323,15 @@ public:
}
// Estimate cycles for given problem given provided parameters
- static uint64_t estimate_cycles(const GemmArgs &args, const PerformanceParameters &params) {
+ template<typename perf_type>
+ static uint64_t estimate_cycles(const GemmArgs &args) {
unsigned int k_blocks = iceildiv(args._Ksize, get_k_block_size(args));
+ const PerformanceParameters &params = strategy::template get_performance_parameters<perf_type>(args._ci);
+
uint64_t total_macs = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * roundup(args._Nsize, strategy::out_width()) * get_ktotal(args);
uint64_t prepare_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * get_ktotal(args) * sizeof(Toi);
- uint64_t merge_bytes = static_cast<uint16_t>(args._nbatches) * args._nmulti * k_blocks * roundup(args._Msize, strategy::out_height()) * roundup(args._Nsize, strategy::out_width()) * sizeof(Tr);
+ uint64_t merge_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * k_blocks * args._Msize * roundup(args._Nsize, strategy::out_width()) * sizeof(Tr);
float mac_cycles = static_cast<float>(total_macs) / params.kernel_macs_cycle;
float prepare_cycles = static_cast<float>(prepare_bytes) / params.prepare_bytes_cycle;
@@ -1042,16 +1349,37 @@ public:
return static_cast<uint64_t>(total_cycles);
}
+
+ GemmConfig get_config() override {
+ GemmConfig c;
+
+ c.method = GemmMethod::GEMM_INTERLEAVED;
+ c.inner_block_size = _k_block;
+ c.outer_block_size = _x_block;
+ c.filter = get_type_name<strategy>();
+ c.weight_format = get_weight_format(get_kernel_weight_format<strategy, FixedFormat, To>::get(), sizeof(To));
+
+ return c;
+ }
};
// Aliases for the variations
template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing>
using GemmInterleavedNoMerge = GemmInterleaved<strategy, To, Tr, OutputStage, false>;
+template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing>
+using GemmInterleavedFixedFormat = GemmInterleaved<strategy, To, Tr, OutputStage, true, true>;
+
template<typename strategy, typename To, typename Tr>
using GemmInterleavedPretransposedNoMergeQuantizedInline = GemmInterleaved<strategy, To, Tr, Requantize32, false>;
template<typename strategy, typename To, typename Tr>
using GemmInterleavedQuantized = GemmInterleaved<strategy, To, Tr, Requantize32>;
+template<typename strategy, typename To, typename Tr>
+using GemmInterleavedNoMergeDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat, false>;
+
+template<typename strategy, typename To, typename Tr>
+using GemmInterleavedDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat>;
+
} // namespace arm_gemm