aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp39
1 files changed, 36 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index c2fd0b0e8c..13f548e39e 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -31,6 +31,7 @@
#include "convolver.hpp"
#include "kernel_weight_format.hpp"
#include "kernel_traits.hpp"
+#include "kernel_weight_format.hpp"
#include "mergeresults.hpp"
#include "performance_parameters.hpp"
#include "quantized.hpp"
@@ -1039,6 +1040,13 @@ public:
return (x_size * _Ktotal * _nmulti * sizeof(Toi)) + get_col_sum_size();
}
+ size_t get_B_pretranspose_window_size() const override {
+ size_t n_blocks = iceildiv(_Nsize, _x_block);
+ size_t k_blocks = iceildiv(_Ktotal, _k_block);
+
+ return n_blocks * k_blocks * _nmulti;
+ }
+
void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
if (std::is_same<OutputStage, Requantize32>::value) {
col_bias = reinterpret_cast<int32_t *>(in_buffer);
@@ -1053,7 +1061,14 @@ public:
}
void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
- requantize_bias(in_buffer, B, ldb, B_multi_stride);
+ pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, 0, get_B_pretranspose_window_size());
+ }
+
+ void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, size_t start, size_t end) override {
+ // Perform column sums etc as part of the last block.
+ if (end >= get_B_pretranspose_window_size()) {
+ requantize_bias(in_buffer, B, ldb, B_multi_stride);
+ }
// Put the transposed data after the column sums - in non-quantized cases get_col_sum_size() == 0
uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
@@ -1063,7 +1078,20 @@ public:
blockwalker current(*this);
strategy strat(_ci);
- do {
+ // Skip over blocks we aren't doing
+ for(size_t i = 0; i < start; i++) {
+ buffer += roundup(current.xmax() - current.x0(), strategy::out_width()) * roundup(current.kmax() - current.k0(), strategy::k_unroll());
+ current.advance();
+ }
+
+ size_t blocks_left = (end - start);
+
+ // Double check that we haven't run out of work
+ if (current.done()) {
+ blocks_left = 0;
+ }
+
+ for (/* blocks_left initialized above */; blocks_left > 0; blocks_left--) {
/* Figure out the size of each block. */
unsigned int k_size = (current.kmax() - current.k0());
@@ -1117,7 +1145,12 @@ public:
current.x0(), current.xmax(), current.k0(), std::min(current.kmax(), _Ksize));
buffer += roundup(current.xmax() - current.x0(), strategy::out_width()) * roundup(current.kmax() - current.k0(), strategy::k_unroll());
}
- } while (current.advance());
+
+ // Advance to the next block, break if we run off the end.
+ if (!current.advance()) {
+ break;
+ }
+ }
}
void set_pretransposed_B_data(void *in_buffer) override {