diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | 107 |
1 files changed, 16 insertions, 91 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 25be4a5349..cd614ba582 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,9 +24,6 @@ #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" #include "arm_compute/core/CPP/Validate.h" -#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h" -#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h" -#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h" #include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h" @@ -38,14 +35,14 @@ namespace arm_compute { namespace { -std::unique_ptr<IFunction> create_function_all_types(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, +std::unique_ptr<IFunction> create_function_all_types(arm_gemm::KernelDescription gemm_kernel_info, + const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr<IMemoryManager> memory_manager) { //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure() - switch(method) + switch(gemm_kernel_info.method) { - case arm_gemm::GemmMethod::GEMM_INTERLEAVED_FP16: case arm_gemm::GemmMethod::GEMM_INTERLEAVED: { if(!pretranspose_hint) @@ -56,92 +53,24 @@ std::unique_ptr<IFunction> create_function_all_types(arm_gemm::GemmMethod method function->configure(a, b, d, alpha, beta, pretranspose_hint); return std::move(function); } - default: - return nullptr; - } -} - -template <typename TypeInput, typename TypeOutput> -std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, - std::shared_ptr<IMemoryManager> memory_manager) -{ - ARM_COMPUTE_UNUSED(method); - ARM_COMPUTE_UNUSED(a); - ARM_COMPUTE_UNUSED(b); - ARM_COMPUTE_UNUSED(d); - ARM_COMPUTE_UNUSED(alpha); - ARM_COMPUTE_UNUSED(beta); - ARM_COMPUTE_UNUSED(pretranspose_hint); - ARM_COMPUTE_UNUSED(memory_manager); - return nullptr; -} - -#ifdef __aarch64__ -template <> -std::unique_ptr<IFunction> create_function<int8_t, int32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, - std::shared_ptr<IMemoryManager> memory_manager) -{ - switch(method) - { - case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT: - { - if(!pretranspose_hint) - { - return nullptr; - } - auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager); - function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */); - return std::move(function); - } - default: - return nullptr; - } - return nullptr; -} - -template <> -std::unique_ptr<IFunction> create_function<uint8_t, uint32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, - std::shared_ptr<IMemoryManager> memory_manager) -{ - switch(method) - { - case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT: +#if defined(__aarch64__) + case arm_gemm::GemmMethod::GEMM_NATIVE: { - if(!pretranspose_hint) + if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos) { - return nullptr; + auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>(); + kernel->configure(a, b, d, alpha, beta); + auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>(); + function->configure(std::move(kernel)); + return std::move(function); } - auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager); - function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */); - return std::move(function); - } - default: return nullptr; - } - return nullptr; -} - -template <> -std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, - std::shared_ptr<IMemoryManager> memory_manager) -{ - ARM_COMPUTE_UNUSED(pretranspose_hint); - ARM_COMPUTE_UNUSED(memory_manager); - switch(method) - { - case arm_gemm::GemmMethod::GEMM_NATIVE: - { - auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>(); - kernel->configure(a, b, d, alpha, beta); - auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>(); - function->configure(std::move(kernel)); - return std::move(function); } +#endif // defined(__aarch64__) default: return nullptr; } } -#endif /* __aarch64__ */ /** Fallback in case ACL doesn't have a function */ template <typename TypeInput, typename TypeOutput> @@ -189,7 +118,7 @@ private: template <typename TypeInput, typename TypeOutput> void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group) { - _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args, nullptr); + _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args); if(_gemm_kernel_asm == nullptr) { //configuration not supported: Leave function unconfigured: @@ -334,12 +263,8 @@ void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std:: arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint); //Try to create an ACL function: - acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager); - // If the type agnostic factory failed to create an ACL function, try the specialised one: - if(acl_function == nullptr) - { - acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager); - } + acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, std::move(memory_manager)); + //If we still don't have an ACL function: if(acl_function == nullptr) { |