diff options
Diffstat (limited to 'src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp')
-rw-r--r-- | src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp b/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp index 2c9cd320f0..3b2975dd80 100644 --- a/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp +++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp @@ -101,6 +101,14 @@ void NEGEMMInterleavedMatrixMultiplyWrapperTemplate<To, Tr, use_dot>::create_wor using strategy = typename Kernel<To, use_dot>::strategy; unsigned int offset_transformed_b = 0; + unsigned int wl_index = 0; + unsigned int num_buffers = 0, reshaped_block_size = 0; + + if(!_b_is_pretransposed) + { + num_buffers = _transformed_b->info()->tensor_shape()[1]; + reshaped_block_size = _transformed_b->info()->tensor_shape()[0]; + } execute_window_loop(_block_walker, [&](const Coordinates & id) { const unsigned int x0 = id.x(); @@ -122,7 +130,9 @@ void NEGEMMInterleavedMatrixMultiplyWrapperTemplate<To, Tr, use_dot>::create_wor } else { - ARM_COMPUTE_ERROR("Not supported"); + // Rotate through the BufferManager's buffers: + wl_index++; + offset_transformed_b = (wl_index % num_buffers) * reshaped_block_size; } }); } |