aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp')
-rw-r--r--src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp12
1 files changed, 11 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp b/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp
index 2c9cd320f0..3b2975dd80 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp
@@ -101,6 +101,14 @@ void NEGEMMInterleavedMatrixMultiplyWrapperTemplate<To, Tr, use_dot>::create_wor
using strategy = typename Kernel<To, use_dot>::strategy;
unsigned int offset_transformed_b = 0;
+ unsigned int wl_index = 0;
+ unsigned int num_buffers = 0, reshaped_block_size = 0;
+
+ if(!_b_is_pretransposed)
+ {
+ num_buffers = _transformed_b->info()->tensor_shape()[1];
+ reshaped_block_size = _transformed_b->info()->tensor_shape()[0];
+ }
execute_window_loop(_block_walker, [&](const Coordinates & id)
{
const unsigned int x0 = id.x();
@@ -122,7 +130,9 @@ void NEGEMMInterleavedMatrixMultiplyWrapperTemplate<To, Tr, use_dot>::create_wor
}
else
{
- ARM_COMPUTE_ERROR("Not supported");
+ // Rotate through the BufferManager's buffers:
+ wl_index++;
+ offset_transformed_b = (wl_index % num_buffers) * reshaped_block_size;
}
});
}