aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2021-06-30 12:05:34 +0100
committerMichele Di Giorgio <michele.digiorgio@arm.com>2021-07-02 13:20:41 +0000
commit4dfc5538948c196def6d2e3305fe8051a5df3f15 (patch)
treede9619bc7f19d09be5ca5642fc15092d31d74ace
parentbc4e31113be0af320f44b338969d6972b64ca4de (diff)
downloadComputeLibrary-4dfc5538948c196def6d2e3305fe8051a5df3f15.tar.gz
Port NEGEMM to memory injecting interface (Part 3)
- Complete porting of NEGEMM to the new API Resolves: COMPMID-4402 Change-Id: I14904102b25332dbb4fc048d45dca068a15b6eca Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5890 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--Android.bp1
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMM.h16
-rw-r--r--filelist.json1
-rw-r--r--src/runtime/NEON/functions/NEGEMM.cpp418
-rw-r--r--src/runtime/cpu/operators/CpuGemm.cpp366
-rw-r--r--src/runtime/cpu/operators/CpuGemm.h145
-rw-r--r--src/runtime/cpu/utils/CpuAuxTensorHandler.h4
-rw-r--r--tests/validation/NEON/GEMM.cpp122
8 files changed, 654 insertions, 419 deletions
diff --git a/Android.bp b/Android.bp
index 3435f02d70..5943b56450 100644
--- a/Android.bp
+++ b/Android.bp
@@ -644,6 +644,7 @@ cc_library_static {
"src/runtime/cpu/operators/CpuFill.cpp",
"src/runtime/cpu/operators/CpuFlatten.cpp",
"src/runtime/cpu/operators/CpuFloor.cpp",
+ "src/runtime/cpu/operators/CpuGemm.cpp",
"src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp",
"src/runtime/cpu/operators/CpuGemmLowpOutputStage.cpp",
"src/runtime/cpu/operators/CpuMul.cpp",
diff --git a/arm_compute/runtime/NEON/functions/NEGEMM.h b/arm_compute/runtime/NEON/functions/NEGEMM.h
index 5daa0406a5..ce68a61923 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMM.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMM.h
@@ -24,11 +24,9 @@
#ifndef ARM_COMPUTE_NEGEMM_H
#define ARM_COMPUTE_NEGEMM_H
-#include "arm_compute/core/ITensorPack.h"
#include "arm_compute/runtime/IFunction.h"
#include "arm_compute/runtime/IMemoryManager.h"
#include "arm_compute/runtime/IWeightsManager.h"
-#include "arm_compute/runtime/MemoryGroup.h"
#include <memory>
@@ -36,19 +34,7 @@ namespace arm_compute
{
/** Basic function to execute GEMM. This function calls the following kernels:
*
- * If optimized assembly is available:
- * -# @ref cpu::CpuGemmAssemblyDispatch
- * -# @ref cpu::CpuActivation (if alpha != 1.0)
- * Else:
- * -# @ref cpu::kernels::CpuGemmInterleave4x4Kernel (if the output tensor is a matrix)
- * -# @ref cpu::kernels::CpuGemmTranspose1xWKernel (if the output tensor is a matrix)
- * -# @ref cpu::kernels::CpuGemmMatrixMultiplyKernel
- * In both cases:
- * -# @ref cpu::kernels::CpuGemmMatrixAdditionKernel (if c != nullptr and beta != 0.0 and is not reshaped once)
- * Else:
- * -# @ref cpu::CpuAdd (if c != nullptr and is reshaped once and not optimized assembly in place)
- *
- * -# @ref cpu::CpuActivation (if activation is specified in GEMMInfo)
+ * -# @ref cpu::CpuGemm
*/
class NEGEMM : public IFunction
{
diff --git a/filelist.json b/filelist.json
index 97f1db0901..7129efbf21 100644
--- a/filelist.json
+++ b/filelist.json
@@ -1131,6 +1131,7 @@
},
"GEMM": {
"files": {
+ "operator" : ["src/runtime/cpu/operators/CpuGemm.cpp"],
"kernel": [
"src/core/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp",
"src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp",
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index a52ca79504..168d93022f 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -23,79 +23,32 @@
*/
#include "arm_compute/runtime/NEON/functions/NEGEMM.h"
-#include "arm_compute/core/Error.h"
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/ITensorPack.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Validate.h"
-#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "arm_compute/runtime/MemoryGroup.h"
#include "arm_compute/runtime/Tensor.h"
-#include "arm_compute/runtime/TensorAllocator.h"
#include "src/core/CPP/Validate.h"
-#include "src/core/cpu/kernels/CpuGemmInterleave4x4Kernel.h"
-#include "src/core/cpu/kernels/CpuGemmMatrixAdditionKernel.h"
-#include "src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.h"
-#include "src/core/cpu/kernels/CpuGemmTranspose1xWKernel.h"
-#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/MemoryHelpers.h"
-#include "src/runtime/cpu/operators/CpuActivation.h"
-#include "src/runtime/cpu/operators/CpuAdd.h"
-#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
+#include "src/runtime/cpu/operators/CpuGemm.h"
using namespace arm_compute::experimental;
-using namespace arm_compute::misc::shape_calculator;
namespace arm_compute
{
-namespace
-{
-cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
-{
- cpu::AsmGemmInfo asm_info;
- asm_info.method = cpu::AsmConvMethod::Im2Col;
- asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d();
- asm_info.depth_output_gemm3d = info.depth_output_gemm3d();
- asm_info.activation_info = info.activation_info();
-
- return asm_info;
-}
-} // namespace
-
struct NEGEMM::Impl
{
MemoryGroup memory_group{};
IWeightsManager *weights_manager{ nullptr };
- std::unique_ptr<cpu::kernels::CpuGemmInterleave4x4Kernel> interleave_kernel{ nullptr };
- std::unique_ptr<cpu::kernels::CpuGemmTranspose1xWKernel> transpose_kernel{ nullptr };
- std::unique_ptr<cpu::kernels::CpuGemmMatrixMultiplyKernel> mm_kernel{ nullptr };
- std::unique_ptr<cpu::CpuGemmAssemblyDispatch> asm_glue{ nullptr };
- std::unique_ptr<cpu::kernels::CpuGemmMatrixAdditionKernel> ma_kernel{ nullptr };
- std::unique_ptr<cpu::CpuActivation> alpha_scale_func{ nullptr };
- std::unique_ptr<cpu::CpuAdd> add_bias{ nullptr };
- std::unique_ptr<cpu::CpuActivation> activation_func{ nullptr };
+ std::unique_ptr<cpu::CpuGemm> op{ nullptr };
- const ITensor *a{ nullptr };
- const ITensor *c{ nullptr };
- ITensor *d{ nullptr };
- ITensor *gemm_output_to_use{ nullptr };
- Tensor tmp_a{};
- Tensor tmp_b{};
- Tensor tmp_d{};
const ITensor *original_b{ nullptr };
- bool run_vector_matrix_multiplication{ false };
- bool run_alpha_scale{ false };
- bool run_addition{ false };
- bool run_bias_addition{ false };
- bool run_activation{ false };
- bool reshape_b_only_on_first_run{ false };
bool is_prepared{ false };
- ITensorPack asm_glue_run_pack{};
- ITensorPack asm_glue_prep_pack{};
- WorkspaceData<Tensor> asm_glue_workspace{};
+ ITensorPack run_pack{};
+ ITensorPack prep_pack{};
+ WorkspaceData<Tensor> workspace{};
experimental::MemoryRequirements aux_mem_req{};
};
@@ -111,259 +64,24 @@ NEGEMM::~NEGEMM() = default;
void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
- ARM_COMPUTE_ERROR_THROW_ON(NEGEMM::validate(a->info(), b->info(), (c != nullptr) ? c->info() : nullptr, d->info(), alpha, beta, gemm_info));
-
- const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
- const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
- bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a->info(), b->info(), (is_c_bias && c != nullptr) ? c->info() : nullptr, d->info(), asm_info));
+ ARM_COMPUTE_ERROR_THROW_ON(cpu::CpuGemm::validate(a->info(), b->info(), (c != nullptr) ? c->info() : nullptr, d->info(), alpha, beta, gemm_info));
- _impl->a = a;
- _impl->c = c;
- _impl->d = d;
- _impl->gemm_output_to_use = d;
// Check if we need to reshape the matrix B only on the first run
- _impl->is_prepared = false;
- _impl->reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
- _impl->run_vector_matrix_multiplication = a->info()->dimension(1) < 2;
- _impl->original_b = b;
- _impl->run_alpha_scale = alpha != 1.f;
- _impl->run_bias_addition = c != nullptr && gemm_info.reshape_b_only_on_first_run();
- _impl->run_addition = beta != 0 && c != nullptr && !gemm_info.reshape_b_only_on_first_run();
- _impl->run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised
- && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info())));
-
- if(run_optimised)
- {
- const ITensor *c_to_use = is_c_bias ? c : nullptr;
- const ITensorInfo *c_info_to_use = c_to_use != nullptr ? c_to_use->info() : nullptr;
- _impl->asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
- _impl->asm_glue->configure(a->info(), b->info(), c_info_to_use, d->info(), asm_info);
- ARM_COMPUTE_ERROR_ON(!_impl->asm_glue->is_configured());
-
- _impl->aux_mem_req = _impl->asm_glue->workspace();
- _impl->asm_glue_run_pack =
- {
- { ACL_SRC_0, a },
- { ACL_SRC_1, b },
- { ACL_SRC_2, c_to_use },
- { ACL_DST, d },
- };
- _impl->asm_glue_prep_pack = { { ACL_SRC_1, b }, { ACL_SRC_2, c_to_use } };
- _impl->asm_glue_workspace = manage_workspace<Tensor>(_impl->aux_mem_req, _impl->memory_group, _impl->asm_glue_run_pack, _impl->asm_glue_prep_pack);
-
- // Scale product by alpha
- if(_impl->run_alpha_scale)
- {
- _impl->alpha_scale_func = std::make_unique<cpu::CpuActivation>();
- _impl->alpha_scale_func->configure(d->info(), nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LINEAR, alpha, 0.f));
- }
- }
- else
- {
- // Pick output tensor in case bias addition should be performed
- if(_impl->run_bias_addition)
- {
- _impl->gemm_output_to_use = &_impl->tmp_d;
- _impl->memory_group.manage(&_impl->tmp_d);
- }
-
- _impl->mm_kernel = std::make_unique<cpu::kernels::CpuGemmMatrixMultiplyKernel>();
-
- // Select between GEMV and GEMM
- if(_impl->run_vector_matrix_multiplication)
- {
- // Configure the matrix multiply kernel
- _impl->mm_kernel->configure(a->info(), b->info(), _impl->gemm_output_to_use->info(), alpha, false);
- }
- else
- {
- TensorShape shape_tmp_a = a->info()->tensor_shape();
- TensorShape shape_tmp_b = b->info()->tensor_shape();
-
- shape_tmp_a.set(0, a->info()->dimension(0) * 4);
- shape_tmp_a.set(1, std::ceil(a->info()->dimension(1) / 4.0f));
-
- const unsigned int transpose_w = 16 / data_size_from_type(b->info()->data_type());
- shape_tmp_b.set(0, b->info()->dimension(1) * transpose_w);
- shape_tmp_b.set(1, std::ceil(b->info()->dimension(0) / static_cast<float>(transpose_w)));
-
- TensorInfo info_a = a->info()->clone()->set_tensor_shape(shape_tmp_a).set_is_resizable(true);
- TensorInfo info_b = b->info()->clone()->set_tensor_shape(shape_tmp_b).set_is_resizable(true);
-
- _impl->tmp_a.allocator()->init(info_a);
- _impl->tmp_b.allocator()->init(info_b);
-
- // Manage intermediate buffers
- _impl->memory_group.manage(&_impl->tmp_a);
- if(!_impl->reshape_b_only_on_first_run)
- {
- _impl->memory_group.manage(&_impl->tmp_b);
- }
-
- int m = a->info()->dimension(1);
- int n = b->info()->dimension(0);
- int k = a->info()->dimension(0);
+ _impl->is_prepared = false;
+ _impl->original_b = b;
+ _impl->op = std::make_unique<cpu::CpuGemm>();
- // Configure interleave kernel
- _impl->interleave_kernel = std::make_unique<cpu::kernels::CpuGemmInterleave4x4Kernel>();
- _impl->interleave_kernel->configure(a->info(), &info_a);
+ _impl->op->configure(a->info(), b->info(), (c != nullptr) ? c->info() : nullptr, d->info(), alpha, beta, gemm_info);
- // Configure transpose kernel
- _impl->transpose_kernel = std::make_unique<cpu::kernels::CpuGemmTranspose1xWKernel>();
- _impl->transpose_kernel->configure(b->info(), _impl->tmp_b.info());
-
- // Configure matrix multiplication kernel
- _impl->mm_kernel->configure(_impl->tmp_a.info(), _impl->tmp_b.info(), _impl->gemm_output_to_use->info(), alpha, true, GEMMReshapeInfo(m, n, k));
-
- // Allocate once the all configure methods have been called
- _impl->tmp_a.allocator()->allocate();
- if(!_impl->reshape_b_only_on_first_run)
- {
- _impl->tmp_b.allocator()->allocate();
- }
- }
-
- if(_impl->run_bias_addition)
- {
- _impl->add_bias = std::make_unique<cpu::CpuAdd>();
- _impl->add_bias->configure(_impl->gemm_output_to_use->info(), c->info(), d->info(), ConvertPolicy::SATURATE);
- _impl->tmp_d.allocator()->allocate();
- }
- }
-
- // Configure matrix addition kernel
- if(_impl->run_addition)
- {
- _impl->ma_kernel = std::make_unique<cpu::kernels::CpuGemmMatrixAdditionKernel>();
- _impl->ma_kernel->configure(c->info(), d->info(), beta);
- }
-
- // Configure activation
- const ActivationLayerInfo &activation = gemm_info.activation_info();
- if(_impl->run_activation)
- {
- _impl->activation_func = std::make_unique<cpu::CpuActivation>();
- _impl->activation_func->configure(d->info(), nullptr, activation);
- }
+ _impl->aux_mem_req = _impl->op->workspace();
+ _impl->run_pack = { { ACL_SRC_0, a }, { ACL_SRC_1, b }, { ACL_SRC_2, c }, { ACL_DST, d } };
+ _impl->prep_pack = { { ACL_SRC_1, b }, { ACL_SRC_2, c } };
+ _impl->workspace = manage_workspace<Tensor>(_impl->aux_mem_req, _impl->memory_group, _impl->run_pack, _impl->prep_pack);
}
Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
{
- ARM_COMPUTE_UNUSED(alpha);
- const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
-
- ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
- ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
- if(a->data_type() != DataType::BFLOAT16)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, output);
- }
-
- if(c != nullptr && !is_c_bias)
- {
- ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0);
- ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d());
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->dimension(1), "The C matrix must have the same number of rows as the matrix A");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->dimension(0), "The C matrix must have the same number of columns as the matrix B");
- }
-
- if(output->total_size() != 0)
- {
- ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != output->dimension(0));
- if(gemm_info.depth_output_gemm3d() != 0)
- {
- if(gemm_info.reinterpret_input_as_3d())
- {
- ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1));
- ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(2) != output->dimension(2));
- }
- else
- {
- ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1) * output->dimension(2));
- }
- }
- else
- {
- ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1));
- }
- }
-
- // Check if we need to run the optimized assembly kernel
- cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
- const bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, output, asm_info));
-
- if(!run_optimised)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.reinterpret_input_as_3d(), "NEGEMM cannot reinterpret the input tensor as 3D");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 0, "NEGEMM cannot reinterpret the output tensor as 3D");
-
- // Check if the first input tensor is a vector.
- const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
- // Check if we need to reshape the matrix A and matrix B
- const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run());
-
- // Arguments used by GEMMReshapeInfo
- // If we pass the matrix A and matrix B reshaped to NEGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to GEMMReshapeInfo
- // in order to know how the matrices have been reshaped
- const int m = a->dimension(1);
- const int n = b->dimension(0);
- const int k = a->dimension(0);
- int mult_transpose1xW_width = 1;
- int mult_interleave4x4_height = 1;
-
- const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
-
- const ITensorInfo *matrix_a_info = a;
- const ITensorInfo *matrix_b_info = b;
-
- TensorInfo tmp_a_info{};
- TensorInfo tmp_b_info{};
- TensorInfo tmp_output_info = *output->clone();
-
- if(run_interleave_transpose)
- {
- matrix_a_info = &tmp_a_info;
- matrix_b_info = &tmp_b_info;
-
- // Validate interleave kernel
- auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
- ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmInterleave4x4Kernel::validate(a, &tmp_a_info));
-
- // Validate transpose kernel
- auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
- ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmTranspose1xWKernel::validate(b, &tmp_b_info));
- }
-
- // Validate matrix multiply
- auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
- ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info));
-
- if(c != nullptr && gemm_info.reshape_b_only_on_first_run())
- {
- ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuAdd::validate(&tmp_output_info, c, output, ConvertPolicy::SATURATE));
- }
- }
-
- // Validate matrix addition kernel
- if(beta != 0 && c != nullptr && !is_c_bias)
- {
- ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixAdditionKernel::validate(c, output, beta));
- }
-
- // Validate activation
- const ActivationLayerInfo &activation = gemm_info.activation_info();
- if(activation.enabled())
- {
- ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuActivation::validate(output, nullptr, activation));
- }
-
- return Status{};
+ return cpu::CpuGemm::validate(a, b, c, output, alpha, beta, gemm_info);
}
void NEGEMM::run()
@@ -371,100 +89,42 @@ void NEGEMM::run()
prepare();
MemoryGroupResourceScope scope_mg(_impl->memory_group);
-
- if(_impl->asm_glue->is_configured())
- {
- _impl->asm_glue->run(_impl->asm_glue_run_pack);
- if(_impl->run_alpha_scale)
- {
- ITensorPack pack{ { ACL_SRC, _impl->d }, { ACL_DST, _impl->d } };
- _impl->alpha_scale_func->run(pack);
- }
- }
- else
- {
- ITensorPack mm_pack{ { ACL_SRC_0, _impl->a }, { ACL_SRC_1, _impl->original_b }, { ACL_DST, _impl->gemm_output_to_use } };
- if(!_impl->run_vector_matrix_multiplication)
- {
- // Run interleave kernel
- ITensorPack interleave_pack{ { ACL_SRC, _impl->a }, { ACL_DST, &_impl->tmp_a } };
- NEScheduler::get().schedule_op(_impl->interleave_kernel.get(), Window::DimY, _impl->interleave_kernel->window(), interleave_pack);
-
- if(!_impl->reshape_b_only_on_first_run)
- {
- // Run transpose kernel
- ITensorPack transpose_pack{ { ACL_SRC, _impl->original_b }, { ACL_DST, &_impl->tmp_b } };
- NEScheduler::get().schedule_op(_impl->transpose_kernel.get(), Window::DimY, _impl->transpose_kernel->window(), transpose_pack);
- }
-
- // Use reshaped matrices
- mm_pack.add_const_tensor(ACL_SRC_0, &_impl->tmp_a);
- mm_pack.add_const_tensor(ACL_SRC_1, &_impl->tmp_b);
- }
-
- NEScheduler::get().schedule_op(_impl->mm_kernel.get(), _impl->run_vector_matrix_multiplication ? Window::DimX : Window::DimY, _impl->mm_kernel->window(), mm_pack);
-
- // Run bias addition kernel
- if(_impl->run_bias_addition)
- {
- ITensorPack pack{ { ACL_SRC_0, _impl->gemm_output_to_use }, { ACL_SRC_1, _impl->c }, { ACL_DST, _impl->d } };
- _impl->add_bias->run(pack);
- }
- }
-
- // Run matrix addition kernel
- if(_impl->run_addition)
- {
- ITensorPack c_add_pack{ { ACL_SRC, _impl->c }, { ACL_DST, _impl->d } };
- NEScheduler::get().schedule_op(_impl->ma_kernel.get(), Window::DimY, _impl->ma_kernel->window(), c_add_pack);
- }
-
- // Run activation function
- if(_impl->run_activation)
- {
- ITensorPack pack{ { ACL_SRC, _impl->d }, { ACL_DST, _impl->d } };
- _impl->activation_func->run(pack);
- }
+ _impl->op->run(_impl->run_pack);
}
void NEGEMM::prepare()
{
if(!_impl->is_prepared)
{
- const bool original_b_managed_by_weights_manager = _impl->weights_manager && _impl->weights_manager->are_weights_managed(_impl->original_b);
- if(_impl->asm_glue->is_configured())
- {
- _impl->asm_glue->prepare(_impl->asm_glue_prep_pack);
+ _impl->op->prepare(_impl->prep_pack);
- auto has_reshape = std::find_if(_impl->aux_mem_req.begin(),
- _impl->aux_mem_req.end(),
- [](const MemoryInfo & m) -> bool { return m.lifetime == MemoryLifetime::Persistent; });
+ auto has_reshape = std::find_if(_impl->aux_mem_req.begin(),
+ _impl->aux_mem_req.end(),
+ [](const MemoryInfo & m) -> bool { return m.lifetime == MemoryLifetime::Persistent; });
- if(has_reshape != std::end(_impl->aux_mem_req))
- {
- _impl->original_b->mark_as_unused();
- }
- else
- {
- _impl->asm_glue_run_pack.add_const_tensor(ACL_SRC_1, _impl->original_b);
- }
+ if(has_reshape != std::end(_impl->aux_mem_req))
+ {
+ _impl->original_b->mark_as_unused();
}
- else if(_impl->reshape_b_only_on_first_run && !_impl->run_vector_matrix_multiplication && !_impl->asm_glue->is_configured())
+ else
{
- if(!original_b_managed_by_weights_manager)
- {
- ARM_COMPUTE_ERROR_ON(!_impl->original_b->is_used());
- }
+ _impl->run_pack.add_const_tensor(ACL_SRC_1, _impl->original_b);
+ }
- _impl->tmp_b.allocator()->allocate();
- ITensorPack transpose_pack{ { ACL_SRC, _impl->original_b }, { ACL_DST, &_impl->tmp_b } };
- NEScheduler::get().schedule_op(_impl->transpose_kernel.get(), Window::DimY, _impl->transpose_kernel->window(), transpose_pack);
- if(!original_b_managed_by_weights_manager)
+ // Release temporary tensors that are only used in prepare stage
+ for(auto &ws : _impl->workspace)
+ {
+ const int slot = ws.first;
+ for(auto &m : _impl->aux_mem_req)
{
- _impl->original_b->mark_as_unused();
+ if(m.slot == slot && m.lifetime == MemoryLifetime::Prepare)
+ {
+ auto tensor = ws.second.get();
+ tensor->allocator()->free();
+ break;
+ }
}
}
-
_impl->is_prepared = true;
}
}
diff --git a/src/runtime/cpu/operators/CpuGemm.cpp b/src/runtime/cpu/operators/CpuGemm.cpp
new file mode 100644
index 0000000000..9a4d171ce6
--- /dev/null
+++ b/src/runtime/cpu/operators/CpuGemm.cpp
@@ -0,0 +1,366 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/runtime/cpu/operators/CpuGemm.h"
+
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "src/core/CPP/Validate.h"
+#include "src/core/helpers/AutoConfiguration.h"
+#include "src/core/helpers/MemoryHelpers.h"
+#include "src/runtime/cpu/utils/CpuAuxTensorHandler.h"
+
+using namespace arm_compute::experimental;
+using namespace arm_compute::misc::shape_calculator;
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace
+{
+cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
+{
+ cpu::AsmGemmInfo asm_info;
+ asm_info.method = cpu::AsmConvMethod::Im2Col;
+ asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d();
+ asm_info.depth_output_gemm3d = info.depth_output_gemm3d();
+ asm_info.activation_info = info.activation_info();
+
+ return asm_info;
+}
+} // namespace
+
+void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
+ ARM_COMPUTE_ERROR_THROW_ON(CpuGemm::validate(a, b, c, d, alpha, beta, gemm_info));
+
+ const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
+ const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
+ bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info));
+
+ // Check if we need to reshape the matrix B only on the first run
+ _is_prepared = false;
+ _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
+ _run_vector_matrix_multiplication = a->dimension(1) < 2;
+ _run_alpha_scale = alpha != 1.f;
+ _run_bias_addition = c != nullptr && gemm_info.reshape_b_only_on_first_run();
+ _run_addition = beta != 0 && c != nullptr && !gemm_info.reshape_b_only_on_first_run();
+ _run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised
+ && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info())));
+
+ if(run_optimised)
+ {
+ const ITensorInfo *c_to_use = is_c_bias ? c : nullptr;
+ _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
+ _asm_glue->configure(a, b, c_to_use, d, asm_info);
+ ARM_COMPUTE_ERROR_ON(!_asm_glue->is_configured());
+
+ auto asm_mem_req = _asm_glue->workspace();
+ _aux_mem[AsmGemmWorkspace] = asm_mem_req[AsmGemmWorkspace];
+ _aux_mem[Pretraspose] = asm_mem_req[Pretraspose];
+
+ // Scale product by alpha
+ if(_run_alpha_scale)
+ {
+ _alpha_scale_func = std::make_unique<cpu::CpuActivation>();
+ _alpha_scale_func->configure(d, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LINEAR, alpha, 0.f));
+ }
+ }
+ else
+ {
+ // Pick output tensor in case bias addition should be performed
+ ITensorInfo *gemm_output_to_use = (_run_bias_addition) ? &_tmp_d : d;
+
+ _mm_kernel = std::make_unique<cpu::kernels::CpuGemmMatrixMultiplyKernel>();
+
+ // Select between GEMV and GEMM
+ if(_run_vector_matrix_multiplication)
+ {
+ // Configure the matrix multiply kernel
+ _mm_kernel->configure(a, b, gemm_output_to_use, alpha, false);
+ }
+ else
+ {
+ const int m = a->dimension(1);
+ const int n = b->dimension(0);
+ const int k = a->dimension(0);
+
+ // Configure interleave kernel
+ _interleave_kernel = std::make_unique<cpu::kernels::CpuGemmInterleave4x4Kernel>();
+ _interleave_kernel->configure(a, &_tmp_a);
+ _aux_mem[InterleavedLHS] = MemoryInfo(offset_int_vec(InterleavedLHS), MemoryLifetime::Temporary, _tmp_a.total_size());
+
+ // Configure transpose kernel
+ _transpose_kernel = std::make_unique<cpu::kernels::CpuGemmTranspose1xWKernel>();
+ _transpose_kernel->configure(b, &_tmp_b);
+ _aux_mem[TransposedRHS] = MemoryInfo(offset_int_vec(TransposedRHS), MemoryLifetime::Persistent, _tmp_b.total_size());
+
+ // Configure matrix multiplication kernel
+ _mm_kernel->configure(&_tmp_a, &_tmp_b, gemm_output_to_use, alpha, true, GEMMReshapeInfo(m, n, k));
+ }
+
+ if(_run_bias_addition)
+ {
+ _add_bias = std::make_unique<cpu::CpuAdd>();
+ _add_bias->configure(gemm_output_to_use, c, d, ConvertPolicy::SATURATE);
+ _aux_mem[TempResult] = MemoryInfo(offset_int_vec(TempResult), MemoryLifetime::Persistent, _tmp_d.total_size());
+ }
+ }
+
+ // Configure matrix addition kernel
+ if(_run_addition)
+ {
+ _ma_kernel = std::make_unique<cpu::kernels::CpuGemmMatrixAdditionKernel>();
+ _ma_kernel->configure(c, d, beta);
+ }
+
+ // Configure activation
+ if(_run_activation)
+ {
+ _activation_func = std::make_unique<cpu::CpuActivation>();
+ _activation_func->configure(d, nullptr, gemm_info.activation_info());
+ }
+}
+
+Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+ ARM_COMPUTE_UNUSED(alpha);
+ const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
+
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
+ if(a->data_type() != DataType::BFLOAT16)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, d);
+ }
+
+ if(c != nullptr && !is_c_bias)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0);
+ ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d());
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, d);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->dimension(1), "The C matrix must have the same number of rows as the matrix A");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->dimension(0), "The C matrix must have the same number of columns as the matrix B");
+ }
+
+ if(d->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != d->dimension(0));
+ if(gemm_info.depth_output_gemm3d() != 0)
+ {
+ if(gemm_info.reinterpret_input_as_3d())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != d->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(2) != d->dimension(2));
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != d->dimension(1) * d->dimension(2));
+ }
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != d->dimension(1));
+ }
+ }
+
+ // Check if we need to run the optimized assembly kernel
+ cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
+ const bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, d, asm_info));
+
+ if(!run_optimised)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.reinterpret_input_as_3d(), "CpuGemm cannot reinterpret the input tensor as 3D");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 0, "CpuGemm cannot reinterpret the output tensor as 3D");
+
+ // Check if the first input tensor is a vector.
+ const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
+ // Check if we need to reshape the matrix A and matrix B
+ const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run());
+
+ // Arguments used by GEMMReshapeInfo
+ // If we pass the matrix A and matrix B reshaped to CpuGemmMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to GEMMReshapeInfo
+ // in order to know how the matrices have been reshaped
+ const int m = a->dimension(1);
+ const int n = b->dimension(0);
+ const int k = a->dimension(0);
+ int mult_transpose1xW_width = 1;
+ int mult_interleave4x4_height = 1;
+
+ const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
+
+ const ITensorInfo *matrix_a_info = a;
+ const ITensorInfo *matrix_b_info = b;
+
+ TensorInfo tmp_a_info{};
+ TensorInfo tmp_b_info{};
+ TensorInfo tmp_output_info = *d->clone();
+
+ if(run_interleave_transpose)
+ {
+ matrix_a_info = &tmp_a_info;
+ matrix_b_info = &tmp_b_info;
+
+ // Validate interleave kernel
+ auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
+ ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmInterleave4x4Kernel::validate(a, &tmp_a_info));
+
+ // Validate transpose kernel
+ auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
+ ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmTranspose1xWKernel::validate(b, &tmp_b_info));
+ }
+
+ // Validate matrix multiply
+ auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
+ ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info));
+
+ if(c != nullptr && gemm_info.reshape_b_only_on_first_run())
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuAdd::validate(&tmp_output_info, c, d, ConvertPolicy::SATURATE));
+ }
+ }
+
+ // Validate matrix addition kernel
+ if(beta != 0 && c != nullptr && !is_c_bias)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixAdditionKernel::validate(c, d, beta));
+ }
+
+ // Validate activation
+ const ActivationLayerInfo &activation = gemm_info.activation_info();
+ if(activation.enabled())
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuActivation::validate(d, nullptr, activation));
+ }
+
+ return Status{};
+}
+
+void CpuGemm::run(ITensorPack &tensors)
+{
+ prepare(tensors);
+
+ auto a = tensors.get_const_tensor(ACL_SRC_0);
+ auto b = tensors.get_const_tensor(ACL_SRC_1);
+ auto c = tensors.get_const_tensor(ACL_SRC_2);
+ auto d = tensors.get_tensor(ACL_DST);
+
+ if(_asm_glue->is_configured())
+ {
+ // Pass c to asm dispatch only if it's the bias tensor
+ ITensorPack asm_pack = tensors;
+ asm_pack.add_const_tensor(ACL_SRC_2, (_reshape_b_only_on_first_run) ? c : nullptr);
+ _asm_glue->run(asm_pack);
+ if(_run_alpha_scale)
+ {
+ ITensorPack pack{ { ACL_SRC, d }, { ACL_DST, d } };
+ _alpha_scale_func->run(pack);
+ }
+ }
+ else
+ {
+ CpuAuxTensorHandler interleaved_a(offset_int_vec(InterleavedLHS), _tmp_a, tensors, true);
+ CpuAuxTensorHandler transposed_b(offset_int_vec(TransposedRHS), _tmp_b, tensors, true);
+ CpuAuxTensorHandler temp_d(offset_int_vec(TempResult), _tmp_d, tensors, true);
+
+ ITensorPack mm_pack{ { ACL_SRC_0, a }, { ACL_SRC_1, b }, { ACL_DST, (_run_bias_addition) ? temp_d.get() : d } };
+ if(!_run_vector_matrix_multiplication)
+ {
+ // Run interleave kernel
+ ITensorPack interleave_pack{ { ACL_SRC, a }, { ACL_DST, interleaved_a.get() } };
+ NEScheduler::get().schedule_op(_interleave_kernel.get(), Window::DimY, _interleave_kernel->window(), interleave_pack);
+
+ if(!_reshape_b_only_on_first_run)
+ {
+ // Run transpose kernel
+ ITensorPack transpose_pack{ { ACL_SRC, b }, { ACL_DST, transposed_b.get() } };
+ NEScheduler::get().schedule_op(_transpose_kernel.get(), Window::DimY, _transpose_kernel->window(), transpose_pack);
+ }
+
+ // Use reshaped matrices
+ mm_pack.add_const_tensor(ACL_SRC_0, interleaved_a.get());
+ mm_pack.add_const_tensor(ACL_SRC_1, transposed_b.get());
+ }
+
+ NEScheduler::get().schedule_op(_mm_kernel.get(), _run_vector_matrix_multiplication ? Window::DimX : Window::DimY, _mm_kernel->window(), mm_pack);
+
+ // Run bias addition kernel
+ if(_run_bias_addition)
+ {
+ ITensorPack pack{ { ACL_SRC_0, temp_d.get() }, { ACL_SRC_1, c }, { ACL_DST, d } };
+ _add_bias->run(pack);
+ }
+ }
+
+ // Run matrix addition kernel
+ if(_run_addition)
+ {
+ ITensorPack c_add_pack{ { ACL_SRC, c }, { ACL_DST, d } };
+ NEScheduler::get().schedule_op(_ma_kernel.get(), Window::DimY, _ma_kernel->window(), c_add_pack);
+ }
+
+ // Run activation function
+ if(_run_activation)
+ {
+ ITensorPack pack{ { ACL_SRC, d }, { ACL_DST, d } };
+ _activation_func->run(pack);
+ }
+}
+
+void CpuGemm::prepare(ITensorPack &tensors)
+{
+ if(!_is_prepared)
+ {
+ if(_asm_glue->is_configured())
+ {
+ _asm_glue->prepare(tensors);
+ }
+ else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication)
+ {
+ const ITensor *b = tensors.get_const_tensor(ACL_SRC_1);
+ ITensor *b_aux = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(TransposedRHS)));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(b, b_aux);
+
+ CpuAuxTensorHandler transposed_b(_tmp_b, *b_aux);
+ ITensorPack transpose_pack{ { ACL_SRC, b }, { ACL_DST, transposed_b.get() } };
+ NEScheduler::get().schedule_op(_transpose_kernel.get(), Window::DimY, _transpose_kernel->window(), transpose_pack);
+ }
+ _is_prepared = true;
+ }
+}
+
+experimental::MemoryRequirements CpuGemm::workspace() const
+{
+ return _aux_mem;
+}
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/runtime/cpu/operators/CpuGemm.h b/src/runtime/cpu/operators/CpuGemm.h
new file mode 100644
index 0000000000..8d859791f5
--- /dev/null
+++ b/src/runtime/cpu/operators/CpuGemm.h
@@ -0,0 +1,145 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_CPU_GEMM_H
+#define ARM_COMPUTE_CPU_GEMM_H
+
+#include "src/runtime/cpu/ICpuOperator.h"
+
+#include "arm_compute/core/ITensorPack.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
+#include "src/core/cpu/kernels/CpuGemmInterleave4x4Kernel.h"
+#include "src/core/cpu/kernels/CpuGemmMatrixAdditionKernel.h"
+#include "src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.h"
+#include "src/core/cpu/kernels/CpuGemmTranspose1xWKernel.h"
+#include "src/runtime/cpu/operators/CpuActivation.h"
+#include "src/runtime/cpu/operators/CpuAdd.h"
+#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
+
+#include <memory>
+
+namespace arm_compute
+{
+namespace cpu
+{
+/** Basic function to execute GEMM. This function calls the following kernels:
+ *
+ * If optimized assembly is available:
+ * -# @ref cpu::CpuGemmAssemblyDispatch
+ * -# @ref cpu::CpuActivation (if alpha != 1.0)
+ * Else:
+ * -# @ref cpu::kernels::CpuGemmInterleave4x4Kernel (if the output tensor is a matrix)
+ * -# @ref cpu::kernels::CpuGemmTranspose1xWKernel (if the output tensor is a matrix)
+ * -# @ref cpu::kernels::CpuGemmMatrixMultiplyKernel
+ * In both cases:
+ * -# @ref cpu::kernels::CpuGemmMatrixAdditionKernel (if c != nullptr and beta != 0.0 and is not reshaped once)
+ * Else:
+ * -# @ref cpu::CpuAdd (if c != nullptr and is reshaped once and not optimized assembly in place)
+ *
+ * -# @ref cpu::CpuActivation (if activation is specified in GEMMInfo)
+ */
+class CpuGemm : public ICpuOperator
+{
+public:
+ /** Default constructor */
+ CpuGemm() = default;
+ /** Default destructor */
+ ~CpuGemm() = default;
+ /** Configure operator for a given list of arguments
+ *
+ * Valid data layouts:
+ * - All
+ *
+ * Valid data type configurations:
+ * |src0 |src1 |src2 |dst |
+ * |:------------|:-----------|:---------|:--------------|
+ * |F32 |F32 |F32 |F32 |
+ * |F16 |F16 |F16 |F16 |
+ * |BFLOAT16 |BFLOAT16 |BFLOAT16 |BFLOAT16 |
+ *
+ * @note GEMM: General Matrix Multiply - [alpha * A * B + beta * C].
+ * @note GEMM: The tensors a, b, c, d must have the same data type. You should not mix data types when calling this function.
+ *
+ * @param[in] a First input tensor info (Matrix A or Vector A). Data type supported: BFLOAT16/F16/F32
+ * @param[in] b Second input tensor info (Matrix B). Data type supported: same as @p a
+ * @param[in] c Third input tensor info (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a
+ * @param[out] d Output tensor info. Data type supported: same as @p a
+ * @param[in] alpha Weight of the matrix product
+ * @param[in] beta Weight of matrix C
+ * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
+ * if the reshape of matrix B should happen only for the first run
+ */
+ void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
+ float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
+ /** Static function to check if given info will lead to a valid configuration of @ref CpuGemm.
+ *
+ * Similar to @ref CpuGemm::configure()
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+ float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
+
+ // Inherited methods overridden:
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &constants) override;
+ experimental::MemoryRequirements workspace() const override;
+
+private:
+ enum AuxTensorIdx
+ {
+ AsmGemmWorkspace = 0,
+ Pretraspose,
+ InterleavedLHS,
+ TransposedRHS,
+ TempResult,
+ Count
+ };
+
+ std::unique_ptr<kernels::CpuGemmInterleave4x4Kernel> _interleave_kernel{ nullptr };
+ std::unique_ptr<kernels::CpuGemmTranspose1xWKernel> _transpose_kernel{ nullptr };
+ std::unique_ptr<kernels::CpuGemmMatrixMultiplyKernel> _mm_kernel{ nullptr };
+ std::unique_ptr<CpuGemmAssemblyDispatch> _asm_glue{ nullptr };
+ std::unique_ptr<kernels::CpuGemmMatrixAdditionKernel> _ma_kernel{ nullptr };
+ std::unique_ptr<CpuActivation> _alpha_scale_func{ nullptr };
+ std::unique_ptr<CpuAdd> _add_bias{ nullptr };
+ std::unique_ptr<CpuActivation> _activation_func{ nullptr };
+
+ TensorInfo _tmp_a{};
+ TensorInfo _tmp_b{};
+ TensorInfo _tmp_d{};
+
+ bool _run_vector_matrix_multiplication{ false };
+ bool _run_alpha_scale{ false };
+ bool _run_addition{ false };
+ bool _run_bias_addition{ false };
+ bool _run_activation{ false };
+ bool _reshape_b_only_on_first_run{ false };
+ bool _is_prepared{ false };
+
+ experimental::MemoryRequirements _aux_mem{ Count };
+};
+} // namespace cpu
+} // namespace arm_compute
+#endif /*ARM_COMPUTE_CPU_GEMM_H */
diff --git a/src/runtime/cpu/utils/CpuAuxTensorHandler.h b/src/runtime/cpu/utils/CpuAuxTensorHandler.h
index 644018a718..0d1c927b5a 100644
--- a/src/runtime/cpu/utils/CpuAuxTensorHandler.h
+++ b/src/runtime/cpu/utils/CpuAuxTensorHandler.h
@@ -41,6 +41,10 @@ public:
CpuAuxTensorHandler(int slot_id, TensorInfo &info, ITensorPack &pack, bool pack_inject = false)
: _tensor()
{
+ if(info.total_size() == 0)
+ {
+ return;
+ }
_tensor.allocator()->soft_init(info);
ITensor *packed_tensor = utils::cast::polymorphic_downcast<ITensor *>(pack.get_tensor(slot_id));
diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp
index 36c1943f27..27f0109590 100644
--- a/tests/validation/NEON/GEMM.cpp
+++ b/tests/validation/NEON/GEMM.cpp
@@ -28,6 +28,8 @@
#include "src/core/cpu/kernels/CpuGemmInterleave4x4Kernel.h"
#include "src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.h"
#include "src/core/cpu/kernels/CpuGemmTranspose1xWKernel.h"
+#include "src/core/helpers/MemoryHelpers.h"
+#include "src/runtime/cpu/operators/CpuGemm.h"
#include "tests/NEON/Accessor.h"
#include "tests/NEON/Helper.h"
#include "tests/PaddingCalculator.h"
@@ -73,29 +75,6 @@ template <typename FunctionType>
bool validate_zero_padding(unsigned int dim0_value, unsigned int dim1_value)
{
const TensorShape in_shape(dim0_value, dim1_value);
-
- // Create tensors
- Tensor in = create_tensor<Tensor>(in_shape, DataType::U32);
- Tensor dst;
-
- ARM_COMPUTE_EXPECT(in.info()->is_resizable(), framework::LogLevel::ERRORS);
-
- // Validate zero-padding
- FunctionType func;
-
- func.configure(&in, &dst);
-
- return in.info()->padding().empty();
-}
-
-/** Zero padding test
- *
- * TODO(COMPMID-4402): merge with previous when all kernels have been ported
- */
-template <typename FunctionType>
-bool validate_zero_padding_new(unsigned int dim0_value, unsigned int dim1_value)
-{
- const TensorShape in_shape(dim0_value, dim1_value);
TensorInfo in(in_shape, 1, DataType::U32);
TensorInfo dst;
@@ -128,6 +107,99 @@ bool validate_gemm_zero_padding(const TensorShape shape0, const TensorShape shap
TEST_SUITE(NEON)
TEST_SUITE(GEMM)
+/** Test case for memory injection in @ref cpu::CpuGemm.
+ *
+ * Configure the operator once and inject memory at run-time in multiple executions.
+ *
+ * Checks performed in order:
+ * - Both runs compute the same output
+ */
+TEST_CASE(MemoryInjection, framework::DatasetMode::ALL)
+{
+ auto gemm = std::make_unique<cpu::CpuGemm>();
+ const auto lhs_info = TensorInfo(TensorShape(3U, 3U), 1, DataType::F32);
+ const auto rhs_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
+ const auto c_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
+ auto dst_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
+ const auto gemm_info = GEMMInfo{};
+ gemm->configure(&lhs_info, &rhs_info, &c_info, &dst_info, 1.f, 1.f, gemm_info);
+
+ // telhs are newly created every call of this lambda function
+ auto lhs = create_tensor<Tensor>(lhs_info);
+ auto rhs = create_tensor<Tensor>(rhs_info);
+ auto c = create_tensor<Tensor>(c_info);
+ lhs.allocator()->allocate();
+ rhs.allocator()->allocate();
+ c.allocator()->allocate();
+
+ ITensorPack run_pack{ { TensorType::ACL_SRC_0, &lhs }, { TensorType::ACL_SRC_1, &rhs }, { TensorType::ACL_SRC_2, &c } };
+ ITensorPack prep_pack{ { TensorType::ACL_SRC_1, &rhs }, { TensorType::ACL_SRC_2, &c } };
+
+ auto mg = MemoryGroup{};
+ auto ws = manage_workspace<Tensor>(gemm->workspace(), mg, run_pack, prep_pack);
+
+ auto run_conv = [&]() -> Tensor
+ {
+ auto dst = create_tensor<Tensor>(dst_info);
+ dst.allocator()->allocate();
+ run_pack.add_tensor(TensorType::ACL_DST, &dst);
+
+ library->fill_tensor_value(Accessor(lhs), 1.f);
+ library->fill_tensor_value(Accessor(rhs), 2.f);
+ library->fill_tensor_value(Accessor(c), 3.f);
+ // This operator is configured once and captured by this lambda.
+ gemm->prepare(prep_pack);
+ gemm->run(run_pack);
+ return dst;
+ };
+ auto result_0 = run_conv();
+ auto result_1 = run_conv();
+ for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
+ {
+ ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ }
+}
+
+/** Test case for memory injection in @ref NEGEMM.
+ *
+ * Make sure @ref NEGEMM still works through injecting the memory at configure time using the old API.
+ *
+ * Checks performed in order:
+ * - Both runs compute the same output
+ */
+TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
+{
+ auto gemm = std::make_unique<NEGEMM>();
+ const auto lhs_info = TensorInfo(TensorShape(3U, 3U), 1, DataType::F32);
+ const auto rhs_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
+ const auto c_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
+ auto dst_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
+ const auto gemm_info = GEMMInfo{};
+ auto run_conv = [&]()
+ {
+ auto lhs = create_tensor<Tensor>(lhs_info);
+ auto rhs = create_tensor<Tensor>(rhs_info);
+ auto c = create_tensor<Tensor>(c_info);
+ auto dst = create_tensor<Tensor>(dst_info);
+ gemm->configure(&lhs, &rhs, &c, &dst, 1.f, 1.f, gemm_info);
+ lhs.allocator()->allocate();
+ rhs.allocator()->allocate();
+ c.allocator()->allocate();
+ dst.allocator()->allocate();
+ library->fill_tensor_value(Accessor(lhs), 1.f);
+ library->fill_tensor_value(Accessor(rhs), 2.f);
+ library->fill_tensor_value(Accessor(c), 3.f);
+ gemm->run();
+ return dst;
+ };
+ auto result_0 = run_conv();
+ auto result_1 = run_conv();
+ for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
+ {
+ ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ }
+}
+
TEST_SUITE(TRANSPOSE_1XW)
using CpuGemmTranspose1xW = NESynthetizeFunctionWithZeroConstantKernelBorder<cpu::kernels::CpuGemmTranspose1xWKernel>;
DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
@@ -135,7 +207,7 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
framework::dataset::make("K", { 1, 47, 29, 27 })),
n_value, k_value)
{
- bool status = validate_zero_padding_new<CpuGemmTranspose1xW>(n_value, k_value);
+ bool status = validate_zero_padding<CpuGemmTranspose1xW>(n_value, k_value);
ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS);
}
@@ -176,7 +248,7 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
framework::dataset::make("K", { 1, 47, 29, 27 })),
m_value, k_value)
{
- bool status = validate_zero_padding_new<cpu::kernels::CpuGemmInterleave4x4Kernel>(m_value, k_value);
+ bool status = validate_zero_padding<cpu::kernels::CpuGemmInterleave4x4Kernel>(m_value, k_value);
ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS);
}