aboutsummaryrefslogtreecommitdiff
path: root/src/runtime
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-10-14 19:03:09 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-10-23 12:08:12 +0000
commit48b3ef89de5f21a0169d8416e3d54081f82c7bf8 (patch)
treef857d733ccf446c704823dc7ac796a96eb55095e /src/runtime
parent1dce3101ef8d77c8cf0af7dfd4af6595a0136b91 (diff)
downloadComputeLibrary-48b3ef89de5f21a0169d8416e3d54081f82c7bf8.tar.gz
COMPMID-2577: Fuse bias addition and activation in gemm assembly kernels
Change-Id: I7f52112d2d05b1ea3d3f3d4b19b8eafab05d6c44 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/2141 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
Diffstat (limited to 'src/runtime')
-rw-r--r--src/runtime/NEON/functions/NEGEMM.cpp100
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp180
-rw-r--r--src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp118
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp2
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp35
-rw-r--r--src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp430
6 files changed, 227 insertions, 638 deletions
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index df92b7999c..baa22b7d32 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -34,7 +34,6 @@
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
#include "arm_compute/runtime/TensorAllocator.h"
-#include "support/ToolchainSupport.h"
#include <cmath>
@@ -43,8 +42,9 @@ using namespace arm_compute::misc::shape_calculator;
namespace arm_compute
{
NEGEMM::NEGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
- : _memory_group(memory_manager), _weights_manager(weights_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(memory_manager, weights_manager), _ma_kernel(), _tmp_a(),
- _tmp_b(), _original_b(nullptr), _run_vector_matrix_multiplication(false), _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
+ : _memory_group(memory_manager), _weights_manager(weights_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(memory_manager, weights_manager), _ma_kernel(),
+ _alpha_scale_func(nullptr), _add_bias_kernel(), _activation_func(), _tmp_a(), _tmp_b(), _tmp_d(), _original_b(nullptr), _run_vector_matrix_multiplication(false), _run_alpha_scale(false),
+ _run_addition(false), _run_bias_addition(false), _run_activation(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
{
}
@@ -52,34 +52,55 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
{
ARM_COMPUTE_ERROR_THROW_ON(NEGEMM::validate(a->info(), b->info(), (c != nullptr) ? c->info() : nullptr, d->info(), alpha, beta, gemm_info));
+ const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
+ bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), (is_c_bias && c != nullptr) ? c->info() : nullptr, d->info(), gemm_info));
+
// Check if we need to reshape the matrix B only on the first run
_is_prepared = false;
_reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
_run_vector_matrix_multiplication = a->info()->dimension(1) < 2;
_original_b = b;
-
- bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), alpha, beta, gemm_info));
+ _run_alpha_scale = alpha != 1.f;
+ _run_bias_addition = c != nullptr && gemm_info.reshape_b_only_on_first_run();
+ _run_addition = beta != 0 && c != nullptr && !gemm_info.reshape_b_only_on_first_run();
+ _run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised && !NEGEMMAssemblyDispatch::is_activation_supported(gemm_info.activation_info())));
if(run_optimised)
{
+ const ITensor *c_to_use = is_c_bias ? c : nullptr;
if(MEMInfo::get_policy() == MemoryPolicy::MINIMIZE)
{
GEMMInfo gemm_info_ntb = gemm_info;
gemm_info_ntb.set_pretranpose_B(false);
- _asm_glue.configure(a, b, c, d, alpha, beta, gemm_info_ntb);
+ _asm_glue.configure(a, b, c_to_use, d, gemm_info_ntb);
}
else
{
- _asm_glue.configure(a, b, c, d, alpha, beta, gemm_info);
+ _asm_glue.configure(a, b, c_to_use, d, gemm_info);
}
ARM_COMPUTE_ERROR_ON(!_asm_glue.is_configured());
+
+ // Scale product by alpha
+ if(_run_alpha_scale)
+ {
+ _alpha_scale_func.configure(d, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LINEAR, alpha, 0.f));
+ }
}
else
{
+ // Pick output tensor in case bias addition should be performed
+ ITensor *gemm_output_to_use = d;
+ if(_run_bias_addition)
+ {
+ gemm_output_to_use = &_tmp_d;
+ _memory_group.manage(&_tmp_d);
+ }
+
+ // Select between GEMV and GEMM
if(_run_vector_matrix_multiplication)
{
// Configure the matrix multiply kernel
- _mm_kernel.configure(a, b, d, alpha, false);
+ _mm_kernel.configure(a, b, gemm_output_to_use, alpha, false);
}
else
{
@@ -117,7 +138,7 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
_transpose_kernel.configure(b, &_tmp_b);
// Configure matrix multiplication kernel
- _mm_kernel.configure(&_tmp_a, &_tmp_b, d, alpha, true, GEMMReshapeInfo(m, n, k));
+ _mm_kernel.configure(&_tmp_a, &_tmp_b, gemm_output_to_use, alpha, true, GEMMReshapeInfo(m, n, k));
// Allocate once the all configure methods have been called
_tmp_a.allocator()->allocate();
@@ -127,18 +148,31 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
}
}
- // Configure matrix addition kernel
- if(beta != 0 && c != nullptr)
+ if(_run_bias_addition)
{
- _ma_kernel.configure(c, d, beta);
- _run_addition = true;
+ _add_bias_kernel.configure(gemm_output_to_use, c, d, ConvertPolicy::SATURATE);
+ _tmp_d.allocator()->allocate();
}
}
+
+ // Configure matrix addition kernel
+ if(_run_addition)
+ {
+ _ma_kernel.configure(c, d, beta);
+ }
+
+ // Configure activation
+ const ActivationLayerInfo &activation = gemm_info.activation_info();
+ if(_run_activation)
+ {
+ _activation_func.configure(d, nullptr, activation);
+ }
}
Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
+ const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F16, DataType::F32);
@@ -147,7 +181,7 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
- if(c != nullptr)
+ if(c != nullptr && !is_c_bias)
{
ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0);
ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d());
@@ -178,7 +212,7 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
}
// Check if we need to run the optimized assembly kernel
- const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, c, output, alpha, beta, gemm_info));
+ const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, output, gemm_info));
if(!run_optimised)
{
@@ -225,14 +259,26 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
// Validate matrix multiply
auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info));
+
+ if(c != nullptr && gemm_info.reshape_b_only_on_first_run())
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&tmp_output_info, c, output, ConvertPolicy::SATURATE));
+ }
}
// Validate matrix addition kernel
- if(beta != 0 && c != nullptr)
+ if(beta != 0 && c != nullptr && !is_c_bias)
{
ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixAdditionKernel::validate(c, output, beta));
}
+ // Validate activation
+ const ActivationLayerInfo &activation = gemm_info.activation_info();
+ if(activation.enabled())
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, activation));
+ }
+
return Status{};
}
@@ -245,6 +291,10 @@ void NEGEMM::run()
if(_asm_glue.is_configured())
{
_asm_glue.run();
+ if(_run_alpha_scale)
+ {
+ _alpha_scale_func.run();
+ }
}
else
{
@@ -262,12 +312,24 @@ void NEGEMM::run()
NEScheduler::get().schedule(&_mm_kernel, _run_vector_matrix_multiplication ? Window::DimX : Window::DimY);
- // Run matrix addition kernel
- if(_run_addition)
+ // Run bias addition kernel
+ if(_run_bias_addition)
{
- NEScheduler::get().schedule(&_ma_kernel, Window::DimY);
+ NEScheduler::get().schedule(&_add_bias_kernel, Window::DimY);
}
}
+
+ // Run matrix addition kernel
+ if(_run_addition)
+ {
+ NEScheduler::get().schedule(&_ma_kernel, Window::DimY);
+ }
+
+ // Run activation function
+ if(_run_activation)
+ {
+ _activation_func.run();
+ }
}
void NEGEMM::prepare()
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 956ded55d2..b31ecb91e9 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -24,10 +24,8 @@
#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
#include "arm_compute/core/CPP/Validate.h"
-#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
-#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
#include <arm_neon.h>
@@ -35,43 +33,36 @@ namespace arm_compute
{
namespace
{
-std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info,
- const ITensor *a, const ITensor *b, ITensor *d,
- float alpha, float beta, const GEMMInfo &gemm_info,
- std::shared_ptr<IMemoryManager> memory_manager,
- IWeightsManager *weights_manager)
-
+arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
{
- // Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
- switch(gemm_kernel_info.method)
+ arm_gemm::Activation gemm_act;
+
+ // Early exit in case lower bound is other than 0, as it's not yet supported
+ if(act.b() != 0.f)
{
- case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
- {
- if(!gemm_info.pretranpose_B())
- {
- return nullptr;
- }
- auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager, weights_manager);
- function->configure(a, b, d, alpha, beta, gemm_info);
- return std::move(function);
- }
-#if defined(__aarch64__)
- case arm_gemm::GemmMethod::GEMM_NATIVE:
- {
- if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos)
- {
- auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
- kernel->configure(a, b, d, alpha, beta, gemm_info);
- auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
- function->configure(std::move(kernel));
- return std::move(function);
- }
- return nullptr;
- }
-#endif // defined(__aarch64__)
+ return gemm_act;
+ }
+
+ switch(act.activation())
+ {
+ case ActivationLayerInfo::ActivationFunction::RELU:
+ gemm_act.type = arm_gemm::Activation::Type::ReLU;
+ break;
+ case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
+ gemm_act.type = arm_gemm::Activation::Type::BoundedReLU;
+ gemm_act.param1 = act.a();
+ gemm_act.param2 = 0.f;
+ break;
+ case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
+ gemm_act.type = arm_gemm::Activation::Type::BoundedReLU;
+ gemm_act.param1 = act.a();
+ gemm_act.param2 = act.b();
+ break;
default:
- return nullptr;
+ gemm_act.type = arm_gemm::Activation::Type::None;
}
+
+ return gemm_act;
}
template <typename TypeInput, typename TypeOutput>
@@ -161,7 +152,7 @@ public:
* @param[in] os Output stage meta-data.
*/
void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
- arm_gemm::GemmArgs<TypeOutput> args, const GEMMInfo &gemm_info,
+ arm_gemm::GemmArgs args, const GEMMInfo &gemm_info,
MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
// Inherited methods overridden:
@@ -214,7 +205,7 @@ private:
template <typename TypeInput, typename TypeOutput, class OutputStage>
void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
- arm_gemm::GemmArgs<TypeOutput> args, const GEMMInfo &gemm_info,
+ arm_gemm::GemmArgs args, const GEMMInfo &gemm_info,
MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
{
arm_gemm::GemmConfig gemm_cfg;
@@ -287,7 +278,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare()
// Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
if(_c && _c->info()->data_type() == DataType::S32)
{
- _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(_c->buffer() + _c->info()->offset_first_element_in_bytes()));
+ _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(_c->buffer() + _c->info()->offset_first_element_in_bytes()), 0);
}
// Pretranspose B if required
@@ -383,83 +374,76 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run()
// Prepare assembly kernel
prepare();
+ TypeOutput *bias = nullptr;
+ // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
+ if(_c && _c->info()->data_type() != DataType::S32)
+ {
+ bias = reinterpret_cast<TypeOutput *>(_c->buffer() + _c->info()->offset_first_element_in_bytes());
+ }
// Set gemm parameters
- _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
+ _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a,
+ in1_ptr, ldb, multi_stride_b,
+ out_ptr, ldd, batch_stride_d, multi_stride_d,
+ bias, 0);
// Schedule assembly kernel
NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
}
template <typename TypeInput, typename TypeOutput>
-void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
- const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
- std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+void create_arm_gemm(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
+ const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info,
+ IWeightsManager *weights_manager)
{
INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B());
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B());
- // Try to create an ACL function:
- const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args);
- acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager);
-
- // If we still don't have an ACL function:
- if(acl_function == nullptr)
- {
- //Fallback onto arm_gemm function if ACL doesn't support this method.
- auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
- fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager);
- arm_gemm = std::move(fallback);
- }
+ // Create arm_gemm fallback
+ auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
+ fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager);
+ arm_gemm = std::move(fallback);
}
template <typename TypeInput, typename TypeOutput>
-void create_function_or_arm_gemm_quant(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
- const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
- std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+void create_arm_gemm_quant(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
+ const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info,
+ IWeightsManager *weights_manager)
{
INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B());
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B());
// Configure requantization info
const int32_t a_offset = -a->info()->quantization_info().uniform().offset;
const int32_t b_offset = -b->info()->quantization_info().uniform().offset;
const GEMMLowpOutputStageInfo os_info = gemm_info.gemmlowp_output_stage();
- const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr,
+ const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr, 0,
a_offset, b_offset, os_info.gemmlowp_offset,
-os_info.gemmlowp_shift, os_info.gemmlowp_multiplier,
os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
- // Try to create an ACL function:
- const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args, gemm_requant_info);
- acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager);
-
- // If we still don't have an ACL function:
- if(acl_function == nullptr)
- {
- // Fallback onto arm_gemm function if ACL doesn't support this method.
- auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::ARequantizeLayer32>>();
- fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info);
- arm_gemm = std::move(fallback);
- }
+ // Create arm_gemm fallback
+ auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::ARequantizeLayer32>>();
+ fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info);
+ arm_gemm = std::move(fallback);
}
} //namespace
NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
- : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager), _weights_manager(weights_manager)
+ : _arm_gemm(nullptr), _memory_group(std::move(memory_manager)), _weights_manager(weights_manager)
{
}
-Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
+Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const GEMMInfo &gemm_info)
{
- ARM_COMPUTE_UNUSED(alpha, beta, gemm_info);
+ ARM_COMPUTE_UNUSED(gemm_info);
ARM_COMPUTE_UNUSED(c);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
@@ -476,12 +460,19 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
return Status{};
}
-void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info)
+bool NEGEMMAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)
+{
+ arm_gemm::Activation act = map_to_arm_gemm_activation(activation);
+ return act.type != arm_gemm::Activation::Type::None;
+}
+
+void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
+ arm_gemm::Activation act = map_to_arm_gemm_activation(gemm_info.activation_info());
//If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
- if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), alpha, beta, gemm_info))
+ if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), gemm_info))
{
return;
}
@@ -489,27 +480,27 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const
switch(a->info()->data_type())
{
case DataType::F32:
- create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
+ create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
if(d->info()->data_type() == DataType::S32)
{
- create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
+ create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
}
else
{
- create_function_or_arm_gemm_quant<uint8_t, uint8_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
+ create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
}
break;
case DataType::S8:
- create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
+ create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
break;
#endif /* __aarch64__ */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
+ create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default:
@@ -519,33 +510,20 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const
void NEGEMMAssemblyDispatch::prepare()
{
- if(_function != nullptr)
- {
- _function->prepare();
- }
- else
- {
- ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
- _arm_gemm->prepare();
- }
+ ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
+ _arm_gemm->prepare();
}
bool NEGEMMAssemblyDispatch::is_configured() const
{
- return (_arm_gemm != nullptr && _arm_gemm->is_configured()) || _function != nullptr;
+ return _arm_gemm != nullptr && _arm_gemm->is_configured();
}
void NEGEMMAssemblyDispatch::run()
{
MemoryGroupResourceScope scope_mg(_memory_group);
- if(_function != nullptr)
- {
- _function->run();
- }
- else
- {
- ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
- _arm_gemm->run();
- }
+
+ ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
+ _arm_gemm->run();
}
} //namespace arm_compute
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index 0034dd2545..f4377cdaf2 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -29,9 +29,7 @@
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
-#include "support/ToolchainSupport.h"
-#include <cmath>
#include <set>
#include <tuple>
@@ -90,19 +88,27 @@ void NEConvolutionLayerReshapeWeights::run()
NEGEMMConvolutionLayer::NEGEMMConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager, IWeightsManager *weights_manager)
: _memory_group(memory_manager), _weights_manager(weights_manager), _reshape_weights(), _reshape_weights_managed(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager),
- _col2im_kernel(), _activationlayer_function(), _add_bias_kernel(), _reshape_layer(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(),
- _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
+ _col2im_kernel(), _reshape_layer(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _skip_im2col(false),
+ _skip_col2im(false), _is_quantized(false), _is_prepared(false)
{
}
void NEGEMMConvolutionLayer::configure_mm(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const ActivationLayerInfo &act_info, int gemm_3d_depth)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
- ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), biases == nullptr ? nullptr : biases->info(), output == nullptr ? nullptr : output->info(), act_info, gemm_3d_depth,
- _skip_im2col));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), biases == nullptr ? nullptr : biases->info(), output == nullptr ? nullptr : output->info(),
+ act_info, gemm_3d_depth, _skip_im2col));
+ // Create GEMMInfo structure
const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
- gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
+ gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
+ false, GEMMLowpOutputStageInfo(), false, false, act_info);
+
+ // Supported activations in GEMM
+ const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
+ ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
+ ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
+ };
if(_is_quantized)
{
@@ -125,19 +131,13 @@ void NEGEMMConvolutionLayer::configure_mm(const ITensor *input, const ITensor *w
int min_activation = 0;
int max_activation = 255;
- const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
- ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
- ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
- };
- if(_is_activationlayer_enabled && supported_acts.count(act_info.activation()) != 0)
+ if(supported_acts.count(act_info.activation()) != 0)
{
const int a_const_int = quantize_qasymm8(act_info.a(), oqinfo);
const int b_const_int = quantize_qasymm8(act_info.b(), oqinfo);
min_activation = act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU ? oqinfo.offset : b_const_int;
max_activation = act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ? 255 : a_const_int;
-
- _is_activationlayer_enabled = false;
}
GEMMLowpOutputStageInfo output_info;
@@ -157,18 +157,21 @@ void NEGEMMConvolutionLayer::configure_mm(const ITensor *input, const ITensor *w
else
{
// Configure matrix multiply function
- _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, gemm_info);
+ _mm_gemm.configure(input, weights, biases, output, 1.0f, 0.0f, gemm_info);
}
}
-Status NEGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ActivationLayerInfo &act_info,
- int gemm_3d_depth, bool skip_im2col)
+Status NEGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
+ const ActivationLayerInfo &act_info, int gemm_3d_depth, bool skip_im2col)
{
const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
const bool is_activation_enabled = act_info.enabled();
- const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
- gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
+ // Create GEMMInfo structure
+ const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
+ gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
+ false, GEMMLowpOutputStageInfo(), false, false, act_info);
+
if(is_quantized)
{
// Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -241,7 +244,7 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
- ARM_COMPUTE_UNUSED(num_groups);
+ ARM_COMPUTE_UNUSED(num_groups, weights_info);
ARM_COMPUTE_ERROR_THROW_ON(NEGEMMConvolutionLayer::validate(input->info(),
weights->info(),
biases != nullptr ? biases->info() : nullptr,
@@ -261,13 +264,11 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
const unsigned int kernel_width = weights->info()->dimension(idx_width);
const unsigned int kernel_height = weights->info()->dimension(idx_height);
- _is_prepared = weights_info.retain_internal_weights();
- _original_weights = weights;
- _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
- _data_layout = data_layout;
- _skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
- _append_bias = (biases != nullptr) && (!_is_quantized);
- _is_activationlayer_enabled = act_info.enabled();
+ _is_prepared = weights_info.retain_internal_weights();
+ _original_weights = weights;
+ _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
+ _data_layout = data_layout;
+ _skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
const ITensor *gemm_input_to_use = input;
ITensor *gemm_output_to_use = output;
@@ -297,8 +298,6 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
_skip_col2im = false;
}
- const ITensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr;
-
// Get parameters from conv_info
unsigned int stride_x = 0;
unsigned int stride_y = 0;
@@ -312,12 +311,12 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
if(_weights_manager && _weights_manager->are_weights_managed(weights))
{
- _reshape_weights_managed.configure(weights, biases_to_use);
+ _reshape_weights_managed.configure(weights, nullptr);
weights_to_use = _weights_manager->acquire(weights, &_reshape_weights_managed);
}
else
{
- _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped);
+ _reshape_weights.configure(weights, nullptr, &_weights_reshaped);
weights_to_use = &_weights_reshaped;
}
@@ -327,16 +326,11 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
_memory_group.manage(&_im2col_output);
// Configure
- _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation);
+ _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, false, dilation);
// Update GEMM input
gemm_input_to_use = &_im2col_output;
}
- else if(_append_bias)
- {
- // Configure add bias kernel
- _add_bias_kernel.configure(output, biases, output, ConvertPolicy::SATURATE);
- }
// Create temporary GEMM output tensor in case we cannot skip col2im
if(!_skip_col2im)
@@ -394,14 +388,6 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(idx_width) != conv_w) || (output->info()->dimension(idx_height) != conv_h),
"Output shape does not match the expected one");
-
- // Configure Activation Layer
- if(_is_activationlayer_enabled)
- {
- _activationlayer_function.configure(output, nullptr, act_info);
- }
-
- ARM_COMPUTE_UNUSED(weights_info);
}
Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
@@ -432,10 +418,9 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
const ITensorInfo *gemm_output_to_use = output;
const ITensorInfo *weights_to_use = weights;
- const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
- const bool append_bias = (biases != nullptr) && (!is_quantized);
- bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
- bool is_activation_enabled = act_info.enabled();
+ const bool append_bias = false;
+ const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
+ bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
// Get convolved dimensions
unsigned int conv_w = 0;
@@ -470,9 +455,6 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
}
}
- const unsigned bias_element = (append_bias && !skip_im2col) ? 1 : 0;
- const ITensorInfo *biases_to_use = (append_bias && !skip_im2col) ? biases : nullptr;
-
ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != input->dimension(idx_channel));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -491,17 +473,12 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
}
- if(act_info.enabled())
- {
- ARM_COMPUTE_ERROR_ON(act_info.b() > act_info.a());
- }
-
unsigned int mat_weights_cols = weights->dimension(idx_kernels);
- unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel) + bias_element;
+ unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel);
// Output tensor auto inizialization if not yet initialized
- ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases_to_use, nullptr));
- weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, (append_bias && !skip_im2col)), 1, data_type);
+ ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, nullptr, nullptr));
+ weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, data_type);
weights_reshaped_info.set_quantization_info(weights->quantization_info());
weights_to_use = &weights_reshaped_info;
@@ -521,11 +498,6 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation));
gemm_input_to_use = &im2col_reshaped_info;
}
- else if(append_bias)
- {
- // Validate add bias kernel
- ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(output, biases, output, ConvertPolicy::SATURATE));
- }
// Create temporary GEMM output tensor in case we cannot skip col2im
if(!skip_col2im)
@@ -549,12 +521,6 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(gemm_output_to_use, output, Size2D(conv_w, conv_h)));
}
- //Validate Activation Layer
- if(is_activation_enabled)
- {
- ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, act_info));
- }
-
return Status{};
}
@@ -583,11 +549,6 @@ void NEGEMMConvolutionLayer::run()
_mm_gemm.run();
}
- if(_skip_im2col && _append_bias)
- {
- NEScheduler::get().schedule(&_add_bias_kernel, Window::DimY);
- }
-
// Reshape output matrix
if(!_skip_col2im)
{
@@ -600,11 +561,6 @@ void NEGEMMConvolutionLayer::run()
_reshape_layer.run();
}
}
-
- if(_is_activationlayer_enabled)
- {
- _activationlayer_function.run();
- }
}
void NEGEMMConvolutionLayer::prepare()
diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
index aa40113c5e..346d025fd2 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
@@ -59,7 +59,7 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITe
case DataType::QASYMM8:
case DataType::U8:
{
- _asm_glue.configure(a, b, c, output, 1.f, 0.f, GEMMInfo(false, false, true));
+ _asm_glue.configure(a, b, c, output, GEMMInfo(false, false, true));
run_optimised = _asm_glue.is_configured();
break;
}
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index a03ec108c6..617d66cf24 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -42,8 +42,9 @@ using namespace arm_compute::misc::shape_calculator;
NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(memory_manager), _asm_glue(memory_manager), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(),
- _offset_contribution_kernel(), _offset_contribution_output_stage_kernel(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _mm_result_s32(), _original_b(nullptr), _a_offset(0), _b_offset(0),
- _run_vector_matrix_multiplication(false), _assembly_path(false), _fused_assembly_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _fuse_output_stage(false)
+ _offset_contribution_kernel(), _offset_contribution_output_stage_kernel(), _activation_func(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _mm_result_s32(), _original_b(nullptr),
+ _a_offset(0), _b_offset(0), _run_vector_matrix_multiplication(false), _assembly_path(false), _fused_assembly_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false),
+ _fuse_output_stage(false), _run_activation(false)
{
}
@@ -87,12 +88,12 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
{
if(a->info()->data_type() == DataType::QASYMM8 && gemm_info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
{
- _asm_glue.configure(a, b, c, output, 1.f, 0.f, gemm_info);
+ _asm_glue.configure(a, b, c, output, gemm_info);
_fused_assembly_path = _asm_glue.is_configured();
}
else
{
- _asm_glue.configure(a, b, nullptr, _fuse_output_stage ? &_mm_result_s32 : output, 1.f, 0.f, gemm_info);
+ _asm_glue.configure(a, b, nullptr, _fuse_output_stage ? &_mm_result_s32 : output, gemm_info);
}
_assembly_path = _asm_glue.is_configured();
break;
@@ -192,6 +193,14 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
}
}
+ // Configure activation
+ const ActivationLayerInfo &activation = gemm_info.activation_info();
+ _run_activation = activation.enabled() && (!_assembly_path || (_assembly_path && !NEGEMMAssemblyDispatch::is_activation_supported(activation)));
+ if(_run_activation)
+ {
+ _activation_func.configure(output, nullptr, activation);
+ }
+
// Allocate tensors
if(!_assembly_path && !_run_vector_matrix_multiplication)
{
@@ -253,12 +262,12 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
bool run_optimised_requantized = false;
if(is_data_type_quantized_asymmetric(a->data_type()))
{
- run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, c, output, 1.f, 0.f, gemm_info));
+ run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, c, output, gemm_info));
run_optimised_requantized = run_optimised;
}
else
{
- run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, 1.f, 0.f, gemm_info));
+ run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, gemm_info));
}
if(run_optimised)
@@ -361,6 +370,14 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
a_offset, b_offset));
}
}
+
+ // Validate activation
+ const ActivationLayerInfo &activation = gemm_info.activation_info();
+ if(activation.enabled())
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, activation));
+ }
+
return Status{};
}
@@ -415,6 +432,12 @@ void NEGEMMLowpMatrixMultiplyCore::run()
NEScheduler::get().schedule(&_offset_contribution_kernel, Window::DimY);
}
}
+
+ // Run fused activation
+ if(_run_activation)
+ {
+ _activation_func.run();
+ }
}
void NEGEMMLowpMatrixMultiplyCore::prepare()
diff --git a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
deleted file mode 100644
index 1aeab5b9cb..0000000000
--- a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
+++ /dev/null
@@ -1,430 +0,0 @@
-/*
- * Copyright (c) 2018-2019 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
-
-#include "arm_compute/core/ITensor.h"
-#include "arm_compute/core/NEON/kernels/assembly/Helpers.h"
-#include "arm_compute/core/Utils.h"
-#include "arm_compute/runtime/NEON/NEScheduler.h"
-
-#include "src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h"
-
-#include <atomic>
-#include <condition_variable>
-#include <mutex>
-
-namespace arm_compute
-{
-#ifndef NO_MULTI_THREADING
-class BufferManagerMultipleThreads final : public IBufferManager
-{
-public:
- /** Number of buffers to ping pong between */
- static constexpr unsigned int NUM_BUFFERS = 3;
-
- explicit BufferManagerMultipleThreads(unsigned int max_num_users)
- : _buffers(), _max_num_users(max_num_users)
- {
- }
- unsigned int num_buffers() const override
- {
- return NUM_BUFFERS;
- }
- /* - Lock the requested index if it's free and return true if it needs reshaping.
- * - Return false without acquiring the lock if the buffer at the index is already reshaped / being reshaped.
- * - Block if the corresponding buffer for the given index is still being used by a different index.
- */
- bool lock_to_reshape_if_needed(unsigned int index) override
- {
- Buffer &buf = get_buffer_from_index(index);
- while(true)
- {
- if(buf.index == index && buf.state != State::FREE)
- {
- //Another thread already is reshaping / has reshaped this block: nothing to do
- return false;
- }
- else
- {
- std::unique_lock<std::mutex> lock(buf.mutex);
- //If the buffer is free then lock it for reshaping:
- if(buf.state == State::FREE)
- {
- buf.index = index;
- buf.state = State::BEING_RESHAPED;
- return true;
- }
- // Check again just in case it changed while we were acquiring the lock:
- if(buf.index == index)
- {
- //Another thread is reshaping this block already, nothing to do
- return false;
- }
- // buf.index != index: Buffer still being used by another block, need to wait
- buf.sem.wait(lock);
- }
- }
- }
- /* Mark the buffer at the given index as reshaped and release the lock acquired via lock_to_reshape_if_needed() */
- void mark_as_reshaped(unsigned int index) override
- {
- Buffer &buf = get_buffer_from_index(index);
- {
- std::lock_guard<std::mutex> lock(buf.mutex);
- buf.users = _max_num_users;
- buf.state = State::IN_USE;
- }
- buf.sem.notify_all();
- }
-
- /* Block until the buffer at the given index is reshaped */
- void wait_for_reshaping(unsigned int index) override
- {
- Buffer &buf = get_buffer_from_index(index);
- ARM_COMPUTE_ERROR_ON(buf.index != index); // Should have blocked in lock_to_reshape_if_needed()
- // Check if it's already ready to use:
- if(buf.state == State::IN_USE)
- {
- return;
- }
- std::unique_lock<std::mutex> lock(buf.mutex);
- //Double check it didn't change while we were acquiring the lock:
- if(buf.state == State::IN_USE)
- {
- return;
- }
- buf.sem.wait(lock);
- }
- /* Mark the buffer at the given index as not used by this thread anymore.
- * Once all the threads have called this method then the buffer is marked as free again.
- */
- void mark_as_unused(unsigned int index) override
- {
- Buffer &buf = get_buffer_from_index(index);
- ARM_COMPUTE_ERROR_ON(buf.index != index); // Should have blocked in lock_to_reshape_if_needed()
- if(--buf.users == 0)
- {
- std::unique_lock<std::mutex> lock(buf.mutex);
- buf.state = State::FREE;
- lock.unlock();
- buf.sem.notify_all();
- }
- }
-
-private:
- enum class State
- {
- FREE,
- BEING_RESHAPED,
- IN_USE
- };
- struct Buffer
- {
- unsigned int index{};
- std::atomic_uint users{};
- State state{ State::FREE };
- std::mutex mutex{};
- std::condition_variable sem{};
- };
- std::array<struct Buffer, NUM_BUFFERS> _buffers;
- Buffer &get_buffer_from_index(unsigned int index)
- {
- return _buffers[index % NUM_BUFFERS];
- }
- unsigned int _max_num_users;
-};
-#endif /* NO_MULTI_THREADING */
-
-class BufferManagerSingleThread : public IBufferManager
-{
-public:
- unsigned int num_buffers() const override
- {
- return 1;
- }
- bool lock_to_reshape_if_needed(unsigned int index) override
- {
- ARM_COMPUTE_UNUSED(index);
- return true;
- }
- void mark_as_reshaped(unsigned int index) override
- {
- ARM_COMPUTE_UNUSED(index);
- }
- void wait_for_reshaping(unsigned int index) override
- {
- ARM_COMPUTE_UNUSED(index);
- }
- void mark_as_unused(unsigned int index) override
- {
- ARM_COMPUTE_UNUSED(index);
- }
-};
-
-NEGEMMInterleavedWrapper::NEGEMMInterleavedWrapper(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
- : _memory_group(std::move(memory_manager)),
- _weights_manager(weights_manager)
-{
-}
-
-void NEGEMMInterleavedWrapper::run()
-{
- prepare();
-
- MemoryGroupResourceScope scope_mg(_memory_group);
- NEScheduler::get().run_tagged_workloads(_workloads, _tag.c_str());
-}
-
-void NEGEMMInterleavedWrapper::prepare()
-{
- ARM_COMPUTE_UNUSED(_weights_manager);
- if(!_is_prepared)
- {
- if(_pretranspose_b)
- {
- _transformed_b.allocator()->allocate();
- NEScheduler::get().schedule(_prepare_b.get(), Window::DimX);
- _b->mark_as_unused();
- }
- else
- {
- _prepare_b->create_workloads(_b_workloads);
- }
- _transform_a->create_workloads(_a_workloads);
- _matrix_multiply->create_workloads(_mm_workloads);
-
- //Maximum number of workloads to create:
- const unsigned int num_threads = NEScheduler::get().num_threads();
- const unsigned int max_iterations = num_threads == 1 ? 1 : num_threads;
- //Maximum number of iterations the parameters allow:
- const unsigned int num_iterations = _batch_window.num_iterations_total();
- // Keep the smallest of the two:
- const unsigned int num_windows = std::min(num_iterations, max_iterations);
- const TensorShape window_shape = _batch_window.shape();
- const unsigned int num_x_blocks = _block_walker.num_iterations(Window::DimX);
-
- // Create a 1D window to dynamically split the batch window:
- Window win_1D;
- win_1D.set(0, Window::Dimension(0, num_iterations));
-
- // Create one workload for each sub-window:
- for(unsigned int w = 0; w < num_windows; w++)
- {
- Window win = win_1D.split_window(0, w, num_windows);
- const Coordinates start_offset = index2coords(window_shape, win.x().start());
- const Coordinates end_offset = index2coords(window_shape, win.x().end() - 1);
-
- if(_pretranspose_b)
- {
- auto workload = [start_offset, end_offset, num_x_blocks, this](const ThreadInfo & info)
- {
- //For each block of rows in "M"
- auto workload_mm = this->_mm_workloads.begin();
- for(auto &workload_a : this->_a_workloads)
- {
- // Transform one k_block from A:
- this->_transform_a->transform(workload_a, info, this->_batch_window, start_offset, end_offset);
- // Then perform the matrix multiplication for each x block along N:
- for(unsigned int i = 0; i < num_x_blocks; i++)
- {
- ARM_COMPUTE_ERROR_ON(workload_mm == this->_mm_workloads.end());
- this->_matrix_multiply->transform(*workload_mm++, info, this->_batch_window, start_offset, end_offset);
- }
- }
- };
- _workloads.emplace_back(workload);
- }
- else
- {
- auto workload = [num_threads, start_offset, end_offset, num_x_blocks, this](const ThreadInfo & info)
- {
- //For each block of rows in "M"
- auto workload_mm = this->_mm_workloads.begin();
- unsigned int workload_b = 0;
- //If there is only one thread then only reshape the B blocks as you need them:
- unsigned int workload_b_next = num_threads == 1 ? this->_b_workloads.size() : 1;
-
- for(auto &workload_a : this->_a_workloads)
- {
- // Transform one k_block from A:
- this->_transform_a->transform(workload_a, info, this->_batch_window, start_offset, end_offset);
- // Then perform the matrix multiplication for each x block along N:
- for(unsigned int i = 0; i < num_x_blocks; i++)
- {
- ARM_COMPUTE_ERROR_ON(workload_mm == this->_mm_workloads.end());
- if(workload_b_next < this->_b_workloads.size())
- {
- //Lock on BufferManager: need to run it ?
- if(this->_buffer_manager->lock_to_reshape_if_needed(workload_b_next))
- {
- this->_prepare_b->transform(this->_b_workloads[workload_b_next], info);
- this->_buffer_manager->mark_as_reshaped(workload_b_next);
- }
- workload_b_next++;
- }
- ARM_COMPUTE_ERROR_ON(workload_b >= this->_b_workloads.size());
- // Run if needed or wait
- if(this->_buffer_manager->lock_to_reshape_if_needed(workload_b))
- {
- this->_prepare_b->transform(this->_b_workloads[workload_b], info);
- this->_buffer_manager->mark_as_reshaped(workload_b);
- }
- this->_buffer_manager->wait_for_reshaping(workload_b);
- this->_matrix_multiply->transform(*workload_mm++, info, this->_batch_window, start_offset, end_offset);
- this->_buffer_manager->mark_as_unused(workload_b);
- workload_b++;
- }
- }
- };
- _workloads.emplace_back(workload);
- }
- }
- if(!_pretranspose_b && num_windows > 1 && num_windows % num_threads != 0)
- {
- //Make sure the number of workloads is a multiple of the number of threads to avoid dead locks:
- for(unsigned int leftover = num_windows % num_threads; leftover != num_threads; leftover++)
- {
- auto workload = [this](const ThreadInfo & info)
- {
- unsigned int workload_b = 0;
- //If there is only one thread then only reshape the B blocks as you need them:
- unsigned int workload_b_next = 1;
-
- for(unsigned int iteration = 0; iteration < this->_mm_workloads.size(); iteration++)
- {
- if(workload_b_next < this->_b_workloads.size())
- {
- //Lock on BufferManager: need to run it ?
- if(this->_buffer_manager->lock_to_reshape_if_needed(workload_b_next))
- {
- this->_prepare_b->transform(this->_b_workloads[workload_b_next], info);
- this->_buffer_manager->mark_as_reshaped(workload_b_next);
- }
- workload_b_next++;
- }
- ARM_COMPUTE_ERROR_ON(workload_b >= this->_b_workloads.size());
- // Run if needed or wait
- if(this->_buffer_manager->lock_to_reshape_if_needed(workload_b))
- {
- this->_prepare_b->transform(this->_b_workloads[workload_b], info);
- this->_buffer_manager->mark_as_reshaped(workload_b);
- }
- this->_buffer_manager->wait_for_reshaping(workload_b);
- this->_buffer_manager->mark_as_unused(workload_b);
- workload_b++;
- }
- };
- _workloads.emplace_back(workload);
- }
- }
-
- _is_prepared = true;
- }
-}
-
-void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info)
-{
- _params = INEGEMMWrapperKernel::extract_parameters(a, b, c, gemm_info);
- _a = a;
- _b = b;
- _c = c;
- _pretranspose_b = gemm_info.pretranpose_B();
-
- const DataType input_type = a->info()->data_type();
- const CPUInfo &ci = NEScheduler::get().cpu_info();
- const unsigned int num_threads = NEScheduler::get().num_threads();
-
- const arm_gemm::KernelDescription gemm_kernel_info = get_gemm_info(input_type, ci, num_threads, _params, alpha, beta, _pretranspose_b);
- ARM_COMPUTE_ERROR_ON(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMM_INTERLEAVED);
-
- // Forcing 128-byte alignment (required by 32-bit kernels)
- const unsigned int alignment = 128;
- _transformed_b.allocator()->init(TensorInfo{}, alignment);
- _tmp_c.allocator()->init(TensorInfo{}, alignment);
- _tag = "NEGEMMInterleaved_" + gemm_kernel_info.name;
-
- // Get strategy
- std::unique_ptr<detail::IInterleavedStrategy> strategy = detail::create_strategy(gemm_kernel_info.name);
- ARM_COMPUTE_ERROR_ON(strategy == nullptr);
-
- if(!_pretranspose_b)
- {
- _block_sizes = strategy->calculate_block_sizes_for_strategy(ci, _params);
- _batch_window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_block_sizes.m_round, _block_sizes.strategy_out_height), _block_sizes.strategy_out_height));
- _batch_window.set(Window::DimY, Window::Dimension(0, _params.batches));
- // If the execution is single threaded or has only one window then the buffer manager only needs 1 buffer else we will use NUM_BUFFERS buffers and ping pong between them:
- const unsigned int num_iterations = _batch_window.num_iterations_total();
- if(NEScheduler::get().num_threads() == 1 || num_iterations == 1)
- {
- _buffer_manager = support::cpp14::make_unique<BufferManagerSingleThread>();
- }
- else
- {
-#ifdef NO_MULTI_THREADING
- ARM_COMPUTE_ERROR("Can't have more than 1 buffer without multiple threads");
-#else /* NO_MULTI_THREADING */
- _buffer_manager = support::cpp14::make_unique<BufferManagerMultipleThreads>(NEScheduler::get().num_threads());
-#endif /* NO_MULTI_THREADING */
- }
- // If B is transposed at every iteration then transformed_B can be managed:
- _memory_group.manage(&_transformed_b);
- auto_init_if_empty(*_transformed_b.info(), _b->info()->clone()->set_tensor_shape(TensorShape(_block_sizes.x_block * _block_sizes.k_block, _buffer_manager->num_buffers())));
- }
- else
- {
- _tag += "_preB";
- }
-
- _prepare_b = strategy->instantiate_prepareB(b, &_transformed_b, _params, ci);
- ARM_COMPUTE_ERROR_ON(_prepare_b == nullptr);
-
- if(_pretranspose_b)
- {
- _block_sizes = _prepare_b->block_sizes();
- _batch_window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_block_sizes.m_round, _block_sizes.strategy_out_height), _block_sizes.strategy_out_height));
- _batch_window.set(Window::DimY, Window::Dimension(0, _params.batches));
- }
-
- _block_walker.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_params.N, _block_sizes.x_block), _block_sizes.x_block));
- _block_walker.set(Window::DimY, Window::Dimension(0, ceil_to_multiple(_params.K, _block_sizes.k_block), _block_sizes.k_block));
- _block_walker.set(Window::DimZ, Window::Dimension(0, _params.multis));
-
- _transformed_a.allocator()->init(TensorInfo(TensorShape{ _block_sizes.k_block, _block_sizes.m_round, _params.batches }, 1, input_type), alignment);
- _memory_group.manage(&_transformed_a);
- _memory_group.manage(&_tmp_c);
-
- _transform_a = strategy->instantiate_transformA(_a, &_transformed_a, _block_walker, _params, gemm_info);
- _matrix_multiply = strategy->instantiate_matrix_multiply(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, alpha, beta, gemm_info, num_threads);
- ARM_COMPUTE_ERROR_ON(_transform_a == nullptr);
- ARM_COMPUTE_ERROR_ON(_matrix_multiply == nullptr);
-
- _transformed_a.allocator()->allocate();
- _tmp_c.allocator()->allocate();
- if(!_pretranspose_b)
- {
- _transformed_b.allocator()->allocate();
- }
-}
-} // namespace arm_compute