aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-01-30 17:17:16 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-02-04 12:45:51 +0000
commit3dbfd23d68b9450fc3e8bf97d92bc211e78e8979 (patch)
tree5ed9fa0b99ddaac47688af280b9f5415fffb0fbd
parent62c3639b086d768661edc04b9b7e01a54edf486b (diff)
downloadComputeLibrary-3dbfd23d68b9450fc3e8bf97d92bc211e78e8979.tar.gz
COMPMID-1710: Introduce GEMM strategy name in GEMMAssemblyWrapper.
Change-Id: I0fd1a313c051849572367e46e7aa64b1adee5763 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/604 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
-rw-r--r--arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h36
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp23
2 files changed, 43 insertions, 16 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h
index 9eaf6061d8..084c3f2401 100644
--- a/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h
+++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,8 +25,8 @@
#define __ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H__
#include "arm_compute/core/NEON/INEKernel.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Validate.h"
#include "gemm_common.hpp"
@@ -52,20 +52,28 @@ class NEGEMMAssemblyWrapperKernel final : public INEKernel
public:
/** Constructor
*/
- NEGEMMAssemblyWrapperKernel() : _kernel(nullptr) {}
+ NEGEMMAssemblyWrapperKernel()
+ : _kernel(nullptr), _kernel_name_tag()
+ {
+ }
- NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &) = delete;
+ NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &) = delete;
NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &&) = default;
- NEGEMMAssemblyWrapperKernel & operator=(NEGEMMAssemblyWrapperKernel &) = delete;
+ NEGEMMAssemblyWrapperKernel &operator=(NEGEMMAssemblyWrapperKernel &) = delete;
const char *name() const override
{
- return "NEGEMMAssemblyWrapperKernel";
+ std::string name = "NEGEMMAssemblyWrapperKernel";
+ if(!_kernel_name_tag.empty())
+ {
+ name += "/" + _kernel_name_tag;
+ }
+ return name.c_str();
}
// Inherited methods overridden:
void run(const Window &window, const ThreadInfo &info) override
{
- ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void*>(_kernel)));
+ ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(_kernel)));
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
auto first = window.x().start();
auto last = window.x().end();
@@ -76,18 +84,20 @@ public:
* @param[in] kernel Pointer to an assembly kernel implementation.
* @param[in] num_threads Number of concurrent threads which will execute the kernel.
*/
- void configure(arm_gemm::GemmCommon<TypeInput, TypeOutput> *kernel)
+ void configure(arm_gemm::GemmCommon<TypeInput, TypeOutput> *kernel, std::string kernel_name_tag)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void*>(kernel)));
- _kernel = kernel;
- auto win_last = _kernel->get_window_size();
+ ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(kernel)));
+ _kernel = kernel;
+ _kernel_name_tag = kernel_name_tag;
+ auto win_last = _kernel->get_window_size();
Window win;
win.set(Window::DimX, Window::Dimension(0, win_last, 1));
INEKernel::configure(win);
}
+
private:
- arm_gemm::GemmCommon<TypeInput, TypeOutput>* _kernel;
+ arm_gemm::GemmCommon<TypeInput, TypeOutput> *_kernel;
+ std::string _kernel_name_tag;
};
-
} // namespace arm_compute
#endif /* __ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H__ */
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index cd614ba582..470e9220ae 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -77,7 +77,17 @@ template <typename TypeInput, typename TypeOutput>
class Fallback : public NEGEMMAssemblyDispatch::IFallback
{
public:
- void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group);
+ /** Initialise the functions's input and output.
+ *
+ * @param[in] a Input tensor containing the Matrix A.
+ * @param[in] b Input tensor containing the Matrix B.
+ * @param[out] d Output tensor to store the result of matrix multiplication.
+ * @param[in] args Matrix multiplication information.
+ * @param[in] memory_group Memory group to be used by the function.
+ */
+ void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group);
+
+ // Inherited methods overridden:
void run() override;
void prepare() override;
bool is_configured() const override;
@@ -116,8 +126,15 @@ private:
};
template <typename TypeInput, typename TypeOutput>
-void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group)
+void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group)
{
+ arm_gemm::GemmConfig gemm_cfg;
+ const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args);
+ if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED)
+ {
+ gemm_cfg.filter = gemm_kernel_info.name;
+ args._cfg = &gemm_cfg;
+ }
_gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args);
if(_gemm_kernel_asm == nullptr)
{
@@ -128,7 +145,7 @@ void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor
// arm_compute wrapper for the Gemm object (see above)
std::unique_ptr<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>>();
ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
- acl_gemm_wrapper->configure(_gemm_kernel_asm.get());
+ acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
const size_t workspace_size = _gemm_kernel_asm->get_working_size();
if(workspace_size > 0)
{