From 48b3ef89de5f21a0169d8416e3d54081f82c7bf8 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 14 Oct 2019 19:03:09 +0100 Subject: COMPMID-2577: Fuse bias addition and activation in gemm assembly kernels Change-Id: I7f52112d2d05b1ea3d3f3d4b19b8eafab05d6c44 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/2141 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez --- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 180 +++++++++------------ 1 file changed, 79 insertions(+), 101 deletions(-) (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp') diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 956ded55d2..b31ecb91e9 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -24,10 +24,8 @@ #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" #include "arm_compute/core/CPP/Validate.h" -#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h" -#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h" #include @@ -35,43 +33,36 @@ namespace arm_compute { namespace { -std::unique_ptr create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info, - const ITensor *a, const ITensor *b, ITensor *d, - float alpha, float beta, const GEMMInfo &gemm_info, - std::shared_ptr memory_manager, - IWeightsManager *weights_manager) - +arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act) { - // Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure() - switch(gemm_kernel_info.method) + arm_gemm::Activation gemm_act; + + // Early exit in case lower bound is other than 0, as it's not yet supported + if(act.b() != 0.f) { - case arm_gemm::GemmMethod::GEMM_INTERLEAVED: - { - if(!gemm_info.pretranpose_B()) - { - return nullptr; - } - auto function = support::cpp14::make_unique(memory_manager, weights_manager); - function->configure(a, b, d, alpha, beta, gemm_info); - return std::move(function); - } -#if defined(__aarch64__) - case arm_gemm::GemmMethod::GEMM_NATIVE: - { - if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos) - { - auto kernel = support::cpp14::make_unique>(); - kernel->configure(a, b, d, alpha, beta, gemm_info); - auto function = support::cpp14::make_unique(); - function->configure(std::move(kernel)); - return std::move(function); - } - return nullptr; - } -#endif // defined(__aarch64__) + return gemm_act; + } + + switch(act.activation()) + { + case ActivationLayerInfo::ActivationFunction::RELU: + gemm_act.type = arm_gemm::Activation::Type::ReLU; + break; + case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU: + gemm_act.type = arm_gemm::Activation::Type::BoundedReLU; + gemm_act.param1 = act.a(); + gemm_act.param2 = 0.f; + break; + case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU: + gemm_act.type = arm_gemm::Activation::Type::BoundedReLU; + gemm_act.param1 = act.a(); + gemm_act.param2 = act.b(); + break; default: - return nullptr; + gemm_act.type = arm_gemm::Activation::Type::None; } + + return gemm_act; } template @@ -161,7 +152,7 @@ public: * @param[in] os Output stage meta-data. */ void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, - arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, + arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {}); // Inherited methods overridden: @@ -214,7 +205,7 @@ private: template void Fallback::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, - arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, + arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os) { arm_gemm::GemmConfig gemm_cfg; @@ -287,7 +278,7 @@ void Fallback::prepare() // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. if(_c && _c->info()->data_type() == DataType::S32) { - _gemm_kernel_asm->set_quantized_bias(reinterpret_cast(_c->buffer() + _c->info()->offset_first_element_in_bytes())); + _gemm_kernel_asm->set_quantized_bias(reinterpret_cast(_c->buffer() + _c->info()->offset_first_element_in_bytes()), 0); } // Pretranspose B if required @@ -383,83 +374,76 @@ void Fallback::run() // Prepare assembly kernel prepare(); + TypeOutput *bias = nullptr; + // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. + if(_c && _c->info()->data_type() != DataType::S32) + { + bias = reinterpret_cast(_c->buffer() + _c->info()->offset_first_element_in_bytes()); + } // Set gemm parameters - _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d); + _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, + in1_ptr, ldb, multi_stride_b, + out_ptr, ldd, batch_stride_d, multi_stride_d, + bias, 0); // Schedule assembly kernel NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX); } template -void create_function_or_arm_gemm(std::unique_ptr &acl_function, std::unique_ptr &arm_gemm, MemoryGroup &memory_group, - const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info, - std::shared_ptr memory_manager, IWeightsManager *weights_manager) +void create_arm_gemm(std::unique_ptr &arm_gemm, MemoryGroup &memory_group, + const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info, + IWeightsManager *weights_manager) { INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info); const CPUInfo &ci = NEScheduler::get().cpu_info(); unsigned int num_threads = NEScheduler::get().num_threads(); - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B()); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B()); - // Try to create an ACL function: - const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method(args); - acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager); - - // If we still don't have an ACL function: - if(acl_function == nullptr) - { - //Fallback onto arm_gemm function if ACL doesn't support this method. - auto fallback = support::cpp14::make_unique>(); - fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager); - arm_gemm = std::move(fallback); - } + // Create arm_gemm fallback + auto fallback = support::cpp14::make_unique>(); + fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager); + arm_gemm = std::move(fallback); } template -void create_function_or_arm_gemm_quant(std::unique_ptr &acl_function, std::unique_ptr &arm_gemm, MemoryGroup &memory_group, - const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info, - std::shared_ptr memory_manager, IWeightsManager *weights_manager) +void create_arm_gemm_quant(std::unique_ptr &arm_gemm, MemoryGroup &memory_group, + const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info, + IWeightsManager *weights_manager) { INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info); const CPUInfo &ci = NEScheduler::get().cpu_info(); unsigned int num_threads = NEScheduler::get().num_threads(); - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B()); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B()); // Configure requantization info const int32_t a_offset = -a->info()->quantization_info().uniform().offset; const int32_t b_offset = -b->info()->quantization_info().uniform().offset; const GEMMLowpOutputStageInfo os_info = gemm_info.gemmlowp_output_stage(); - const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr, + const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr, 0, a_offset, b_offset, os_info.gemmlowp_offset, -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier, os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); - // Try to create an ACL function: - const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method(args, gemm_requant_info); - acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager); - - // If we still don't have an ACL function: - if(acl_function == nullptr) - { - // Fallback onto arm_gemm function if ACL doesn't support this method. - auto fallback = support::cpp14::make_unique>(); - fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info); - arm_gemm = std::move(fallback); - } + // Create arm_gemm fallback + auto fallback = support::cpp14::make_unique>(); + fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info); + arm_gemm = std::move(fallback); } } //namespace NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr memory_manager, IWeightsManager *weights_manager) - : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager), _weights_manager(weights_manager) + : _arm_gemm(nullptr), _memory_group(std::move(memory_manager)), _weights_manager(weights_manager) { } -Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info) +Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const GEMMInfo &gemm_info) { - ARM_COMPUTE_UNUSED(alpha, beta, gemm_info); + ARM_COMPUTE_UNUSED(gemm_info); ARM_COMPUTE_UNUSED(c); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); @@ -476,12 +460,19 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo return Status{}; } -void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info) +bool NEGEMMAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation) +{ + arm_gemm::Activation act = map_to_arm_gemm_activation(activation); + return act.type != arm_gemm::Activation::Type::None; +} + +void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const GEMMInfo &gemm_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); + arm_gemm::Activation act = map_to_arm_gemm_activation(gemm_info.activation_info()); //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured() - if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), alpha, beta, gemm_info)) + if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), gemm_info)) { return; } @@ -489,27 +480,27 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const switch(a->info()->data_type()) { case DataType::F32: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); + create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); break; #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: if(d->info()->data_type() == DataType::S32) { - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); + create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); } else { - create_function_or_arm_gemm_quant(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); + create_arm_gemm_quant(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); } break; case DataType::S8: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); + create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); break; #endif /* __aarch64__ */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); + create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ default: @@ -519,33 +510,20 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const void NEGEMMAssemblyDispatch::prepare() { - if(_function != nullptr) - { - _function->prepare(); - } - else - { - ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); - _arm_gemm->prepare(); - } + ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); + _arm_gemm->prepare(); } bool NEGEMMAssemblyDispatch::is_configured() const { - return (_arm_gemm != nullptr && _arm_gemm->is_configured()) || _function != nullptr; + return _arm_gemm != nullptr && _arm_gemm->is_configured(); } void NEGEMMAssemblyDispatch::run() { MemoryGroupResourceScope scope_mg(_memory_group); - if(_function != nullptr) - { - _function->run(); - } - else - { - ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); - _arm_gemm->run(); - } + + ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); + _arm_gemm->run(); } } //namespace arm_compute -- cgit v1.2.1