From a602f03f4c66e5ee2480f1a3fc66847968fc1076 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Thu, 12 Mar 2020 19:34:33 +0000 Subject: COMPMID-3237: Extend GEMMLowpReduction kernels to multiply reductions by a scalar value Change-Id: If2a242f52aea753591525d30a4cb64c1a766bf8d Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2881 Tested-by: Arm Jenkins Reviewed-by: Sang-Hoon Park Comments-Addressed: Arm Jenkins --- .../NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) (limited to 'src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp') diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp index 8c6cee78bb..3417c72735 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp @@ -26,6 +26,7 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" +#include "arm_compute/core/KernelDescriptors.h" #include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h" @@ -37,7 +38,8 @@ #include "arm_compute/runtime/TensorAllocator.h" #include "support/MemorySupport.h" -using namespace arm_compute; +namespace arm_compute +{ using namespace arm_compute::misc::shape_calculator; NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr memory_manager) @@ -172,6 +174,9 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, if(!_fused_assembly_path) { + // Build reduction info + const GEMMLowpReductionKernelInfo reduction_info(a_to_use->info()->dimension(0), false, 0, false); + // Initialize matrix B reduction kernel only if _a_offset is not equal to 0 if(_a_offset != 0) { @@ -184,7 +189,7 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, } // Configure Matrix B reduction kernel - _mtx_b_reduction_kernel.configure(b, &_vector_sum_col, a_to_use->info()->dimension(0), false); + _mtx_b_reduction_kernel.configure(b, &_vector_sum_col, reduction_info); } // Initialize Matrix A reduction kernel only if _b_offset is not equal to 0 @@ -196,7 +201,7 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, _memory_group.manage(&_vector_sum_row); // Configure matrix A reduction kernel - _mtx_a_reduction_kernel.configure(a_to_use, &_vector_sum_row, a_to_use->info()->dimension(0), false); + _mtx_a_reduction_kernel.configure(a_to_use, &_vector_sum_row, reduction_info); } if(_fuse_output_stage) @@ -418,13 +423,15 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso TensorInfo info_vector_sum_col{}; TensorInfo info_vector_sum_row{}; + const GEMMLowpReductionKernelInfo reduction_info(a_to_use->dimension(0), false, 0, false); + // Validate matrix B reduction kernel only if _a_offset is not equal to 0 if(a_offset != 0) { info_vector_sum_col = TensorInfo(compute_reductionA_shape(*b), 1, DataType::S32); // Configure Matrix B reduction kernel - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixBReductionKernel::validate(b, &info_vector_sum_col, a->dimension(0), false)); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixBReductionKernel::validate(b, &info_vector_sum_col, reduction_info)); } // Validate Matrix A reduction kernel only if _b_offset is not equal to 0 @@ -433,7 +440,7 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso info_vector_sum_row = TensorInfo(compute_reductionB_shape(*a), 1, DataType::S32); // Configure matrix A reduction kernel - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(a_to_use, &info_vector_sum_row, a->dimension(0), false)); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(a_to_use, &info_vector_sum_row, reduction_info)); } if(fuse_output_stage) @@ -580,3 +587,4 @@ void NEGEMMLowpMatrixMultiplyCore::prepare() _is_prepared = true; } } +} // namespace arm_compute -- cgit v1.2.1