aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp')
-rw-r--r--src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp27
1 files changed, 22 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp b/src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp
index 41a031c1c7..7fc57f3c02 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp
@@ -35,10 +35,18 @@ namespace arm_compute
namespace
{
// Call the lambda function for each workload generated by the passed window.
-template <typename To, bool use_dot, typename Lambda>
+template <typename To, bool use_dot, bool use_buffer_manager, typename Lambda>
void for_each_element_in_window(const Window &window, const ITensor *b, ITensor *transformed_b, unsigned int N, unsigned int K, Lambda &&lambda)
{
- using strategy = typename Kernel<To, use_dot>::strategy;
+ using strategy = typename Kernel<To, use_dot>::strategy;
+ unsigned int wl_index = 0;
+ unsigned int num_buffers = 0, reshaped_block_size = 0;
+
+ if(use_buffer_manager)
+ {
+ num_buffers = transformed_b->info()->tensor_shape()[1];
+ reshaped_block_size = transformed_b->info()->strides_in_bytes().y();
+ }
unsigned int offset_transformed_b = transformed_b->info()->offset_first_element_in_bytes();
execute_window_loop(window, [&](const Coordinates & coordinates)
@@ -62,7 +70,16 @@ void for_each_element_in_window(const Window &window, const ITensor *b, ITensor
lambda(PrepareBWorkload(offset_b, offset_transformed_b, x0, xmax, k0, kmax));
//Each workload represents one block:
- offset_transformed_b += (x_size * k_size * sizeof(To));
+ if(use_buffer_manager)
+ {
+ // Rotate through the BufferManager's buffers:
+ wl_index++;
+ offset_transformed_b = (wl_index % num_buffers) * reshaped_block_size;
+ }
+ else
+ {
+ offset_transformed_b += (x_size * k_size * sizeof(To));
+ }
});
}
@@ -142,7 +159,7 @@ void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::transform(cons
template <typename To, bool use_dot>
void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::create_workloads(std::vector<PrepareBWorkload> &workloads)
{
- for_each_element_in_window<To, use_dot>(window(), _b, _transformed_b, _Nsize, _Ksize, [&workloads](PrepareBWorkload && wl)
+ for_each_element_in_window<To, use_dot, true>(window(), _b, _transformed_b, _Nsize, _Ksize, [&workloads](PrepareBWorkload && wl)
{
workloads.push_back(std::move(wl));
});
@@ -152,7 +169,7 @@ template <typename To, bool use_dot>
void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::run(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(window, INEKernel::window());
- for_each_element_in_window<To, use_dot>(window, _b, _transformed_b, _Nsize, _Ksize, [&](PrepareBWorkload && wl)
+ for_each_element_in_window<To, use_dot, false>(window, _b, _transformed_b, _Nsize, _Ksize, [&](PrepareBWorkload && wl)
{
this->transform(wl, info);
});