From 3dbfd23d68b9450fc3e8bf97d92bc211e78e8979 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 30 Jan 2019 17:17:16 +0000 Subject: COMPMID-1710: Introduce GEMM strategy name in GEMMAssemblyWrapper. Change-Id: I0fd1a313c051849572367e46e7aa64b1adee5763 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/604 Tested-by: Arm Jenkins Reviewed-by: Isabella Gottardi Reviewed-by: Gian Marco Iodice --- .../kernels/assembly/NEGEMMAssemblyWrapperKernel.h | 36 ++++++++++++++-------- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 23 ++++++++++++-- 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h index 9eaf6061d8..084c3f2401 100644 --- a/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h +++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -25,8 +25,8 @@ #define __ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H__ #include "arm_compute/core/NEON/INEKernel.h" -#include "arm_compute/core/Validate.h" #include "arm_compute/core/Utils.h" +#include "arm_compute/core/Validate.h" #include "gemm_common.hpp" @@ -52,20 +52,28 @@ class NEGEMMAssemblyWrapperKernel final : public INEKernel public: /** Constructor */ - NEGEMMAssemblyWrapperKernel() : _kernel(nullptr) {} + NEGEMMAssemblyWrapperKernel() + : _kernel(nullptr), _kernel_name_tag() + { + } - NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &) = delete; + NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &) = delete; NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &&) = default; - NEGEMMAssemblyWrapperKernel & operator=(NEGEMMAssemblyWrapperKernel &) = delete; + NEGEMMAssemblyWrapperKernel &operator=(NEGEMMAssemblyWrapperKernel &) = delete; const char *name() const override { - return "NEGEMMAssemblyWrapperKernel"; + std::string name = "NEGEMMAssemblyWrapperKernel"; + if(!_kernel_name_tag.empty()) + { + name += "/" + _kernel_name_tag; + } + return name.c_str(); } // Inherited methods overridden: void run(const Window &window, const ThreadInfo &info) override { - ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast(_kernel))); + ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast(_kernel))); ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); auto first = window.x().start(); auto last = window.x().end(); @@ -76,18 +84,20 @@ public: * @param[in] kernel Pointer to an assembly kernel implementation. * @param[in] num_threads Number of concurrent threads which will execute the kernel. */ - void configure(arm_gemm::GemmCommon *kernel) + void configure(arm_gemm::GemmCommon *kernel, std::string kernel_name_tag) { - ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast(kernel))); - _kernel = kernel; - auto win_last = _kernel->get_window_size(); + ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast(kernel))); + _kernel = kernel; + _kernel_name_tag = kernel_name_tag; + auto win_last = _kernel->get_window_size(); Window win; win.set(Window::DimX, Window::Dimension(0, win_last, 1)); INEKernel::configure(win); } + private: - arm_gemm::GemmCommon* _kernel; + arm_gemm::GemmCommon *_kernel; + std::string _kernel_name_tag; }; - } // namespace arm_compute #endif /* __ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H__ */ diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index cd614ba582..470e9220ae 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -77,7 +77,17 @@ template class Fallback : public NEGEMMAssemblyDispatch::IFallback { public: - void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs &args, MemoryGroup &memory_group); + /** Initialise the functions's input and output. + * + * @param[in] a Input tensor containing the Matrix A. + * @param[in] b Input tensor containing the Matrix B. + * @param[out] d Output tensor to store the result of matrix multiplication. + * @param[in] args Matrix multiplication information. + * @param[in] memory_group Memory group to be used by the function. + */ + void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs args, MemoryGroup &memory_group); + + // Inherited methods overridden: void run() override; void prepare() override; bool is_configured() const override; @@ -116,8 +126,15 @@ private: }; template -void Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs &args, MemoryGroup &memory_group) +void Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs args, MemoryGroup &memory_group) { + arm_gemm::GemmConfig gemm_cfg; + const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method(args); + if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED) + { + gemm_cfg.filter = gemm_kernel_info.name; + args._cfg = &gemm_cfg; + } _gemm_kernel_asm = arm_gemm::gemm(args); if(_gemm_kernel_asm == nullptr) { @@ -128,7 +145,7 @@ void Fallback::configure(const ITensor *a, const ITensor // arm_compute wrapper for the Gemm object (see above) std::unique_ptr> acl_gemm_wrapper = support::cpp14::make_unique>(); ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); - acl_gemm_wrapper->configure(_gemm_kernel_asm.get()); + acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); const size_t workspace_size = _gemm_kernel_asm->get_working_size(); if(workspace_size > 0) { -- cgit v1.2.1