diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | 111 |
1 files changed, 93 insertions, 18 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index f17da7d2e4..8ba620fe51 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -24,9 +24,13 @@ #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" #include "arm_compute/core/CPP/Validate.h" +#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h" +#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h" +#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h" #include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h" +#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h" #include <arm_neon.h> @@ -34,8 +38,31 @@ namespace arm_compute { namespace { +std::unique_ptr<IFunction> create_function_all_types(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, + std::shared_ptr<IMemoryManager> memory_manager) + +{ + //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure() + switch(method) + { + case arm_gemm::GemmMethod::GEMM_INTERLEAVED: + { + if(!pretranspose_hint) + { + return nullptr; + } + auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager); + function->configure(a, b, d, alpha, beta, pretranspose_hint); + return std::move(function); + } + default: + return nullptr; + } +} + template <typename TypeInput, typename TypeOutput> -std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) +std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, + std::shared_ptr<IMemoryManager> memory_manager) { ARM_COMPUTE_UNUSED(method); ARM_COMPUTE_UNUSED(a); @@ -44,21 +71,63 @@ std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const IT ARM_COMPUTE_UNUSED(alpha); ARM_COMPUTE_UNUSED(beta); ARM_COMPUTE_UNUSED(pretranspose_hint); + ARM_COMPUTE_UNUSED(memory_manager); return nullptr; } + +#ifdef __aarch64__ template <> -std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) +std::unique_ptr<IFunction> create_function<int8_t, int32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, + std::shared_ptr<IMemoryManager> memory_manager) +{ + switch(method) + { + case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT: + { + if(!pretranspose_hint) + { + return nullptr; + } + auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager); + function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */); + return std::move(function); + } + default: + return nullptr; + } + return nullptr; +} + +template <> +std::unique_ptr<IFunction> create_function<uint8_t, uint32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, + std::shared_ptr<IMemoryManager> memory_manager) +{ + switch(method) + { + case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT: + { + if(!pretranspose_hint) + { + return nullptr; + } + auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager); + function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */); + return std::move(function); + } + default: + return nullptr; + } + return nullptr; +} + +template <> +std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, + std::shared_ptr<IMemoryManager> memory_manager) { - ARM_COMPUTE_UNUSED(method); - ARM_COMPUTE_UNUSED(a); - ARM_COMPUTE_UNUSED(b); - ARM_COMPUTE_UNUSED(d); - ARM_COMPUTE_UNUSED(alpha); - ARM_COMPUTE_UNUSED(beta); ARM_COMPUTE_UNUSED(pretranspose_hint); + ARM_COMPUTE_UNUSED(memory_manager); switch(method) { -#ifdef __aarch64__ case arm_gemm::GemmMethod::GEMM_NATIVE: { auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>(); @@ -67,11 +136,11 @@ std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod me function->configure(std::move(kernel)); return std::move(function); } -#endif /* __aarch64__ */ default: return nullptr; } } +#endif /* __aarch64__ */ /** Fallback in case ACL doesn't have a function */ template <typename TypeInput, typename TypeOutput> @@ -173,11 +242,11 @@ void Fallback<TypeInput, TypeOutput>::prepare() // Pretranspose B if required if(_gemm_kernel_asm->B_pretranspose_required()) { + ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr); const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes()); const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); - ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr); _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b); _b->mark_as_unused(); } @@ -260,7 +329,7 @@ void Fallback<TypeInput, TypeOutput>::run() template <typename TypeInput, typename TypeOutput> void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b, - ITensor *d, float alpha, float beta, bool pretranspose_hint) + ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr<IMemoryManager> memory_manager) { INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d); const CPUInfo &ci = NEScheduler::get().cpu_info(); @@ -269,7 +338,13 @@ void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std:: arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint); //Try to create an ACL function: - acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint); + acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager); + // If the type agnostic factory failed to create an ACL function, try the specialised one: + if(acl_function == nullptr) + { + acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager); + } + //If we still don't have an ACL function: if(acl_function == nullptr) { //Fallback onto arm_gemm function if ACL doesn't support this method. @@ -282,7 +357,7 @@ void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std:: } //namespace NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager) - : _function(nullptr), _arm_gemm(nullptr), _memory_group(std::move(memory_manager)) + : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager) { } @@ -321,20 +396,20 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITens switch(a->info()->data_type()) { case DataType::F32: - create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint); + create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager); break; #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: - create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint); + create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager); break; case DataType::S8: - create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint); + create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager); break; #endif /* __aarch64__ */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint); + create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ default: |