diff options
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h | 36 | ||||
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | 23 |
2 files changed, 43 insertions, 16 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h index 9eaf6061d8..084c3f2401 100644 --- a/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h +++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -25,8 +25,8 @@ #define __ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H__ #include "arm_compute/core/NEON/INEKernel.h" -#include "arm_compute/core/Validate.h" #include "arm_compute/core/Utils.h" +#include "arm_compute/core/Validate.h" #include "gemm_common.hpp" @@ -52,20 +52,28 @@ class NEGEMMAssemblyWrapperKernel final : public INEKernel public: /** Constructor */ - NEGEMMAssemblyWrapperKernel() : _kernel(nullptr) {} + NEGEMMAssemblyWrapperKernel() + : _kernel(nullptr), _kernel_name_tag() + { + } - NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &) = delete; + NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &) = delete; NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &&) = default; - NEGEMMAssemblyWrapperKernel & operator=(NEGEMMAssemblyWrapperKernel &) = delete; + NEGEMMAssemblyWrapperKernel &operator=(NEGEMMAssemblyWrapperKernel &) = delete; const char *name() const override { - return "NEGEMMAssemblyWrapperKernel"; + std::string name = "NEGEMMAssemblyWrapperKernel"; + if(!_kernel_name_tag.empty()) + { + name += "/" + _kernel_name_tag; + } + return name.c_str(); } // Inherited methods overridden: void run(const Window &window, const ThreadInfo &info) override { - ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void*>(_kernel))); + ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(_kernel))); ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); auto first = window.x().start(); auto last = window.x().end(); @@ -76,18 +84,20 @@ public: * @param[in] kernel Pointer to an assembly kernel implementation. * @param[in] num_threads Number of concurrent threads which will execute the kernel. */ - void configure(arm_gemm::GemmCommon<TypeInput, TypeOutput> *kernel) + void configure(arm_gemm::GemmCommon<TypeInput, TypeOutput> *kernel, std::string kernel_name_tag) { - ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void*>(kernel))); - _kernel = kernel; - auto win_last = _kernel->get_window_size(); + ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(kernel))); + _kernel = kernel; + _kernel_name_tag = kernel_name_tag; + auto win_last = _kernel->get_window_size(); Window win; win.set(Window::DimX, Window::Dimension(0, win_last, 1)); INEKernel::configure(win); } + private: - arm_gemm::GemmCommon<TypeInput, TypeOutput>* _kernel; + arm_gemm::GemmCommon<TypeInput, TypeOutput> *_kernel; + std::string _kernel_name_tag; }; - } // namespace arm_compute #endif /* __ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H__ */ diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index cd614ba582..470e9220ae 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -77,7 +77,17 @@ template <typename TypeInput, typename TypeOutput> class Fallback : public NEGEMMAssemblyDispatch::IFallback { public: - void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group); + /** Initialise the functions's input and output. + * + * @param[in] a Input tensor containing the Matrix A. + * @param[in] b Input tensor containing the Matrix B. + * @param[out] d Output tensor to store the result of matrix multiplication. + * @param[in] args Matrix multiplication information. + * @param[in] memory_group Memory group to be used by the function. + */ + void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group); + + // Inherited methods overridden: void run() override; void prepare() override; bool is_configured() const override; @@ -116,8 +126,15 @@ private: }; template <typename TypeInput, typename TypeOutput> -void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group) +void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group) { + arm_gemm::GemmConfig gemm_cfg; + const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args); + if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED) + { + gemm_cfg.filter = gemm_kernel_info.name; + args._cfg = &gemm_cfg; + } _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args); if(_gemm_kernel_asm == nullptr) { @@ -128,7 +145,7 @@ void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor // arm_compute wrapper for the Gemm object (see above) std::unique_ptr<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>>(); ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); - acl_gemm_wrapper->configure(_gemm_kernel_asm.get()); + acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); const size_t workspace_size = _gemm_kernel_asm->get_working_size(); if(workspace_size > 0) { |