aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/NEON/AssemblyHelper.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/NEON/AssemblyHelper.h')
-rw-r--r--arm_compute/runtime/NEON/AssemblyHelper.h32
1 files changed, 15 insertions, 17 deletions
diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h
index 40f28587c2..2b4f35f2e1 100644
--- a/arm_compute/runtime/NEON/AssemblyHelper.h
+++ b/arm_compute/runtime/NEON/AssemblyHelper.h
@@ -82,24 +82,19 @@ public:
const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
- // Configure kernel window
- Window window = calculate_max_window(*_d->info());
+ const int batch_stride_a = _a->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
+
+ const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
+ const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
+
+ const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer());
const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
+ auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer());
- // Only iterate over batches
- Window win(window);
- win.set(0, Window::Dimension(0, 1, 1));
- win.set(1, Window::Dimension(0, 1, 1));
- Iterator in0(_a, window);
- Iterator out(_d, window);
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto in0_ptr = reinterpret_cast<const TypeInput *>(in0.ptr());
- auto out_ptr = reinterpret_cast<TypeOutput *>(out.ptr());
- _gemm_kernel_asm->set_arrays(in0_ptr, lda, in1_ptr, ldb, out_ptr, ldd);
- NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
- },
- in0, out);
+ _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
+ NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
}
};
@@ -146,10 +141,13 @@ inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d
const int M = d->info()->tensor_shape().y();
const int N = d->info()->tensor_shape().x();
const int K = a->info()->tensor_shape().x();
+ const int batches = a->info()->tensor_shape().total_size_upper(2);
+ const int multis = b->info()->tensor_shape().z();
unsigned int num_threads = NEScheduler::get().num_threads();
+
// unique_ptr to a Gemm object
std::unique_ptr<typename T::AssemblyGemm>
- asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, false, false, alpha, beta, num_threads, false));
+ asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, false));
// arm_compute wrapper for the Gemm object (see above)
std::unique_ptr<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>>
acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>>();