diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp | 133 |
1 files changed, 67 insertions, 66 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index 897ec9d05f..5214a71cce 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -62,12 +62,12 @@ namespace { template<bool MergeStep, bool FixedFormat, typename OutputStage> class kernel_and_merge { public: - template<typename strategy, typename To, typename Tr, typename Tri, typename Tab> + template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab> static void run ( #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, size_t b_stride, Tri *c_panel, + strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t b_stride, Tri *c_panel, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr, const Activation &act, bool accumulate, const OutputStage &os, const int32_t *col_bias, @@ -76,12 +76,12 @@ public: // Run a kernel and call the separate merge step template<> -template<typename strategy, typename To, typename Tr, typename Tri, typename Tab> +template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab> void kernel_and_merge<true, false, Nothing>::run( #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel, + strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *c_panel, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr, const Activation &act, bool accumulate, const Nothing &, const int32_t *, Tab *) @@ -106,12 +106,12 @@ void kernel_and_merge<true, false, Nothing>::run( // Run a fixed-format kernel and call the separate merge step template<> -template<typename strategy, typename To, typename Tr, typename Tri, typename Tab> +template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab> void kernel_and_merge<true, true, Nothing>::run( #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, size_t b_stride, Tri *c_panel, + strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t b_stride, Tri *c_panel, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr, const Activation &act, bool accumulate, const Nothing &, const int32_t *, Tab *) @@ -136,12 +136,12 @@ void kernel_and_merge<true, true, Nothing>::run( // Run a kernel with integrated merge template<> -template<typename strategy, typename To, typename Tr, typename Tri, typename Tab> +template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab> void kernel_and_merge<false, false, Nothing>::run( #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *, + strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr, const Activation &act, bool accumulate, const Nothing &, const int32_t *, @@ -175,12 +175,12 @@ void kernel_and_merge<false, false, Nothing>::run( // Run a kernel with integrated merge, quantizing template<> -template<typename strategy, typename To, typename Tr, typename Tri, typename Tab> +template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab> void kernel_and_merge<false, false, Requantize32>::run( #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *, + strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *, const Activation &, bool accumulate, const Requantize32 &qp, const int32_t *col_bias, @@ -211,12 +211,12 @@ void kernel_and_merge<false, false, Requantize32>::run( // Run a kernel and call the separate quantize step template<> -template<typename strategy, typename To, typename Tr, typename Tri, typename Tab> +template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab> void kernel_and_merge<true, false, Requantize32>::run( #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel, + strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *c_panel, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *, const Activation &, bool, const Requantize32 &qp, const int32_t *col_bias, @@ -257,12 +257,12 @@ void kernel_and_merge<true, false, Requantize32>::run( // Run a kernel with integrated merge, dequantizing to FP32 template<> -template<typename strategy, typename To, typename Tr, typename Tri, typename Tab> +template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab> void kernel_and_merge<false, false, DequantizeFloat>::run( #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *, + strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *bias, const Activation &act, bool accumulate, const DequantizeFloat &dq, const int32_t *col_bias, @@ -294,12 +294,12 @@ void kernel_and_merge<false, false, DequantizeFloat>::run( } template<> -template<typename strategy, typename To, typename Tr, typename Tri, typename Tab> +template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab> void kernel_and_merge<true, false, DequantizeFloat>::run( #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel, + strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *c_panel, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *bias, const Activation &act, bool accumulate, const DequantizeFloat &qp, const int32_t *, @@ -394,21 +394,21 @@ struct get_stripe_width<strategy, true> { }; // KernelWeightFormat is a similar story. -template<typename strategy, bool FixedFormat, typename To> +template<typename strategy, bool FixedFormat, typename Tro> struct get_kernel_weight_format { static KernelWeightFormat get() { return KernelWeightFormat::NON_FIXED; } }; -template<typename strategy, typename To> -struct get_kernel_weight_format<strategy, true, To> { +template<typename strategy, typename Tro> +struct get_kernel_weight_format<strategy, true, Tro> { static KernelWeightFormat get() { KernelWeightFormat kwf = strategy::kernel_weight_format(); // If we are using a BF16 kernel to do an FP32 problem (fast mode) then we need to set the BF16 flag on the // weight format. - if (std::is_same<To, float>::value && std::is_same<typename strategy::operand_type, bfloat16>::value) { + if (std::is_same<Tro, float>::value && std::is_same<typename strategy::rhs_operand_type, bfloat16>::value) { uint32_t kwf_i = static_cast<uint32_t>(kwf); kwf_i |= 0x10; kwf = static_cast<KernelWeightFormat>(kwf_i); @@ -420,9 +420,10 @@ struct get_kernel_weight_format<strategy, true, To> { } // anonymous namespace -template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool FixedFormat=false, bool ForceThreadColumns=false, bool ForceFloatAccumulate=false> -class GemmInterleaved : public GemmCommon<To, Tr> { - typedef typename strategy::operand_type Toi; +template<typename strategy, typename Tlo, typename Tro, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool FixedFormat=false, bool ForceThreadColumns=false, bool ForceFloatAccumulate=false> +class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> { + typedef typename strategy::lhs_operand_type Tloi; + typedef typename strategy::rhs_operand_type Troi; typedef typename strategy::result_type Tri; typedef typename accumulate_buffer_type<strategy, OutputStage, ForceFloatAccumulate>::type Tab; @@ -453,7 +454,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { unsigned int _Mround=0; /* Working space, pretransposed buffer, buffer manager */ - const Toi *_B_transposed=nullptr; + const Troi *_B_transposed=nullptr; void *_working_space=nullptr; Tab *_accumulation_buffer=nullptr; @@ -465,10 +466,10 @@ class GemmInterleaved : public GemmCommon<To, Tr> { int32_t *col_bias = nullptr; /* Indirect parameters. _indirect_buf doubles as a flag to indicate that "indirect" transform should be used. */ - const To * const * const * _indirect_buf = nullptr; + const Tlo * const * const * _indirect_buf = nullptr; /* Convolver - only set up for convolution problems, so also doubles as a flag. */ - std::unique_ptr<convolver<To>> _convolver = nullptr; + std::unique_ptr<convolver<Tlo>> _convolver = nullptr; unsigned int get_col_sum_size() const { if (std::is_same<OutputStage, Requantize32>::value) { @@ -483,7 +484,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { class blockwalker { private: /* Size loops, etc. based on our parent's configuration */ - const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &_parent; + const GemmInterleaved<strategy, Tlo, Tro, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &_parent; /* K, X and multi parameters for current iteration. */ unsigned int _k0=0, _x0=0, _multi=0; @@ -498,9 +499,9 @@ class GemmInterleaved : public GemmCommon<To, Tr> { bool _newmulti=true; public: - blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent) : _parent(parent) { } + blockwalker(const GemmInterleaved<strategy, Tlo, Tro, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent) : _parent(parent) { } - blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent, + blockwalker(const GemmInterleaved<strategy, Tlo, Tro, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent, unsigned int x_start, unsigned int x_end) : _parent(parent), _x0 (_x_start), _x_start(x_start), _x_end(x_end) { } unsigned int xmax() { @@ -554,7 +555,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { unsigned int k_depth = _k_block; if (std::is_same<OutputStage, Requantize32>::value) { - k_depth += sizeof(int32_t) / sizeof(Toi); + k_depth += sizeof(int32_t) / sizeof(Tloi); } return k_depth; @@ -564,10 +565,10 @@ class GemmInterleaved : public GemmCommon<To, Tr> { size_t get_a_working_size() const { if (_thread_columns) { // For 2D threading: allocate a buffer of one block of rows per thread - return ROUND_UP(sizeof(Toi) * get_total_k_depth() * strategy::out_height() * _maxthreads); + return ROUND_UP(sizeof(Tloi) * get_total_k_depth() * strategy::out_height() * _maxthreads); } else { // For 1D threaded: one of these needed, regardless of thread count. Divided according to window. - return ROUND_UP(sizeof(Toi) * get_total_k_depth() * _Mround * _nbatches); + return ROUND_UP(sizeof(Tloi) * get_total_k_depth() * _Mround * _nbatches); } } @@ -692,7 +693,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { } // Don't bother to block below this size threshold (1.25X target size) - unsigned int scaling_threshold = ((target_bytes_per_block * 5) / 4) / sizeof(Toi); + unsigned int scaling_threshold = ((target_bytes_per_block * 5) / 4) / sizeof(Tloi); if (get_ktotal(args) <= scaling_threshold) { return get_ktotal(args); @@ -700,7 +701,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { // Once we are blocking, this (lower) threshold determines when we should use more blocks // NOTE: Could be that some factor-based solution would work better here. - unsigned int max_block_size = target_bytes_per_block / sizeof(Toi); + unsigned int max_block_size = target_bytes_per_block / sizeof(Tloi); unsigned int num_k_blocks = iceildiv(get_ktotal(args), max_block_size); @@ -713,7 +714,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { // k_block: Find out how much of the larger array can be loaded into half the cache. // This should account for associative caches. - k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height()))); + k_block = (L1_size / 2) / (sizeof(Tloi) * (std::max(strategy::out_width(), strategy::out_height()))); // Needs to be (at least a single) multiple of the K unroll level. k_block /= strategy::k_unroll(); @@ -761,14 +762,14 @@ class GemmInterleaved : public GemmCommon<To, Tr> { // x_block: Work out how many rows (of length k_block) will fit in the L2 // Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents. const unsigned int scaled_l2_size = (L2_size * 9) / 10; - const unsigned int k_block_area = k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()); + const unsigned int k_block_area = k_block * sizeof(Tloi) * (strategy::out_width() + strategy::out_height()); // .. if the L1 contents is bigger than the L2, just return a minimal size block. if (k_block_area > scaled_l2_size) { return strategy::out_width(); } - x_block = (scaled_l2_size - k_block_area) / (sizeof(Toi) * k_block); + x_block = (scaled_l2_size - k_block_area) / (sizeof(Tloi) * k_block); // Needs to be (at least a single) multiple of the kernel output width. x_block /= strategy::out_width(); @@ -866,8 +867,8 @@ public: const auto end_x = std::min(work_range.get_position_end(1) * strategy::out_width(), _Nsize); Tri * const c_panel = reinterpret_cast<Tri *>(working_space_bytes + (threadid * get_c_working_size())); - Toi * const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size()) + - (threadid * sizeof(Toi) * get_total_k_depth() * strategy::out_height())); + Tloi * const a_panel = reinterpret_cast<Tloi *>(working_space_bytes + (_maxthreads * get_c_working_size()) + + (threadid * sizeof(Tloi) * get_total_k_depth() * strategy::out_height())); for (unsigned int multi=0; multi<_nmulti; multi++) { for (unsigned int k0=0; k0<_Ktotal; k0+=_k_block) { @@ -884,8 +885,8 @@ public: // Figure out how many "K" the kernel will actually process. unsigned int kern_k = roundup(kmax - k0, strategy::k_unroll()); - const Toi *b_ptr = FixedFormat ? - reinterpret_cast<const Toi *>(this->_Bptr) + (multi * this->_B_multi_stride) + + const Troi *b_ptr = FixedFormat ? + reinterpret_cast<const Troi *>(this->_Bptr) + (multi * this->_B_multi_stride) + ((start_x / get_stripe_width<strategy, FixedFormat>::get()) * this->_ldb) + (k0 * get_stripe_width<strategy, FixedFormat>::get()) : _B_transposed + (rounded_width * _Ktotal * multi) + (k0 * rounded_width) + (start_x * kern_k); @@ -899,7 +900,7 @@ public: // Set up transposed 'A' block { #ifdef CYCLE_PROFILING - auto p=prof.ScopedProfiler(PROFILE_PREPA, strategy::out_height() * (kmax-k0) * sizeof(Toi)); + auto p=prof.ScopedProfiler(PROFILE_PREPA, strategy::out_height() * (kmax-k0) * sizeof(Tloi)); #endif // See comment above on transform_type<> class: this extracts either 'transforms' or // 'transforms_quantized' as appropriate. @@ -967,10 +968,10 @@ public: // (one per thread) first, followed by the (window-divided) A // buffer. // Set a_panel to the base of the A buffers - compute offsets into it based on M/batches later. - Toi * const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size())); + Tloi * const a_panel = reinterpret_cast<Tloi *>(working_space_bytes + (_maxthreads * get_c_working_size())); Tri * const c_panel = reinterpret_cast<Tri *>(working_space_bytes + (threadid * get_c_working_size())); - const Toi *b_panel; + const Troi *b_panel; b_panel = _B_transposed; // newkblock() is always true on the first iteration, so these will be set properly on the first loop. @@ -989,7 +990,7 @@ public: for (;!current.done();current.advance()) { if (current.newkblock()) { #ifdef CYCLE_PROFILING - auto p=prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height() * (current.kmax()-current.k0()) * sizeof(Toi)); + auto p=prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height() * (current.kmax()-current.k0()) * sizeof(Tloi)); #endif // See comment above on transform_type<> class: this extracts either 'transforms' or // 'transforms_quantized' as appropriate. @@ -1025,7 +1026,7 @@ public: // larger than the (rounded) K value. if(std::is_same<OutputStage, Requantize32>::value) { - a_panel_stride = kern_k + (sizeof(int32_t) / sizeof(Toi)); + a_panel_stride = kern_k + (sizeof(int32_t) / sizeof(Tloi)); } else { a_panel_stride = kern_k; } @@ -1033,7 +1034,7 @@ public: // For FixedFormat cases, figure out the B pointer. The loop below moves through batches and vertically through the output so this will be the same throughout. if (FixedFormat) { - b_panel = reinterpret_cast<const Toi *>(this->_Bptr) + (current.multi() * this->_B_multi_stride) + + b_panel = reinterpret_cast<const Troi *>(this->_Bptr) + (current.multi() * this->_B_multi_stride) + ((current.x0() / get_stripe_width<strategy, FixedFormat>::get()) * this->_ldb) + (current.k0() * get_stripe_width<strategy, FixedFormat>::get()); } @@ -1043,7 +1044,7 @@ public: unsigned int first_m = (batch == batch_0) ? m_0 : 0; unsigned int last_m = (batch == batch_end) ? m_max : _Msize; - const Toi *a_ptr = a_panel + (batch * _Mround + first_m) * get_total_k_depth(); + const Tloi *a_ptr = a_panel + (batch * _Mround + first_m) * get_total_k_depth(); if (first_m >= last_m) continue; @@ -1165,7 +1166,7 @@ public: unsigned int x_size = roundup(_Nsize, strategy::out_width()); - return (x_size * _Ktotal * _nmulti * sizeof(Toi)) + get_col_sum_size(); + return (x_size * _Ktotal * _nmulti * sizeof(Troi)) + get_col_sum_size(); } size_t get_B_pretranspose_window_size() const override { @@ -1175,7 +1176,7 @@ public: return n_blocks * k_blocks * _nmulti; } - void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { + void requantize_bias(void *in_buffer, const Tro *B, const int ldb, const int B_multi_stride) override { if (std::is_same<OutputStage, Requantize32>::value) { col_bias = reinterpret_cast<int32_t *>(in_buffer); @@ -1195,11 +1196,11 @@ public: return transforms.PrepareB_supports_transpose(); } - void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed) override { + void pretranspose_B_array(void *in_buffer, const Tro *B, const int ldb, const int B_multi_stride, const bool transposed) override { pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, transposed, 0, get_B_pretranspose_window_size()); } - void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed, size_t start, size_t end) override { + void pretranspose_B_array_part(void *in_buffer, const Tro *B, const int ldb, const int B_multi_stride, const bool transposed, size_t start, size_t end) override { // Perform column sums etc as part of the last block. if (end >= get_B_pretranspose_window_size()) { requantize_bias(in_buffer, B, ldb, B_multi_stride); @@ -1207,7 +1208,7 @@ public: // Put the transposed data after the column sums - in non-quantized cases get_col_sum_size() == 0 uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer); - Toi *buffer = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size()); + Troi *buffer = reinterpret_cast<Troi *>(buffer_int + get_col_sum_size()); _B_transposed = buffer; blockwalker current(*this); @@ -1292,7 +1293,7 @@ public: void set_pretransposed_B_data(void *in_buffer) override { // Put the transposed data after the column sums - in non-quantized cases get_col_sum_size() == 0 uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer); - _B_transposed = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size()); + _B_transposed = reinterpret_cast<Troi *>(buffer_int + get_col_sum_size()); col_bias = reinterpret_cast<int32_t *>(in_buffer); } @@ -1312,14 +1313,14 @@ public: } } - void set_indirect_parameters(size_t string_len, const To * const * const *ptr) override { + void set_indirect_parameters(size_t string_len, const Tlo * const * const *ptr) override { assert(string_len == _Ksize); _indirect_buf = ptr; } void set_convolution_parameters(ConvolutionParameters parms) override { assert(parms.input_channels == _Ksize); - _convolver = std::unique_ptr<convolver<To>>(new convolver<To>(parms)); + _convolver = std::unique_ptr<convolver<Tlo>>(new convolver<Tlo>(parms)); } // Estimate cycles for given problem given provided parameters @@ -1330,7 +1331,7 @@ public: const PerformanceParameters ¶ms = strategy::template get_performance_parameters<perf_type>(args._ci); uint64_t total_macs = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * roundup(args._Nsize, strategy::out_width()) * get_ktotal(args); - uint64_t prepare_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * get_ktotal(args) * sizeof(Toi); + uint64_t prepare_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * get_ktotal(args) * sizeof(Tloi); uint64_t merge_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * k_blocks * args._Msize * roundup(args._Nsize, strategy::out_width()) * sizeof(Tr); float mac_cycles = static_cast<float>(total_macs) / params.kernel_macs_cycle; @@ -1357,7 +1358,7 @@ public: c.inner_block_size = _k_block; c.outer_block_size = _x_block; c.filter = get_type_name<strategy>(); - c.weight_format = get_weight_format(get_kernel_weight_format<strategy, FixedFormat, To>::get(), sizeof(To)); + c.weight_format = get_weight_format(get_kernel_weight_format<strategy, FixedFormat, Tro>::get(), sizeof(Tro)); return c; } @@ -1365,21 +1366,21 @@ public: // Aliases for the variations template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing> -using GemmInterleavedNoMerge = GemmInterleaved<strategy, To, Tr, OutputStage, false>; +using GemmInterleavedNoMerge = GemmInterleaved<strategy, To, To, Tr, OutputStage, false>; template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing> -using GemmInterleavedFixedFormat = GemmInterleaved<strategy, To, Tr, OutputStage, true, true>; +using GemmInterleavedFixedFormat = GemmInterleaved<strategy, To, To, Tr, OutputStage, true, true>; template<typename strategy, typename To, typename Tr> -using GemmInterleavedPretransposedNoMergeQuantizedInline = GemmInterleaved<strategy, To, Tr, Requantize32, false>; +using GemmInterleavedPretransposedNoMergeQuantizedInline = GemmInterleaved<strategy, To, To, Tr, Requantize32, false>; -template<typename strategy, typename To, typename Tr> -using GemmInterleavedQuantized = GemmInterleaved<strategy, To, Tr, Requantize32>; +template<typename strategy, typename Tlo, typename Tro, typename Tr> +using GemmInterleavedQuantized = GemmInterleaved<strategy, Tlo, Tro, Tr, Requantize32>; template<typename strategy, typename To, typename Tr> -using GemmInterleavedNoMergeDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat, false>; +using GemmInterleavedNoMergeDequantized = GemmInterleaved<strategy, To, To, Tr, DequantizeFloat, false>; -template<typename strategy, typename To, typename Tr> -using GemmInterleavedDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat>; +template<typename strategy, typename Tlo, typename Tro, typename Tr> +using GemmInterleavedDequantized = GemmInterleaved<strategy, Tlo, Tro, Tr, DequantizeFloat>; } // namespace arm_gemm |