From e7e96e09ff0d3e47797adf197aff2bc39671788c Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Fri, 13 Apr 2018 13:44:10 +0100 Subject: COMPMID-1054 Update RSH's GEMM to add batch+multi support Change-Id: Ib9d91b77f1d51976da4449fa1e6eeeffae307353 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127876 Tested-by: Jenkins Reviewed-by: Pablo Tello Reviewed-by: Anthony Barbier --- arm_compute/runtime/NEON/AssemblyHelper.h | 32 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 17 deletions(-) (limited to 'arm_compute/runtime/NEON/AssemblyHelper.h') diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h index 40f28587c2..2b4f35f2e1 100644 --- a/arm_compute/runtime/NEON/AssemblyHelper.h +++ b/arm_compute/runtime/NEON/AssemblyHelper.h @@ -82,24 +82,19 @@ public: const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput); - // Configure kernel window - Window window = calculate_max_window(*_d->info()); + const int batch_stride_a = _a->info()->strides_in_bytes().z() / sizeof(TypeInput); + const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput); + + const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput); + const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput); + + const auto in0_ptr = reinterpret_cast(_a->buffer()); const auto in1_ptr = reinterpret_cast(_b->buffer()); + auto out_ptr = reinterpret_cast(_d->buffer()); - // Only iterate over batches - Window win(window); - win.set(0, Window::Dimension(0, 1, 1)); - win.set(1, Window::Dimension(0, 1, 1)); - Iterator in0(_a, window); - Iterator out(_d, window); - execute_window_loop(win, [&](const Coordinates &) - { - const auto in0_ptr = reinterpret_cast(in0.ptr()); - auto out_ptr = reinterpret_cast(out.ptr()); - _gemm_kernel_asm->set_arrays(in0_ptr, lda, in1_ptr, ldb, out_ptr, ldd); - NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX); - }, - in0, out); + _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d); + NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX); } }; @@ -146,10 +141,13 @@ inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d const int M = d->info()->tensor_shape().y(); const int N = d->info()->tensor_shape().x(); const int K = a->info()->tensor_shape().x(); + const int batches = a->info()->tensor_shape().total_size_upper(2); + const int multis = b->info()->tensor_shape().z(); unsigned int num_threads = NEScheduler::get().num_threads(); + // unique_ptr to a Gemm object std::unique_ptr - asm_gemm(arm_gemm::gemm(ci, M, N, K, false, false, alpha, beta, num_threads, false)); + asm_gemm(arm_gemm::gemm(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, false)); // arm_compute wrapper for the Gemm object (see above) std::unique_ptr> acl_gemm_wrapper = support::cpp14::make_unique>(); -- cgit v1.2.1