aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h35
1 files changed, 27 insertions, 8 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h
index d612681c41..0e3dd74577 100644
--- a/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h
+++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,6 +24,7 @@
#ifndef ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H
#define ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H
+#include "arm_compute/core/NEON/kernels/assembly/arm_gemm_compute_iface.hpp"
#include "arm_compute/core/NEON/INEKernel.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
@@ -65,15 +66,33 @@ public:
{
return _name.c_str();
}
- // Inherited methods overridden:
+
+
void run(const Window &window, const ThreadInfo &info) override
{
ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(_kernel)));
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
- auto first = window.x().start();
- auto last = window.x().end();
- _kernel->execute(first, last, info.thread_id);
+
+ auto win=arm_gemm::to_ndcoord(window);
+
+ arm_gemm::ndcoord_t thread_locator { };
+
+ _kernel->execute(win, thread_locator, info.thread_id);
}
+
+ // Inherited methods overridden:
+ void run_nd(const Window &window, const ThreadInfo &info, const Window &thread_locator) override
+ {
+ ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(_kernel)));
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+
+ //convert between arm_compute and arm_gemm types
+ auto ndc_win = arm_gemm::to_ndcoord(window);
+ auto ndc_tlc = arm_gemm::to_ndcoord(thread_locator);
+
+ _kernel->execute(ndc_win, ndc_tlc, info.thread_id);
+ }
+
/** Initialise the kernel's input and output.
*
* @param[in] kernel Pointer to an assembly kernel implementation.
@@ -83,9 +102,9 @@ public:
{
ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(kernel)));
_kernel = kernel;
- auto win_last = _kernel->get_window_size();
- Window win;
- win.set(Window::DimX, Window::Dimension(0, win_last, 1));
+
+ Window win = to_window(kernel->get_window_size());
+
INEKernel::configure(win);
if(!kernel_name_tag.empty())