From 3dbfd23d68b9450fc3e8bf97d92bc211e78e8979 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 30 Jan 2019 17:17:16 +0000 Subject: COMPMID-1710: Introduce GEMM strategy name in GEMMAssemblyWrapper. Change-Id: I0fd1a313c051849572367e46e7aa64b1adee5763 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/604 Tested-by: Arm Jenkins Reviewed-by: Isabella Gottardi Reviewed-by: Gian Marco Iodice --- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 23 +++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) (limited to 'src/runtime/NEON/functions') diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index cd614ba582..470e9220ae 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -77,7 +77,17 @@ template class Fallback : public NEGEMMAssemblyDispatch::IFallback { public: - void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs &args, MemoryGroup &memory_group); + /** Initialise the functions's input and output. + * + * @param[in] a Input tensor containing the Matrix A. + * @param[in] b Input tensor containing the Matrix B. + * @param[out] d Output tensor to store the result of matrix multiplication. + * @param[in] args Matrix multiplication information. + * @param[in] memory_group Memory group to be used by the function. + */ + void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs args, MemoryGroup &memory_group); + + // Inherited methods overridden: void run() override; void prepare() override; bool is_configured() const override; @@ -116,8 +126,15 @@ private: }; template -void Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs &args, MemoryGroup &memory_group) +void Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs args, MemoryGroup &memory_group) { + arm_gemm::GemmConfig gemm_cfg; + const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method(args); + if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED) + { + gemm_cfg.filter = gemm_kernel_info.name; + args._cfg = &gemm_cfg; + } _gemm_kernel_asm = arm_gemm::gemm(args); if(_gemm_kernel_asm == nullptr) { @@ -128,7 +145,7 @@ void Fallback::configure(const ITensor *a, const ITensor // arm_compute wrapper for the Gemm object (see above) std::unique_ptr> acl_gemm_wrapper = support::cpp14::make_unique>(); ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); - acl_gemm_wrapper->configure(_gemm_kernel_asm.get()); + acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); const size_t workspace_size = _gemm_kernel_asm->get_working_size(); if(workspace_size > 0) { -- cgit v1.2.1