From 7cd26d4a1b14bc4bf7c61496803416ab3d84791f Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 9 Jan 2019 18:35:17 +0000 Subject: COMPMID-1867: Add NEON/SVE GEMM Hybrid kernels. Change-Id: Ib40a9921e7f9a6a8be6c38872d6b3a0f24ed0cd3 Reviewed-on: https://review.mlplatform.org/515 Reviewed-by: Anthony Barbier Tested-by: Arm Jenkins --- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 107 +++------------------ 1 file changed, 16 insertions(+), 91 deletions(-) (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp') diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 25be4a5349..cd614ba582 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,9 +24,6 @@ #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" #include "arm_compute/core/CPP/Validate.h" -#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h" -#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h" -#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h" #include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h" @@ -38,14 +35,14 @@ namespace arm_compute { namespace { -std::unique_ptr create_function_all_types(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, +std::unique_ptr create_function_all_types(arm_gemm::KernelDescription gemm_kernel_info, + const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr memory_manager) { //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure() - switch(method) + switch(gemm_kernel_info.method) { - case arm_gemm::GemmMethod::GEMM_INTERLEAVED_FP16: case arm_gemm::GemmMethod::GEMM_INTERLEAVED: { if(!pretranspose_hint) @@ -56,92 +53,24 @@ std::unique_ptr create_function_all_types(arm_gemm::GemmMethod method function->configure(a, b, d, alpha, beta, pretranspose_hint); return std::move(function); } - default: - return nullptr; - } -} - -template -std::unique_ptr create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, - std::shared_ptr memory_manager) -{ - ARM_COMPUTE_UNUSED(method); - ARM_COMPUTE_UNUSED(a); - ARM_COMPUTE_UNUSED(b); - ARM_COMPUTE_UNUSED(d); - ARM_COMPUTE_UNUSED(alpha); - ARM_COMPUTE_UNUSED(beta); - ARM_COMPUTE_UNUSED(pretranspose_hint); - ARM_COMPUTE_UNUSED(memory_manager); - return nullptr; -} - -#ifdef __aarch64__ -template <> -std::unique_ptr create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, - std::shared_ptr memory_manager) -{ - switch(method) - { - case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT: - { - if(!pretranspose_hint) - { - return nullptr; - } - auto function = support::cpp14::make_unique(memory_manager); - function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */); - return std::move(function); - } - default: - return nullptr; - } - return nullptr; -} - -template <> -std::unique_ptr create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, - std::shared_ptr memory_manager) -{ - switch(method) - { - case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT: +#if defined(__aarch64__) + case arm_gemm::GemmMethod::GEMM_NATIVE: { - if(!pretranspose_hint) + if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos) { - return nullptr; + auto kernel = support::cpp14::make_unique>(); + kernel->configure(a, b, d, alpha, beta); + auto function = support::cpp14::make_unique(); + function->configure(std::move(kernel)); + return std::move(function); } - auto function = support::cpp14::make_unique(memory_manager); - function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */); - return std::move(function); - } - default: return nullptr; - } - return nullptr; -} - -template <> -std::unique_ptr create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, - std::shared_ptr memory_manager) -{ - ARM_COMPUTE_UNUSED(pretranspose_hint); - ARM_COMPUTE_UNUSED(memory_manager); - switch(method) - { - case arm_gemm::GemmMethod::GEMM_NATIVE: - { - auto kernel = support::cpp14::make_unique>(); - kernel->configure(a, b, d, alpha, beta); - auto function = support::cpp14::make_unique(); - function->configure(std::move(kernel)); - return std::move(function); } +#endif // defined(__aarch64__) default: return nullptr; } } -#endif /* __aarch64__ */ /** Fallback in case ACL doesn't have a function */ template @@ -189,7 +118,7 @@ private: template void Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs &args, MemoryGroup &memory_group) { - _gemm_kernel_asm = arm_gemm::gemm(args, nullptr); + _gemm_kernel_asm = arm_gemm::gemm(args); if(_gemm_kernel_asm == nullptr) { //configuration not supported: Leave function unconfigured: @@ -334,12 +263,8 @@ void create_function_or_arm_gemm(std::unique_ptr &acl_function, std:: arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint); //Try to create an ACL function: - acl_function = create_function_all_types(arm_gemm::get_gemm_method(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager); - // If the type agnostic factory failed to create an ACL function, try the specialised one: - if(acl_function == nullptr) - { - acl_function = create_function(arm_gemm::get_gemm_method(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager); - } + acl_function = create_function_all_types(arm_gemm::get_gemm_method(args), a, b, d, alpha, beta, pretranspose_hint, std::move(memory_manager)); + //If we still don't have an ACL function: if(acl_function == nullptr) { -- cgit v1.2.1