From eaefd002a5d6509dd5f12e98b538c99b33c2c1ee Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Fri, 20 Jul 2018 17:49:35 +0100 Subject: COMPMID-1419: Make NEGEMMAssemblyDispatch dynamically typed instead of templated This makes it easier to integrate in GEMMLowpMatrixMultiplyCore Change-Id: Ibf80803f016a2e6a24d943ffafb50b48f04ec545 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140868 Reviewed-by: Georgios Pinitas Tested-by: Jenkins --- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 300 ++++++++++++--------- 1 file changed, 175 insertions(+), 125 deletions(-) (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp') diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index f4710fab84..e60fe80e0f 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -23,20 +23,31 @@ */ #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" +#include "arm_compute/core/CPP/Validate.h" #include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h" +#include + namespace arm_compute { +namespace +{ template -NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr memory_manager) - : _function(nullptr), _arm_gemm(), _memory_group(std::move(memory_manager)) +std::unique_ptr create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) { + ARM_COMPUTE_UNUSED(method); + ARM_COMPUTE_UNUSED(a); + ARM_COMPUTE_UNUSED(b); + ARM_COMPUTE_UNUSED(d); + ARM_COMPUTE_UNUSED(alpha); + ARM_COMPUTE_UNUSED(beta); + ARM_COMPUTE_UNUSED(pretranspose_hint); + return nullptr; } - template <> -bool NEGEMMAssemblyDispatch::create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) +std::unique_ptr create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) { ARM_COMPUTE_UNUSED(method); ARM_COMPUTE_UNUSED(a); @@ -54,132 +65,59 @@ bool NEGEMMAssemblyDispatch::create_function(arm_gemm::GemmMethod kernel->configure(a, b, d, alpha, beta); auto function = support::cpp14::make_unique(); function->configure(std::move(kernel)); - _function = std::move(function); - return true; + return std::move(function); } #endif /* __aarch64__ */ default: - return false; + return nullptr; } } +/** Fallback in case ACL doesn't have a function */ template -bool NEGEMMAssemblyDispatch::create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) +class Fallback : public NEGEMMAssemblyDispatch::IFallback { - ARM_COMPUTE_UNUSED(method); - ARM_COMPUTE_UNUSED(a); - ARM_COMPUTE_UNUSED(b); - ARM_COMPUTE_UNUSED(d); - ARM_COMPUTE_UNUSED(alpha); - ARM_COMPUTE_UNUSED(beta); - ARM_COMPUTE_UNUSED(pretranspose_hint); - return false; -} - -template -void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) -{ - INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d); - const CPUInfo &ci = NEScheduler::get().cpu_info(); - unsigned int num_threads = NEScheduler::get().num_threads(); - - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint); - - //Try to create an ACL function: - if(!create_function(arm_gemm::get_gemm_method(args), a, b, d, alpha, beta, pretranspose_hint)) +public: + void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs &args, MemoryGroup &memory_group); + void run() override; + void prepare() override; + bool is_configured() const override; + +private: + /** Allocate a workspace tensor. + * + * @param[in] workspace_size Size to allocate. + * @param[in] memory_group Tensor memory group. + * @param[in] alignment Workspace memory alignment. + */ + void allocate_workspace(size_t workspace_size, MemoryGroup *memory_group, size_t alignment); + + /** Assembly Gemm kernel */ + std::unique_ptr> _gemm_kernel_asm{ nullptr }; + /** Optimised NEON kernel */ + std::unique_ptr _optimised_kernel{ nullptr }; + /** Input A */ + const ITensor *_a { - //Fallback onto arm_gemm function if ACL doesn't support this method. - _arm_gemm.configure(a, b, d, args, _memory_group); - } -} - -template -void NEGEMMAssemblyDispatch::prepare() -{ - if(_function != nullptr) - { - _function->prepare(); - } - else + nullptr + }; + /** Input B */ + const ITensor *_b { - _arm_gemm.prepare(); - } -} - -template -bool NEGEMMAssemblyDispatch::is_configured() const -{ - return _arm_gemm.is_configured() || _function != nullptr; -} + nullptr + }; + /** Output */ + ITensor *_d{ nullptr }; + /** GEMM workspace */ + Tensor _workspace{}; + /** Pre-transpose tensor */ + Tensor _pretranspose{}; + /** Prepared flag */ + bool _is_prepared{ false }; +}; template -void NEGEMMAssemblyDispatch::run() -{ - _memory_group.acquire(); - if(_function != nullptr) - { - _function->run(); - } - else - { - _arm_gemm.run(); - } - _memory_group.release(); -} - -#ifndef __aarch64__ -template <> -void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) -{ - // arm_gemm::gemm for 8bit only exists for aarch64 - ARM_COMPUTE_UNUSED(a); - ARM_COMPUTE_UNUSED(b); - ARM_COMPUTE_UNUSED(d); - ARM_COMPUTE_UNUSED(alpha); - ARM_COMPUTE_UNUSED(beta); - ARM_COMPUTE_UNUSED(pretranspose_hint); - ARM_COMPUTE_ERROR("Not supported for this architecture"); -} - -template <> -void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) -{ - // arm_gemm::gemm for 8bit only exists for aarch64 - ARM_COMPUTE_UNUSED(a); - ARM_COMPUTE_UNUSED(b); - ARM_COMPUTE_UNUSED(d); - ARM_COMPUTE_UNUSED(alpha); - ARM_COMPUTE_UNUSED(beta); - ARM_COMPUTE_UNUSED(pretranspose_hint); - ARM_COMPUTE_ERROR("Not supported for this architecture"); -} - -template <> -void NEGEMMAssemblyDispatch::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs &args, MemoryGroup &memory_group) -{ - // arm_gemm::gemm for 8bit only exists for aarch64 - ARM_COMPUTE_UNUSED(a); - ARM_COMPUTE_UNUSED(b); - ARM_COMPUTE_UNUSED(d); - ARM_COMPUTE_UNUSED(args); - ARM_COMPUTE_UNUSED(memory_group); - ARM_COMPUTE_ERROR("Not supported for this architecture"); -} - -template <> -void NEGEMMAssemblyDispatch::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs &args, MemoryGroup &memory_group) -{ - // arm_gemm::gemm for 8bit only exists for aarch64 - ARM_COMPUTE_UNUSED(a); - ARM_COMPUTE_UNUSED(b); - ARM_COMPUTE_UNUSED(d); - ARM_COMPUTE_UNUSED(args); - ARM_COMPUTE_UNUSED(memory_group); - ARM_COMPUTE_ERROR("Not supported for this architecture"); -} -#endif // aarch64 -template -void NEGEMMAssemblyDispatch::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs &args, MemoryGroup &memory_group) +void Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs &args, MemoryGroup &memory_group) { _gemm_kernel_asm = arm_gemm::gemm(args, nullptr); if(_gemm_kernel_asm == nullptr) @@ -228,7 +166,7 @@ void NEGEMMAssemblyDispatch::Fallback::configure(const IT } template -void NEGEMMAssemblyDispatch::Fallback::prepare() +void Fallback::prepare() { if(!_is_prepared) { @@ -249,7 +187,7 @@ void NEGEMMAssemblyDispatch::Fallback::prepare() } template -void NEGEMMAssemblyDispatch::Fallback::allocate_workspace(size_t workspace_size, MemoryGroup *memory_group, size_t alignment) +void Fallback::allocate_workspace(size_t workspace_size, MemoryGroup *memory_group, size_t alignment) { ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0"); _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment); @@ -261,13 +199,13 @@ void NEGEMMAssemblyDispatch::Fallback::allocate_workspace } template -bool NEGEMMAssemblyDispatch::Fallback::is_configured() const +bool Fallback::is_configured() const { return _optimised_kernel != nullptr; } template -void NEGEMMAssemblyDispatch::Fallback::run() +void Fallback::run() { const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput); const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); @@ -312,7 +250,119 @@ void NEGEMMAssemblyDispatch::Fallback::run() NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX); } -template class NEGEMMAssemblyDispatch; -template class NEGEMMAssemblyDispatch; -template class NEGEMMAssemblyDispatch; +template +void create_function_or_arm_gemm(std::unique_ptr &acl_function, std::unique_ptr &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b, + ITensor *d, float alpha, float beta, bool pretranspose_hint) +{ + INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d); + const CPUInfo &ci = NEScheduler::get().cpu_info(); + unsigned int num_threads = NEScheduler::get().num_threads(); + + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint); + + //Try to create an ACL function: + acl_function = create_function(arm_gemm::get_gemm_method(args), a, b, d, alpha, beta, pretranspose_hint); + if(acl_function == nullptr) + { + //Fallback onto arm_gemm function if ACL doesn't support this method. + auto fallback = support::cpp14::make_unique>(); + fallback->configure(a, b, d, args, memory_group); + arm_gemm = std::move(fallback); + } +} + +} //namespace + +NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr memory_manager) + : _function(nullptr), _arm_gemm(nullptr), _memory_group(std::move(memory_manager)) +{ +} + +Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint) +{ + ARM_COMPUTE_UNUSED(alpha); + ARM_COMPUTE_UNUSED(beta); + ARM_COMPUTE_UNUSED(pretranspose_hint); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); +#ifndef __aarch64__ + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 || a->data_type() == DataType::S8 || a->data_type() == DataType::QASYMM8, "8bit integer types only supported for aarch64"); +#endif /* __aarch64__ */ + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::U8, DataType::QASYMM8, DataType::S8, DataType::F16); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((a->data_type() == DataType::QASYMM8 || a->data_type() == DataType::U8) && d->data_type() != DataType::U32, "Only U32 output supported for U8 / QASYMM8 input"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); + return Status{}; +} + +void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(a); + ARM_COMPUTE_ERROR_ON_NULLPTR(b); + ARM_COMPUTE_ERROR_ON_NULLPTR(d); + + //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured() + if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, pretranspose_hint)) + { + return; + } + + switch(a->info()->data_type()) + { + case DataType::F32: + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint); + break; +#ifdef __aarch64__ + case DataType::U8: + case DataType::QASYMM8: + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint); + break; + case DataType::S8: + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint); + break; +#endif /* __aarch64__ */ +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint); + break; +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + default: + break; + } +} + +void NEGEMMAssemblyDispatch::prepare() +{ + if(_function != nullptr) + { + _function->prepare(); + } + else + { + ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); + _arm_gemm->prepare(); + } +} + +bool NEGEMMAssemblyDispatch::is_configured() const +{ + return (_arm_gemm != nullptr && _arm_gemm->is_configured()) || _function != nullptr; +} + +void NEGEMMAssemblyDispatch::run() +{ + _memory_group.acquire(); + if(_function != nullptr) + { + _function->run(); + } + else + { + ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); + _arm_gemm->run(); + } + _memory_group.release(); +} } //namespace arm_compute -- cgit v1.2.1