aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp133
1 files changed, 67 insertions, 66 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 897ec9d05f..5214a71cce 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -62,12 +62,12 @@ namespace {
template<bool MergeStep, bool FixedFormat, typename OutputStage>
class kernel_and_merge {
public:
- template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+ template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab>
static void run (
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, size_t b_stride, Tri *c_panel,
+ strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t b_stride, Tri *c_panel,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr,
const Activation &act, bool accumulate, const OutputStage &os, const int32_t *col_bias,
@@ -76,12 +76,12 @@ public:
// Run a kernel and call the separate merge step
template<>
-template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab>
void kernel_and_merge<true, false, Nothing>::run(
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel,
+ strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *c_panel,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr,
const Activation &act, bool accumulate, const Nothing &, const int32_t *, Tab *)
@@ -106,12 +106,12 @@ void kernel_and_merge<true, false, Nothing>::run(
// Run a fixed-format kernel and call the separate merge step
template<>
-template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab>
void kernel_and_merge<true, true, Nothing>::run(
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, size_t b_stride, Tri *c_panel,
+ strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t b_stride, Tri *c_panel,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr,
const Activation &act, bool accumulate, const Nothing &, const int32_t *, Tab *)
@@ -136,12 +136,12 @@ void kernel_and_merge<true, true, Nothing>::run(
// Run a kernel with integrated merge
template<>
-template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab>
void kernel_and_merge<false, false, Nothing>::run(
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *,
+ strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max,
unsigned int n_0, unsigned int n_max, const Tr *biasptr,
const Activation &act, bool accumulate, const Nothing &, const int32_t *,
@@ -175,12 +175,12 @@ void kernel_and_merge<false, false, Nothing>::run(
// Run a kernel with integrated merge, quantizing
template<>
-template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab>
void kernel_and_merge<false, false, Requantize32>::run(
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *,
+ strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max,
unsigned int n_0, unsigned int n_max, const Tr *,
const Activation &, bool accumulate, const Requantize32 &qp, const int32_t *col_bias,
@@ -211,12 +211,12 @@ void kernel_and_merge<false, false, Requantize32>::run(
// Run a kernel and call the separate quantize step
template<>
-template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab>
void kernel_and_merge<true, false, Requantize32>::run(
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel,
+ strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *c_panel,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *,
const Activation &, bool, const Requantize32 &qp, const int32_t *col_bias,
@@ -257,12 +257,12 @@ void kernel_and_merge<true, false, Requantize32>::run(
// Run a kernel with integrated merge, dequantizing to FP32
template<>
-template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab>
void kernel_and_merge<false, false, DequantizeFloat>::run(
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *,
+ strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max,
unsigned int n_0, unsigned int n_max, const Tr *bias,
const Activation &act, bool accumulate, const DequantizeFloat &dq, const int32_t *col_bias,
@@ -294,12 +294,12 @@ void kernel_and_merge<false, false, DequantizeFloat>::run(
}
template<>
-template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab>
void kernel_and_merge<true, false, DequantizeFloat>::run(
#ifdef CYCLE_PROFILING
profiler &prof,
#endif
- strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel,
+ strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *c_panel,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *bias,
const Activation &act, bool accumulate, const DequantizeFloat &qp, const int32_t *,
@@ -394,21 +394,21 @@ struct get_stripe_width<strategy, true> {
};
// KernelWeightFormat is a similar story.
-template<typename strategy, bool FixedFormat, typename To>
+template<typename strategy, bool FixedFormat, typename Tro>
struct get_kernel_weight_format {
static KernelWeightFormat get() {
return KernelWeightFormat::NON_FIXED;
}
};
-template<typename strategy, typename To>
-struct get_kernel_weight_format<strategy, true, To> {
+template<typename strategy, typename Tro>
+struct get_kernel_weight_format<strategy, true, Tro> {
static KernelWeightFormat get() {
KernelWeightFormat kwf = strategy::kernel_weight_format();
// If we are using a BF16 kernel to do an FP32 problem (fast mode) then we need to set the BF16 flag on the
// weight format.
- if (std::is_same<To, float>::value && std::is_same<typename strategy::operand_type, bfloat16>::value) {
+ if (std::is_same<Tro, float>::value && std::is_same<typename strategy::rhs_operand_type, bfloat16>::value) {
uint32_t kwf_i = static_cast<uint32_t>(kwf);
kwf_i |= 0x10;
kwf = static_cast<KernelWeightFormat>(kwf_i);
@@ -420,9 +420,10 @@ struct get_kernel_weight_format<strategy, true, To> {
} // anonymous namespace
-template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool FixedFormat=false, bool ForceThreadColumns=false, bool ForceFloatAccumulate=false>
-class GemmInterleaved : public GemmCommon<To, Tr> {
- typedef typename strategy::operand_type Toi;
+template<typename strategy, typename Tlo, typename Tro, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool FixedFormat=false, bool ForceThreadColumns=false, bool ForceFloatAccumulate=false>
+class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> {
+ typedef typename strategy::lhs_operand_type Tloi;
+ typedef typename strategy::rhs_operand_type Troi;
typedef typename strategy::result_type Tri;
typedef typename accumulate_buffer_type<strategy, OutputStage, ForceFloatAccumulate>::type Tab;
@@ -453,7 +454,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
unsigned int _Mround=0;
/* Working space, pretransposed buffer, buffer manager */
- const Toi *_B_transposed=nullptr;
+ const Troi *_B_transposed=nullptr;
void *_working_space=nullptr;
Tab *_accumulation_buffer=nullptr;
@@ -465,10 +466,10 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
int32_t *col_bias = nullptr;
/* Indirect parameters. _indirect_buf doubles as a flag to indicate that "indirect" transform should be used. */
- const To * const * const * _indirect_buf = nullptr;
+ const Tlo * const * const * _indirect_buf = nullptr;
/* Convolver - only set up for convolution problems, so also doubles as a flag. */
- std::unique_ptr<convolver<To>> _convolver = nullptr;
+ std::unique_ptr<convolver<Tlo>> _convolver = nullptr;
unsigned int get_col_sum_size() const {
if (std::is_same<OutputStage, Requantize32>::value) {
@@ -483,7 +484,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
class blockwalker {
private:
/* Size loops, etc. based on our parent's configuration */
- const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &_parent;
+ const GemmInterleaved<strategy, Tlo, Tro, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &_parent;
/* K, X and multi parameters for current iteration. */
unsigned int _k0=0, _x0=0, _multi=0;
@@ -498,9 +499,9 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
bool _newmulti=true;
public:
- blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent) : _parent(parent) { }
+ blockwalker(const GemmInterleaved<strategy, Tlo, Tro, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent) : _parent(parent) { }
- blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent,
+ blockwalker(const GemmInterleaved<strategy, Tlo, Tro, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent,
unsigned int x_start, unsigned int x_end) : _parent(parent), _x0 (_x_start), _x_start(x_start), _x_end(x_end) { }
unsigned int xmax() {
@@ -554,7 +555,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
unsigned int k_depth = _k_block;
if (std::is_same<OutputStage, Requantize32>::value) {
- k_depth += sizeof(int32_t) / sizeof(Toi);
+ k_depth += sizeof(int32_t) / sizeof(Tloi);
}
return k_depth;
@@ -564,10 +565,10 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
size_t get_a_working_size() const {
if (_thread_columns) {
// For 2D threading: allocate a buffer of one block of rows per thread
- return ROUND_UP(sizeof(Toi) * get_total_k_depth() * strategy::out_height() * _maxthreads);
+ return ROUND_UP(sizeof(Tloi) * get_total_k_depth() * strategy::out_height() * _maxthreads);
} else {
// For 1D threaded: one of these needed, regardless of thread count. Divided according to window.
- return ROUND_UP(sizeof(Toi) * get_total_k_depth() * _Mround * _nbatches);
+ return ROUND_UP(sizeof(Tloi) * get_total_k_depth() * _Mround * _nbatches);
}
}
@@ -692,7 +693,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
}
// Don't bother to block below this size threshold (1.25X target size)
- unsigned int scaling_threshold = ((target_bytes_per_block * 5) / 4) / sizeof(Toi);
+ unsigned int scaling_threshold = ((target_bytes_per_block * 5) / 4) / sizeof(Tloi);
if (get_ktotal(args) <= scaling_threshold) {
return get_ktotal(args);
@@ -700,7 +701,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
// Once we are blocking, this (lower) threshold determines when we should use more blocks
// NOTE: Could be that some factor-based solution would work better here.
- unsigned int max_block_size = target_bytes_per_block / sizeof(Toi);
+ unsigned int max_block_size = target_bytes_per_block / sizeof(Tloi);
unsigned int num_k_blocks = iceildiv(get_ktotal(args), max_block_size);
@@ -713,7 +714,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
// k_block: Find out how much of the larger array can be loaded into half the cache.
// This should account for associative caches.
- k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height())));
+ k_block = (L1_size / 2) / (sizeof(Tloi) * (std::max(strategy::out_width(), strategy::out_height())));
// Needs to be (at least a single) multiple of the K unroll level.
k_block /= strategy::k_unroll();
@@ -761,14 +762,14 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
// x_block: Work out how many rows (of length k_block) will fit in the L2
// Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents.
const unsigned int scaled_l2_size = (L2_size * 9) / 10;
- const unsigned int k_block_area = k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height());
+ const unsigned int k_block_area = k_block * sizeof(Tloi) * (strategy::out_width() + strategy::out_height());
// .. if the L1 contents is bigger than the L2, just return a minimal size block.
if (k_block_area > scaled_l2_size) {
return strategy::out_width();
}
- x_block = (scaled_l2_size - k_block_area) / (sizeof(Toi) * k_block);
+ x_block = (scaled_l2_size - k_block_area) / (sizeof(Tloi) * k_block);
// Needs to be (at least a single) multiple of the kernel output width.
x_block /= strategy::out_width();
@@ -866,8 +867,8 @@ public:
const auto end_x = std::min(work_range.get_position_end(1) * strategy::out_width(), _Nsize);
Tri * const c_panel = reinterpret_cast<Tri *>(working_space_bytes + (threadid * get_c_working_size()));
- Toi * const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size()) +
- (threadid * sizeof(Toi) * get_total_k_depth() * strategy::out_height()));
+ Tloi * const a_panel = reinterpret_cast<Tloi *>(working_space_bytes + (_maxthreads * get_c_working_size()) +
+ (threadid * sizeof(Tloi) * get_total_k_depth() * strategy::out_height()));
for (unsigned int multi=0; multi<_nmulti; multi++) {
for (unsigned int k0=0; k0<_Ktotal; k0+=_k_block) {
@@ -884,8 +885,8 @@ public:
// Figure out how many "K" the kernel will actually process.
unsigned int kern_k = roundup(kmax - k0, strategy::k_unroll());
- const Toi *b_ptr = FixedFormat ?
- reinterpret_cast<const Toi *>(this->_Bptr) + (multi * this->_B_multi_stride) +
+ const Troi *b_ptr = FixedFormat ?
+ reinterpret_cast<const Troi *>(this->_Bptr) + (multi * this->_B_multi_stride) +
((start_x / get_stripe_width<strategy, FixedFormat>::get()) * this->_ldb) +
(k0 * get_stripe_width<strategy, FixedFormat>::get()) :
_B_transposed + (rounded_width * _Ktotal * multi) + (k0 * rounded_width) + (start_x * kern_k);
@@ -899,7 +900,7 @@ public:
// Set up transposed 'A' block
{
#ifdef CYCLE_PROFILING
- auto p=prof.ScopedProfiler(PROFILE_PREPA, strategy::out_height() * (kmax-k0) * sizeof(Toi));
+ auto p=prof.ScopedProfiler(PROFILE_PREPA, strategy::out_height() * (kmax-k0) * sizeof(Tloi));
#endif
// See comment above on transform_type<> class: this extracts either 'transforms' or
// 'transforms_quantized' as appropriate.
@@ -967,10 +968,10 @@ public:
// (one per thread) first, followed by the (window-divided) A
// buffer.
// Set a_panel to the base of the A buffers - compute offsets into it based on M/batches later.
- Toi * const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size()));
+ Tloi * const a_panel = reinterpret_cast<Tloi *>(working_space_bytes + (_maxthreads * get_c_working_size()));
Tri * const c_panel = reinterpret_cast<Tri *>(working_space_bytes + (threadid * get_c_working_size()));
- const Toi *b_panel;
+ const Troi *b_panel;
b_panel = _B_transposed;
// newkblock() is always true on the first iteration, so these will be set properly on the first loop.
@@ -989,7 +990,7 @@ public:
for (;!current.done();current.advance()) {
if (current.newkblock()) {
#ifdef CYCLE_PROFILING
- auto p=prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height() * (current.kmax()-current.k0()) * sizeof(Toi));
+ auto p=prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height() * (current.kmax()-current.k0()) * sizeof(Tloi));
#endif
// See comment above on transform_type<> class: this extracts either 'transforms' or
// 'transforms_quantized' as appropriate.
@@ -1025,7 +1026,7 @@ public:
// larger than the (rounded) K value.
if(std::is_same<OutputStage, Requantize32>::value) {
- a_panel_stride = kern_k + (sizeof(int32_t) / sizeof(Toi));
+ a_panel_stride = kern_k + (sizeof(int32_t) / sizeof(Tloi));
} else {
a_panel_stride = kern_k;
}
@@ -1033,7 +1034,7 @@ public:
// For FixedFormat cases, figure out the B pointer. The loop below moves through batches and vertically through the output so this will be the same throughout.
if (FixedFormat) {
- b_panel = reinterpret_cast<const Toi *>(this->_Bptr) + (current.multi() * this->_B_multi_stride) +
+ b_panel = reinterpret_cast<const Troi *>(this->_Bptr) + (current.multi() * this->_B_multi_stride) +
((current.x0() / get_stripe_width<strategy, FixedFormat>::get()) * this->_ldb) +
(current.k0() * get_stripe_width<strategy, FixedFormat>::get());
}
@@ -1043,7 +1044,7 @@ public:
unsigned int first_m = (batch == batch_0) ? m_0 : 0;
unsigned int last_m = (batch == batch_end) ? m_max : _Msize;
- const Toi *a_ptr = a_panel + (batch * _Mround + first_m) * get_total_k_depth();
+ const Tloi *a_ptr = a_panel + (batch * _Mround + first_m) * get_total_k_depth();
if (first_m >= last_m)
continue;
@@ -1165,7 +1166,7 @@ public:
unsigned int x_size = roundup(_Nsize, strategy::out_width());
- return (x_size * _Ktotal * _nmulti * sizeof(Toi)) + get_col_sum_size();
+ return (x_size * _Ktotal * _nmulti * sizeof(Troi)) + get_col_sum_size();
}
size_t get_B_pretranspose_window_size() const override {
@@ -1175,7 +1176,7 @@ public:
return n_blocks * k_blocks * _nmulti;
}
- void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ void requantize_bias(void *in_buffer, const Tro *B, const int ldb, const int B_multi_stride) override {
if (std::is_same<OutputStage, Requantize32>::value) {
col_bias = reinterpret_cast<int32_t *>(in_buffer);
@@ -1195,11 +1196,11 @@ public:
return transforms.PrepareB_supports_transpose();
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed) override {
+ void pretranspose_B_array(void *in_buffer, const Tro *B, const int ldb, const int B_multi_stride, const bool transposed) override {
pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, transposed, 0, get_B_pretranspose_window_size());
}
- void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed, size_t start, size_t end) override {
+ void pretranspose_B_array_part(void *in_buffer, const Tro *B, const int ldb, const int B_multi_stride, const bool transposed, size_t start, size_t end) override {
// Perform column sums etc as part of the last block.
if (end >= get_B_pretranspose_window_size()) {
requantize_bias(in_buffer, B, ldb, B_multi_stride);
@@ -1207,7 +1208,7 @@ public:
// Put the transposed data after the column sums - in non-quantized cases get_col_sum_size() == 0
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
- Toi *buffer = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size());
+ Troi *buffer = reinterpret_cast<Troi *>(buffer_int + get_col_sum_size());
_B_transposed = buffer;
blockwalker current(*this);
@@ -1292,7 +1293,7 @@ public:
void set_pretransposed_B_data(void *in_buffer) override {
// Put the transposed data after the column sums - in non-quantized cases get_col_sum_size() == 0
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
- _B_transposed = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size());
+ _B_transposed = reinterpret_cast<Troi *>(buffer_int + get_col_sum_size());
col_bias = reinterpret_cast<int32_t *>(in_buffer);
}
@@ -1312,14 +1313,14 @@ public:
}
}
- void set_indirect_parameters(size_t string_len, const To * const * const *ptr) override {
+ void set_indirect_parameters(size_t string_len, const Tlo * const * const *ptr) override {
assert(string_len == _Ksize);
_indirect_buf = ptr;
}
void set_convolution_parameters(ConvolutionParameters parms) override {
assert(parms.input_channels == _Ksize);
- _convolver = std::unique_ptr<convolver<To>>(new convolver<To>(parms));
+ _convolver = std::unique_ptr<convolver<Tlo>>(new convolver<Tlo>(parms));
}
// Estimate cycles for given problem given provided parameters
@@ -1330,7 +1331,7 @@ public:
const PerformanceParameters &params = strategy::template get_performance_parameters<perf_type>(args._ci);
uint64_t total_macs = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * roundup(args._Nsize, strategy::out_width()) * get_ktotal(args);
- uint64_t prepare_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * get_ktotal(args) * sizeof(Toi);
+ uint64_t prepare_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * get_ktotal(args) * sizeof(Tloi);
uint64_t merge_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * k_blocks * args._Msize * roundup(args._Nsize, strategy::out_width()) * sizeof(Tr);
float mac_cycles = static_cast<float>(total_macs) / params.kernel_macs_cycle;
@@ -1357,7 +1358,7 @@ public:
c.inner_block_size = _k_block;
c.outer_block_size = _x_block;
c.filter = get_type_name<strategy>();
- c.weight_format = get_weight_format(get_kernel_weight_format<strategy, FixedFormat, To>::get(), sizeof(To));
+ c.weight_format = get_weight_format(get_kernel_weight_format<strategy, FixedFormat, Tro>::get(), sizeof(Tro));
return c;
}
@@ -1365,21 +1366,21 @@ public:
// Aliases for the variations
template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing>
-using GemmInterleavedNoMerge = GemmInterleaved<strategy, To, Tr, OutputStage, false>;
+using GemmInterleavedNoMerge = GemmInterleaved<strategy, To, To, Tr, OutputStage, false>;
template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing>
-using GemmInterleavedFixedFormat = GemmInterleaved<strategy, To, Tr, OutputStage, true, true>;
+using GemmInterleavedFixedFormat = GemmInterleaved<strategy, To, To, Tr, OutputStage, true, true>;
template<typename strategy, typename To, typename Tr>
-using GemmInterleavedPretransposedNoMergeQuantizedInline = GemmInterleaved<strategy, To, Tr, Requantize32, false>;
+using GemmInterleavedPretransposedNoMergeQuantizedInline = GemmInterleaved<strategy, To, To, Tr, Requantize32, false>;
-template<typename strategy, typename To, typename Tr>
-using GemmInterleavedQuantized = GemmInterleaved<strategy, To, Tr, Requantize32>;
+template<typename strategy, typename Tlo, typename Tro, typename Tr>
+using GemmInterleavedQuantized = GemmInterleaved<strategy, Tlo, Tro, Tr, Requantize32>;
template<typename strategy, typename To, typename Tr>
-using GemmInterleavedNoMergeDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat, false>;
+using GemmInterleavedNoMergeDequantized = GemmInterleaved<strategy, To, To, Tr, DequantizeFloat, false>;
-template<typename strategy, typename To, typename Tr>
-using GemmInterleavedDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat>;
+template<typename strategy, typename Tlo, typename Tro, typename Tr>
+using GemmInterleavedDequantized = GemmInterleaved<strategy, Tlo, Tro, Tr, DequantizeFloat>;
} // namespace arm_gemm