diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | 180 |
1 files changed, 79 insertions, 101 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 956ded55d2..b31ecb91e9 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -24,10 +24,8 @@ #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" #include "arm_compute/core/CPP/Validate.h" -#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h" -#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h" #include <arm_neon.h> @@ -35,43 +33,36 @@ namespace arm_compute { namespace { -std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info, - const ITensor *a, const ITensor *b, ITensor *d, - float alpha, float beta, const GEMMInfo &gemm_info, - std::shared_ptr<IMemoryManager> memory_manager, - IWeightsManager *weights_manager) - +arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act) { - // Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure() - switch(gemm_kernel_info.method) + arm_gemm::Activation gemm_act; + + // Early exit in case lower bound is other than 0, as it's not yet supported + if(act.b() != 0.f) { - case arm_gemm::GemmMethod::GEMM_INTERLEAVED: - { - if(!gemm_info.pretranpose_B()) - { - return nullptr; - } - auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager, weights_manager); - function->configure(a, b, d, alpha, beta, gemm_info); - return std::move(function); - } -#if defined(__aarch64__) - case arm_gemm::GemmMethod::GEMM_NATIVE: - { - if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos) - { - auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>(); - kernel->configure(a, b, d, alpha, beta, gemm_info); - auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>(); - function->configure(std::move(kernel)); - return std::move(function); - } - return nullptr; - } -#endif // defined(__aarch64__) + return gemm_act; + } + + switch(act.activation()) + { + case ActivationLayerInfo::ActivationFunction::RELU: + gemm_act.type = arm_gemm::Activation::Type::ReLU; + break; + case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU: + gemm_act.type = arm_gemm::Activation::Type::BoundedReLU; + gemm_act.param1 = act.a(); + gemm_act.param2 = 0.f; + break; + case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU: + gemm_act.type = arm_gemm::Activation::Type::BoundedReLU; + gemm_act.param1 = act.a(); + gemm_act.param2 = act.b(); + break; default: - return nullptr; + gemm_act.type = arm_gemm::Activation::Type::None; } + + return gemm_act; } template <typename TypeInput, typename TypeOutput> @@ -161,7 +152,7 @@ public: * @param[in] os Output stage meta-data. */ void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, - arm_gemm::GemmArgs<TypeOutput> args, const GEMMInfo &gemm_info, + arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {}); // Inherited methods overridden: @@ -214,7 +205,7 @@ private: template <typename TypeInput, typename TypeOutput, class OutputStage> void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, - arm_gemm::GemmArgs<TypeOutput> args, const GEMMInfo &gemm_info, + arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os) { arm_gemm::GemmConfig gemm_cfg; @@ -287,7 +278,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare() // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. if(_c && _c->info()->data_type() == DataType::S32) { - _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(_c->buffer() + _c->info()->offset_first_element_in_bytes())); + _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(_c->buffer() + _c->info()->offset_first_element_in_bytes()), 0); } // Pretranspose B if required @@ -383,83 +374,76 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run() // Prepare assembly kernel prepare(); + TypeOutput *bias = nullptr; + // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. + if(_c && _c->info()->data_type() != DataType::S32) + { + bias = reinterpret_cast<TypeOutput *>(_c->buffer() + _c->info()->offset_first_element_in_bytes()); + } // Set gemm parameters - _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d); + _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, + in1_ptr, ldb, multi_stride_b, + out_ptr, ldd, batch_stride_d, multi_stride_d, + bias, 0); // Schedule assembly kernel NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX); } template <typename TypeInput, typename TypeOutput> -void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, - const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info, - std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager) +void create_arm_gemm(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, + const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info, + IWeightsManager *weights_manager) { INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info); const CPUInfo &ci = NEScheduler::get().cpu_info(); unsigned int num_threads = NEScheduler::get().num_threads(); - arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B()); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B()); - // Try to create an ACL function: - const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args); - acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager); - - // If we still don't have an ACL function: - if(acl_function == nullptr) - { - //Fallback onto arm_gemm function if ACL doesn't support this method. - auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>(); - fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager); - arm_gemm = std::move(fallback); - } + // Create arm_gemm fallback + auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>(); + fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager); + arm_gemm = std::move(fallback); } template <typename TypeInput, typename TypeOutput> -void create_function_or_arm_gemm_quant(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, - const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info, - std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager) +void create_arm_gemm_quant(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, + const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info, + IWeightsManager *weights_manager) { INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info); const CPUInfo &ci = NEScheduler::get().cpu_info(); unsigned int num_threads = NEScheduler::get().num_threads(); - arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B()); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B()); // Configure requantization info const int32_t a_offset = -a->info()->quantization_info().uniform().offset; const int32_t b_offset = -b->info()->quantization_info().uniform().offset; const GEMMLowpOutputStageInfo os_info = gemm_info.gemmlowp_output_stage(); - const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr, + const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr, 0, a_offset, b_offset, os_info.gemmlowp_offset, -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier, os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); - // Try to create an ACL function: - const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args, gemm_requant_info); - acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager); - - // If we still don't have an ACL function: - if(acl_function == nullptr) - { - // Fallback onto arm_gemm function if ACL doesn't support this method. - auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::ARequantizeLayer32>>(); - fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info); - arm_gemm = std::move(fallback); - } + // Create arm_gemm fallback + auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::ARequantizeLayer32>>(); + fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info); + arm_gemm = std::move(fallback); } } //namespace NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager) - : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager), _weights_manager(weights_manager) + : _arm_gemm(nullptr), _memory_group(std::move(memory_manager)), _weights_manager(weights_manager) { } -Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info) +Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const GEMMInfo &gemm_info) { - ARM_COMPUTE_UNUSED(alpha, beta, gemm_info); + ARM_COMPUTE_UNUSED(gemm_info); ARM_COMPUTE_UNUSED(c); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); @@ -476,12 +460,19 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo return Status{}; } -void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info) +bool NEGEMMAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation) +{ + arm_gemm::Activation act = map_to_arm_gemm_activation(activation); + return act.type != arm_gemm::Activation::Type::None; +} + +void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const GEMMInfo &gemm_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); + arm_gemm::Activation act = map_to_arm_gemm_activation(gemm_info.activation_info()); //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured() - if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), alpha, beta, gemm_info)) + if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), gemm_info)) { return; } @@ -489,27 +480,27 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const switch(a->info()->data_type()) { case DataType::F32: - create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); + create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); break; #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: if(d->info()->data_type() == DataType::S32) { - create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); + create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); } else { - create_function_or_arm_gemm_quant<uint8_t, uint8_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); + create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); } break; case DataType::S8: - create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); + create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); break; #endif /* __aarch64__ */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); + create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ default: @@ -519,33 +510,20 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const void NEGEMMAssemblyDispatch::prepare() { - if(_function != nullptr) - { - _function->prepare(); - } - else - { - ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); - _arm_gemm->prepare(); - } + ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); + _arm_gemm->prepare(); } bool NEGEMMAssemblyDispatch::is_configured() const { - return (_arm_gemm != nullptr && _arm_gemm->is_configured()) || _function != nullptr; + return _arm_gemm != nullptr && _arm_gemm->is_configured(); } void NEGEMMAssemblyDispatch::run() { MemoryGroupResourceScope scope_mg(_memory_group); - if(_function != nullptr) - { - _function->run(); - } - else - { - ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); - _arm_gemm->run(); - } + + ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); + _arm_gemm->run(); } } //namespace arm_compute |