aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-04-03 12:40:10 +0100
committerMichele Di Giorgio <michele.digiorgio@arm.com>2020-04-08 08:43:39 +0000
commitf64d33619827ce6ec9af4566c4743834e521328e (patch)
tree2ca123ae92c2f91ecbccbacc77687bc359973fd4 /src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
parent8abbabd6ad946441c8ef1a03896fa98f7801af1f (diff)
downloadComputeLibrary-f64d33619827ce6ec9af4566c4743834e521328e.tar.gz
COMPMID-3236: Extend CLGEMMLowpReduction kernels to multiply by a scalar value
Change-Id: Iebd6afac65d10a42d60c2c9df9e1895fadb205ae Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2981 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp')
-rw-r--r--src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp13
1 files changed, 9 insertions, 4 deletions
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 9346e9357c..90e5698fd8 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -28,6 +28,7 @@
#include "arm_compute/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
@@ -149,6 +150,9 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
_mtx_b_reshape_kernel.configure(_convert_to_qasymm8 ? &_qasymm8_weights : b, &_tmp_b, rhs_info);
}
+ // Using default reduction info
+ const GEMMLowpReductionKernelInfo reduction_info {};
+
// Initialize matrix B reduction kernel only if _a_offset is not equal to 0
if(_a_offset != 0)
{
@@ -160,7 +164,7 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
}
// Configure Matrix B reduction kernel
- _mtx_b_reduction_kernel.configure(_convert_to_qasymm8 ? &_qasymm8_weights : b, &_vector_sum_col);
+ _mtx_b_reduction_kernel.configure(_convert_to_qasymm8 ? &_qasymm8_weights : b, &_vector_sum_col, reduction_info);
}
// Initialize Matrix A reduction kernel only if _b_offset is not equal to 0
@@ -171,7 +175,7 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
_memory_group.manage(&_vector_sum_row);
// Configure matrix A reduction kernel
- _mtx_a_reduction_kernel.configure(a, &_vector_sum_row);
+ _mtx_a_reduction_kernel.configure(a, &_vector_sum_row, reduction_info);
}
GEMMKernelInfo gemm_kernel_info;
@@ -356,13 +360,14 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
TensorInfo info_vector_sum_col{};
TensorInfo info_vector_sum_row{};
+ const GEMMLowpReductionKernelInfo reduction_info;
// Validate matrix B reduction kernel only if _a_offset is not equal to 0
if(a_offset != 0)
{
info_vector_sum_col = TensorInfo(compute_reductionA_shape(weights_info), 1, DataType::S32);
// Configure Matrix B reduction kernel
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixBReductionKernel::validate(&weights_info, &info_vector_sum_col));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixBReductionKernel::validate(&weights_info, &info_vector_sum_col, reduction_info));
}
// Validate Matrix A reduction kernel only if _b_offset is not equal to 0
@@ -371,7 +376,7 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
info_vector_sum_row = TensorInfo(compute_reductionB_shape(*a), 1, DataType::S32);
// Configure matrix A reduction kernel
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(a, &info_vector_sum_row));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(a, &info_vector_sum_row, reduction_info));
}
GEMMKernelInfo gemm_kernel_info;