aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-01-09 18:35:17 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-01-18 13:41:40 +0000
commit7cd26d4a1b14bc4bf7c61496803416ab3d84791f (patch)
tree12cc4a27d7ecebc69a43e96b1f46c7eb05437978 /src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
parent3ac2f3a1d9297220d1b0ce920dd13fdd4edcc187 (diff)
downloadComputeLibrary-7cd26d4a1b14bc4bf7c61496803416ab3d84791f.tar.gz
COMPMID-1867: Add NEON/SVE GEMM Hybrid kernels.
Change-Id: Ib40a9921e7f9a6a8be6c38872d6b3a0f24ed0cd3 Reviewed-on: https://review.mlplatform.org/515 Reviewed-by: Anthony Barbier <Anthony.barbier@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp107
1 files changed, 16 insertions, 91 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 25be4a5349..cd614ba582 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,9 +24,6 @@
#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
#include "arm_compute/core/CPP/Validate.h"
-#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h"
-#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
-#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h"
#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
@@ -38,14 +35,14 @@ namespace arm_compute
{
namespace
{
-std::unique_ptr<IFunction> create_function_all_types(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+std::unique_ptr<IFunction> create_function_all_types(arm_gemm::KernelDescription gemm_kernel_info,
+ const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
std::shared_ptr<IMemoryManager> memory_manager)
{
//Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
- switch(method)
+ switch(gemm_kernel_info.method)
{
- case arm_gemm::GemmMethod::GEMM_INTERLEAVED_FP16:
case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
{
if(!pretranspose_hint)
@@ -56,92 +53,24 @@ std::unique_ptr<IFunction> create_function_all_types(arm_gemm::GemmMethod method
function->configure(a, b, d, alpha, beta, pretranspose_hint);
return std::move(function);
}
- default:
- return nullptr;
- }
-}
-
-template <typename TypeInput, typename TypeOutput>
-std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
- std::shared_ptr<IMemoryManager> memory_manager)
-{
- ARM_COMPUTE_UNUSED(method);
- ARM_COMPUTE_UNUSED(a);
- ARM_COMPUTE_UNUSED(b);
- ARM_COMPUTE_UNUSED(d);
- ARM_COMPUTE_UNUSED(alpha);
- ARM_COMPUTE_UNUSED(beta);
- ARM_COMPUTE_UNUSED(pretranspose_hint);
- ARM_COMPUTE_UNUSED(memory_manager);
- return nullptr;
-}
-
-#ifdef __aarch64__
-template <>
-std::unique_ptr<IFunction> create_function<int8_t, int32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
- std::shared_ptr<IMemoryManager> memory_manager)
-{
- switch(method)
- {
- case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
- {
- if(!pretranspose_hint)
- {
- return nullptr;
- }
- auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
- function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
- return std::move(function);
- }
- default:
- return nullptr;
- }
- return nullptr;
-}
-
-template <>
-std::unique_ptr<IFunction> create_function<uint8_t, uint32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
- std::shared_ptr<IMemoryManager> memory_manager)
-{
- switch(method)
- {
- case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
+#if defined(__aarch64__)
+ case arm_gemm::GemmMethod::GEMM_NATIVE:
{
- if(!pretranspose_hint)
+ if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos)
{
- return nullptr;
+ auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
+ kernel->configure(a, b, d, alpha, beta);
+ auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
+ function->configure(std::move(kernel));
+ return std::move(function);
}
- auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
- function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
- return std::move(function);
- }
- default:
return nullptr;
- }
- return nullptr;
-}
-
-template <>
-std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
- std::shared_ptr<IMemoryManager> memory_manager)
-{
- ARM_COMPUTE_UNUSED(pretranspose_hint);
- ARM_COMPUTE_UNUSED(memory_manager);
- switch(method)
- {
- case arm_gemm::GemmMethod::GEMM_NATIVE:
- {
- auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
- kernel->configure(a, b, d, alpha, beta);
- auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
- function->configure(std::move(kernel));
- return std::move(function);
}
+#endif // defined(__aarch64__)
default:
return nullptr;
}
}
-#endif /* __aarch64__ */
/** Fallback in case ACL doesn't have a function */
template <typename TypeInput, typename TypeOutput>
@@ -189,7 +118,7 @@ private:
template <typename TypeInput, typename TypeOutput>
void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group)
{
- _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args, nullptr);
+ _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args);
if(_gemm_kernel_asm == nullptr)
{
//configuration not supported: Leave function unconfigured:
@@ -334,12 +263,8 @@ void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::
arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
//Try to create an ACL function:
- acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
- // If the type agnostic factory failed to create an ACL function, try the specialised one:
- if(acl_function == nullptr)
- {
- acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
- }
+ acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, std::move(memory_manager));
+
//If we still don't have an ACL function:
if(acl_function == nullptr)
{