diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp | 39 |
1 files changed, 36 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index c2fd0b0e8c..13f548e39e 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -31,6 +31,7 @@ #include "convolver.hpp" #include "kernel_weight_format.hpp" #include "kernel_traits.hpp" +#include "kernel_weight_format.hpp" #include "mergeresults.hpp" #include "performance_parameters.hpp" #include "quantized.hpp" @@ -1039,6 +1040,13 @@ public: return (x_size * _Ktotal * _nmulti * sizeof(Toi)) + get_col_sum_size(); } + size_t get_B_pretranspose_window_size() const override { + size_t n_blocks = iceildiv(_Nsize, _x_block); + size_t k_blocks = iceildiv(_Ktotal, _k_block); + + return n_blocks * k_blocks * _nmulti; + } + void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { if (std::is_same<OutputStage, Requantize32>::value) { col_bias = reinterpret_cast<int32_t *>(in_buffer); @@ -1053,7 +1061,14 @@ public: } void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { - requantize_bias(in_buffer, B, ldb, B_multi_stride); + pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, 0, get_B_pretranspose_window_size()); + } + + void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, size_t start, size_t end) override { + // Perform column sums etc as part of the last block. + if (end >= get_B_pretranspose_window_size()) { + requantize_bias(in_buffer, B, ldb, B_multi_stride); + } // Put the transposed data after the column sums - in non-quantized cases get_col_sum_size() == 0 uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer); @@ -1063,7 +1078,20 @@ public: blockwalker current(*this); strategy strat(_ci); - do { + // Skip over blocks we aren't doing + for(size_t i = 0; i < start; i++) { + buffer += roundup(current.xmax() - current.x0(), strategy::out_width()) * roundup(current.kmax() - current.k0(), strategy::k_unroll()); + current.advance(); + } + + size_t blocks_left = (end - start); + + // Double check that we haven't run out of work + if (current.done()) { + blocks_left = 0; + } + + for (/* blocks_left initialized above */; blocks_left > 0; blocks_left--) { /* Figure out the size of each block. */ unsigned int k_size = (current.kmax() - current.k0()); @@ -1117,7 +1145,12 @@ public: current.x0(), current.xmax(), current.k0(), std::min(current.kmax(), _Ksize)); buffer += roundup(current.xmax() - current.x0(), strategy::out_width()) * roundup(current.kmax() - current.k0(), strategy::k_unroll()); } - } while (current.advance()); + + // Advance to the next block, break if we run off the end. + if (!current.advance()) { + break; + } + } } void set_pretransposed_B_data(void *in_buffer) override { |