aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-10-14 19:03:09 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-10-23 12:08:12 +0000
commit48b3ef89de5f21a0169d8416e3d54081f82c7bf8 (patch)
treef857d733ccf446c704823dc7ac796a96eb55095e /src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
parent1dce3101ef8d77c8cf0af7dfd4af6595a0136b91 (diff)
downloadComputeLibrary-48b3ef89de5f21a0169d8416e3d54081f82c7bf8.tar.gz
COMPMID-2577: Fuse bias addition and activation in gemm assembly kernels
Change-Id: I7f52112d2d05b1ea3d3f3d4b19b8eafab05d6c44 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/2141 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp180
1 files changed, 79 insertions, 101 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 956ded55d2..b31ecb91e9 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -24,10 +24,8 @@
#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
#include "arm_compute/core/CPP/Validate.h"
-#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
-#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
#include <arm_neon.h>
@@ -35,43 +33,36 @@ namespace arm_compute
{
namespace
{
-std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info,
- const ITensor *a, const ITensor *b, ITensor *d,
- float alpha, float beta, const GEMMInfo &gemm_info,
- std::shared_ptr<IMemoryManager> memory_manager,
- IWeightsManager *weights_manager)
-
+arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
{
- // Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
- switch(gemm_kernel_info.method)
+ arm_gemm::Activation gemm_act;
+
+ // Early exit in case lower bound is other than 0, as it's not yet supported
+ if(act.b() != 0.f)
{
- case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
- {
- if(!gemm_info.pretranpose_B())
- {
- return nullptr;
- }
- auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager, weights_manager);
- function->configure(a, b, d, alpha, beta, gemm_info);
- return std::move(function);
- }
-#if defined(__aarch64__)
- case arm_gemm::GemmMethod::GEMM_NATIVE:
- {
- if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos)
- {
- auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
- kernel->configure(a, b, d, alpha, beta, gemm_info);
- auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
- function->configure(std::move(kernel));
- return std::move(function);
- }
- return nullptr;
- }
-#endif // defined(__aarch64__)
+ return gemm_act;
+ }
+
+ switch(act.activation())
+ {
+ case ActivationLayerInfo::ActivationFunction::RELU:
+ gemm_act.type = arm_gemm::Activation::Type::ReLU;
+ break;
+ case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
+ gemm_act.type = arm_gemm::Activation::Type::BoundedReLU;
+ gemm_act.param1 = act.a();
+ gemm_act.param2 = 0.f;
+ break;
+ case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
+ gemm_act.type = arm_gemm::Activation::Type::BoundedReLU;
+ gemm_act.param1 = act.a();
+ gemm_act.param2 = act.b();
+ break;
default:
- return nullptr;
+ gemm_act.type = arm_gemm::Activation::Type::None;
}
+
+ return gemm_act;
}
template <typename TypeInput, typename TypeOutput>
@@ -161,7 +152,7 @@ public:
* @param[in] os Output stage meta-data.
*/
void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
- arm_gemm::GemmArgs<TypeOutput> args, const GEMMInfo &gemm_info,
+ arm_gemm::GemmArgs args, const GEMMInfo &gemm_info,
MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
// Inherited methods overridden:
@@ -214,7 +205,7 @@ private:
template <typename TypeInput, typename TypeOutput, class OutputStage>
void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
- arm_gemm::GemmArgs<TypeOutput> args, const GEMMInfo &gemm_info,
+ arm_gemm::GemmArgs args, const GEMMInfo &gemm_info,
MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
{
arm_gemm::GemmConfig gemm_cfg;
@@ -287,7 +278,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare()
// Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
if(_c && _c->info()->data_type() == DataType::S32)
{
- _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(_c->buffer() + _c->info()->offset_first_element_in_bytes()));
+ _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(_c->buffer() + _c->info()->offset_first_element_in_bytes()), 0);
}
// Pretranspose B if required
@@ -383,83 +374,76 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run()
// Prepare assembly kernel
prepare();
+ TypeOutput *bias = nullptr;
+ // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
+ if(_c && _c->info()->data_type() != DataType::S32)
+ {
+ bias = reinterpret_cast<TypeOutput *>(_c->buffer() + _c->info()->offset_first_element_in_bytes());
+ }
// Set gemm parameters
- _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
+ _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a,
+ in1_ptr, ldb, multi_stride_b,
+ out_ptr, ldd, batch_stride_d, multi_stride_d,
+ bias, 0);
// Schedule assembly kernel
NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
}
template <typename TypeInput, typename TypeOutput>
-void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
- const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
- std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+void create_arm_gemm(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
+ const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info,
+ IWeightsManager *weights_manager)
{
INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B());
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B());
- // Try to create an ACL function:
- const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args);
- acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager);
-
- // If we still don't have an ACL function:
- if(acl_function == nullptr)
- {
- //Fallback onto arm_gemm function if ACL doesn't support this method.
- auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
- fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager);
- arm_gemm = std::move(fallback);
- }
+ // Create arm_gemm fallback
+ auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
+ fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager);
+ arm_gemm = std::move(fallback);
}
template <typename TypeInput, typename TypeOutput>
-void create_function_or_arm_gemm_quant(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
- const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
- std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+void create_arm_gemm_quant(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
+ const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info,
+ IWeightsManager *weights_manager)
{
INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B());
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B());
// Configure requantization info
const int32_t a_offset = -a->info()->quantization_info().uniform().offset;
const int32_t b_offset = -b->info()->quantization_info().uniform().offset;
const GEMMLowpOutputStageInfo os_info = gemm_info.gemmlowp_output_stage();
- const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr,
+ const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr, 0,
a_offset, b_offset, os_info.gemmlowp_offset,
-os_info.gemmlowp_shift, os_info.gemmlowp_multiplier,
os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
- // Try to create an ACL function:
- const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args, gemm_requant_info);
- acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager);
-
- // If we still don't have an ACL function:
- if(acl_function == nullptr)
- {
- // Fallback onto arm_gemm function if ACL doesn't support this method.
- auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::ARequantizeLayer32>>();
- fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info);
- arm_gemm = std::move(fallback);
- }
+ // Create arm_gemm fallback
+ auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::ARequantizeLayer32>>();
+ fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info);
+ arm_gemm = std::move(fallback);
}
} //namespace
NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
- : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager), _weights_manager(weights_manager)
+ : _arm_gemm(nullptr), _memory_group(std::move(memory_manager)), _weights_manager(weights_manager)
{
}
-Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
+Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const GEMMInfo &gemm_info)
{
- ARM_COMPUTE_UNUSED(alpha, beta, gemm_info);
+ ARM_COMPUTE_UNUSED(gemm_info);
ARM_COMPUTE_UNUSED(c);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
@@ -476,12 +460,19 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
return Status{};
}
-void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info)
+bool NEGEMMAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)
+{
+ arm_gemm::Activation act = map_to_arm_gemm_activation(activation);
+ return act.type != arm_gemm::Activation::Type::None;
+}
+
+void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
+ arm_gemm::Activation act = map_to_arm_gemm_activation(gemm_info.activation_info());
//If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
- if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), alpha, beta, gemm_info))
+ if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), gemm_info))
{
return;
}
@@ -489,27 +480,27 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const
switch(a->info()->data_type())
{
case DataType::F32:
- create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
+ create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
if(d->info()->data_type() == DataType::S32)
{
- create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
+ create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
}
else
{
- create_function_or_arm_gemm_quant<uint8_t, uint8_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
+ create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
}
break;
case DataType::S8:
- create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
+ create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
break;
#endif /* __aarch64__ */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
+ create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default:
@@ -519,33 +510,20 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const
void NEGEMMAssemblyDispatch::prepare()
{
- if(_function != nullptr)
- {
- _function->prepare();
- }
- else
- {
- ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
- _arm_gemm->prepare();
- }
+ ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
+ _arm_gemm->prepare();
}
bool NEGEMMAssemblyDispatch::is_configured() const
{
- return (_arm_gemm != nullptr && _arm_gemm->is_configured()) || _function != nullptr;
+ return _arm_gemm != nullptr && _arm_gemm->is_configured();
}
void NEGEMMAssemblyDispatch::run()
{
MemoryGroupResourceScope scope_mg(_memory_group);
- if(_function != nullptr)
- {
- _function->run();
- }
- else
- {
- ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
- _arm_gemm->run();
- }
+
+ ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
+ _arm_gemm->run();
}
} //namespace arm_compute