diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h')
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h | 35 |
1 files changed, 27 insertions, 8 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h index d612681c41..0e3dd74577 100644 --- a/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h +++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,6 +24,7 @@ #ifndef ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H #define ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H +#include "arm_compute/core/NEON/kernels/assembly/arm_gemm_compute_iface.hpp" #include "arm_compute/core/NEON/INEKernel.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" @@ -65,15 +66,33 @@ public: { return _name.c_str(); } - // Inherited methods overridden: + + void run(const Window &window, const ThreadInfo &info) override { ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(_kernel))); ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); - auto first = window.x().start(); - auto last = window.x().end(); - _kernel->execute(first, last, info.thread_id); + + auto win=arm_gemm::to_ndcoord(window); + + arm_gemm::ndcoord_t thread_locator { }; + + _kernel->execute(win, thread_locator, info.thread_id); } + + // Inherited methods overridden: + void run_nd(const Window &window, const ThreadInfo &info, const Window &thread_locator) override + { + ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(_kernel))); + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + + //convert between arm_compute and arm_gemm types + auto ndc_win = arm_gemm::to_ndcoord(window); + auto ndc_tlc = arm_gemm::to_ndcoord(thread_locator); + + _kernel->execute(ndc_win, ndc_tlc, info.thread_id); + } + /** Initialise the kernel's input and output. * * @param[in] kernel Pointer to an assembly kernel implementation. @@ -83,9 +102,9 @@ public: { ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(kernel))); _kernel = kernel; - auto win_last = _kernel->get_window_size(); - Window win; - win.set(Window::DimX, Window::Dimension(0, win_last, 1)); + + Window win = to_window(kernel->get_window_size()); + INEKernel::configure(win); if(!kernel_name_tag.empty()) |