aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-04-03 12:40:10 +0100
committerMichele Di Giorgio <michele.digiorgio@arm.com>2020-04-08 08:43:39 +0000
commitf64d33619827ce6ec9af4566c4743834e521328e (patch)
tree2ca123ae92c2f91ecbccbacc77687bc359973fd4
parent8abbabd6ad946441c8ef1a03896fa98f7801af1f (diff)
downloadComputeLibrary-f64d33619827ce6ec9af4566c4743834e521328e.tar.gz
COMPMID-3236: Extend CLGEMMLowpReduction kernels to multiply by a scalar value
Change-Id: Iebd6afac65d10a42d60c2c9df9e1895fadb205ae Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2981 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h36
-rw-r--r--src/core/CL/cl_kernels/gemmlowp.cl26
-rw-r--r--src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp62
-rw-r--r--src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp13
4 files changed, 85 insertions, 52 deletions
diff --git a/arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h b/arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h
index 4e52a8029e..71681cf628 100644
--- a/arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h
@@ -29,6 +29,7 @@
namespace arm_compute
{
class ICLTensor;
+struct GEMMLowpReductionKernelInfo;
/** Common interface for all OpenCL reduction kernels */
class ICLGEMMLowpReductionKernel : public ICLKernel
@@ -49,8 +50,13 @@ public:
*
* @param[in] input Input tensor. Data type supported: S8
* @param[out] output Output row-vector of sums of all the entries in each row/col of input tensor. Data type supported: S32
+ * @param[in] info Kernel metadata:
+ * - k Number of matrix columns/rows depending on the type of reduction.
+ * - is_reshaped True if the matrix has been reshaped.
+ * - scalar Scalar value to multiply each reduced column/row by.
+ * - mul_byscalar True if each reduced column/row must be multiplied by a scalar value.
*/
- virtual void configure(const ICLTensor *input, ICLTensor *output) = 0;
+ virtual void configure(const ICLTensor *input, ICLTensor *output, const GEMMLowpReductionKernelInfo &info) = 0;
protected:
const ICLTensor *_input;
@@ -69,16 +75,26 @@ public:
*
* @param[in] mtx_a Input tensor. Data type supported: QASYMM8/QASYMM8_SIGNED
* @param[out] vector_sum_row Output row-vector of sums of all the entries in each row of mtx_a. Data type supported: S32
+ * @param[in] info Kernel metadata:
+ * - k Number of matrix columns/rows depending on the type of reduction.
+ * - is_reshaped True if the matrix has been reshaped.
+ * - scalar Scalar value to multiply each reduced column/row by.
+ * - mul_byscalar True if each reduced column/row must be multiplied by a scalar value.
*/
- void configure(const ICLTensor *mtx_a, ICLTensor *vector_sum_row) override;
+ void configure(const ICLTensor *mtx_a, ICLTensor *vector_sum_row, const GEMMLowpReductionKernelInfo &info) override;
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMLowpMatrixAReductionKernel
*
* @param[in] mtx_a Input tensor. Data type supported: QASYMM8/QASYMM8_SIGNED
* @param[in] vector_sum_row Output row-vector of sums of all the entries in each row of mtx_a. Data type supported: S32
+ * @param[in] info Kernel metadata:
+ * - k Number of matrix columns/rows depending on the type of reduction.
+ * - is_reshaped True if the matrix has been reshaped.
+ * - scalar Scalar value to multiply each reduced column/row by.
+ * - mul_byscalar True if each reduced column/row must be multiplied by a scalar value.
*
* @return a status
*/
- static Status validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row);
+ static Status validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row, const GEMMLowpReductionKernelInfo &info);
// Inherited methods overridden:
void run(const Window &window, cl::CommandQueue &queue) override;
@@ -96,16 +112,26 @@ public:
*
* @param[in] mtx_b Input tensor. Data type supported: Data type supported: QASYMM8/QASYMM8_SIGNED
* @param[out] vector_sum_col Output row-vector of sums of all the entries in each column of mtx_b. Data type supported: S32
+ * @param[in] info Kernel metadata:
+ * - k Number of matrix columns/rows depending on the type of reduction.
+ * - is_reshaped True if the matrix has been reshaped.
+ * - scalar Scalar value to multiply each reduced column/row by.
+ * - mul_byscalar True if each reduced column/row must be multiplied by a scalar value.
*/
- void configure(const ICLTensor *mtx_b, ICLTensor *vector_sum_col) override;
+ void configure(const ICLTensor *mtx_b, ICLTensor *vector_sum_col, const GEMMLowpReductionKernelInfo &info) override;
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMLowpMatrixBReductionKernel
*
* @param[in] mtx_b Input tensor. Data type supported: Data type supported: QASYMM8/QASYMM8_SIGNED
* @param[in] vector_sum_col Output row-vector of sums of all the entries in each column of mtx_b. Data type supported: S32
+ * @param[in] info Kernel metadata:
+ * - k Number of matrix columns/rows depending on the type of reduction.
+ * - is_reshaped True if the matrix has been reshaped.
+ * - scalar Scalar value to multiply each reduced column/row by.
+ * - mul_byscalar True if each reduced column/row must be multiplied by a scalar value.
*
* @return a status
*/
- static Status validate(const ITensorInfo *mtx_b, const ITensorInfo *vector_sum_col);
+ static Status validate(const ITensorInfo *mtx_b, const ITensorInfo *vector_sum_col, const GEMMLowpReductionKernelInfo &info);
// Inherited methods overridden:
void run(const Window &window, cl::CommandQueue &queue) override;
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index 71de1d4b27..b707ec8175 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -1287,6 +1287,7 @@ __kernel void gemmlowp_mm_native(IMAGE_DECLARATION(lhs),
#if defined(COLS_A)
/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A.
+ * It is also possible to multiply each reduced row by a scalar value, if SCALAR is passed at compile time.
*
* @note This stage is needed to handle the offset of matrix product
* https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
@@ -1294,8 +1295,9 @@ __kernel void gemmlowp_mm_native(IMAGE_DECLARATION(lhs),
* @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
* @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar)
* @note The data type for the accumulation must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint)
+ * @note In case of scaling the scalar value must be passed at compile time using -DSCALAR (e.g. -DSCALAR=3)
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8
+ * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8/QASYMM8_SIGNED
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -1342,11 +1344,15 @@ __kernel void gemmlowp_matrix_a_reduction(TENSOR3D_DECLARATION(src),
sum_row += sum_row_32.s0 + sum_row_32.s1 + sum_row_32.s2 + sum_row_32.s3;
+#if defined(SCALAR)
+ sum_row *= (int)SCALAR;
+#endif // defined(SCALAR)
*((__global int *)dst.ptr) = (int)sum_row;
}
#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
-/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A using the arm dot product instruction
+/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A using the arm dot product instruction.
+ * It is also possible to multiply each reduced row by a scalar value, if SCALAR is passed at compile time.
*
* @note This stage is needed to handle the offset of matrix product
* https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
@@ -1354,8 +1360,9 @@ __kernel void gemmlowp_matrix_a_reduction(TENSOR3D_DECLARATION(src),
* @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
* @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar)
* @note The data type for the accumulation must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint)
+ * @note In case of scaling the scalar value must be passed at compile time using -DSCALAR (e.g. -DSCALAR=3)
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8
+ * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8/QASYMM8_SIGNED
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -1408,6 +1415,9 @@ __kernel void gemmlowp_matrix_a_reduction_dot8(TENSOR3D_DECLARATION(src),
sum_row += (ACC_DATA_TYPE)matrix_a[i];
}
+#if defined(SCALAR)
+ sum_row *= (int)SCALAR;
+#endif // defined(SCALAR)
*((__global int *)dst.ptr) = (int)sum_row;
}
#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
@@ -1415,6 +1425,7 @@ __kernel void gemmlowp_matrix_a_reduction_dot8(TENSOR3D_DECLARATION(src),
#if defined(COLS_B) && defined(ROWS_B)
/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each column of Matrix B.
+ * It is also possible to multiply each reduced column by a scalar value, if SCALAR is passed at compile time.
*
* @note This stage is needed to handle the offset of matrix product
* https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
@@ -1422,8 +1433,9 @@ __kernel void gemmlowp_matrix_a_reduction_dot8(TENSOR3D_DECLARATION(src),
* @attention The number of matrix B columns and rows needs to be passed at compile time using -DCOLS_B and -DROWS_B
* @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar)
* @note The data type for the accumulation must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint)
+ * @note In case of scaling the scalar value must be passed at compile time using -DSCALAR (i.e. -DSCALAR=3)
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8
+ * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8/QASYMM8_SIGNED
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -1480,7 +1492,11 @@ __kernel void gemmlowp_matrix_b_reduction(TENSOR3D_DECLARATION(src),
matrix_b += src_stride_y;
}
- vstore16(convert_int16(sum_col_32), 0, (__global int *)dst.ptr);
+#if defined(SCALAR)
+ sum_col_32 *= (VEC_DATA_TYPE(ACC_DATA_TYPE, 16))SCALAR;
+#endif // defined(SCALAR)
+ VSTORE(16)
+ (sum_col_32, 0, (__global int *)dst.ptr);
}
#endif // defined(COLS_B) && defined(ROWS_B)
diff --git a/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp b/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp
index 8f3d53a412..832e6281f4 100644
--- a/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp
@@ -26,53 +26,36 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/ICLTensor.h"
-#include "arm_compute/core/Error.h"
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/TensorInfo.h"
-#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
-#include "arm_compute/core/Window.h"
+#include "arm_compute/core/KernelDescriptors.h"
#include "support/StringSupport.h"
-#include <cstddef>
-#include <cstdint>
-
namespace arm_compute
{
-class Coordinates;
-
namespace
{
Status validate_arguments_matrix_a_reduction(const ITensorInfo *input, const ITensorInfo *output)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+ if(output->total_size() > 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->dimension(0) != input->dimension(1), "Output vector must have length equal to the number of rows of the input matrix");
+ }
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window_matrix_a_reduction(ITensorInfo *input, ITensorInfo *output)
-{
- constexpr unsigned int num_elems_processed_per_iteration = 1;
-
- Window win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
-
- AccessWindowStatic input_access(input, 0, 0, input->dimension(0), input->dimension(1));
- AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
-
- bool window_changed = update_window_and_padding(win, input_access, output_access);
-
- output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
-
- Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_pair(err, win);
-}
Status validate_arguments_matrix_b_reduction(const ITensorInfo *input, const ITensorInfo *output)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+ if(output->total_size() > 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->dimension(0) != input->dimension(0), "Output vector must have length equal to the number of columns of the input matrix");
+ }
return Status{};
}
@@ -100,7 +83,7 @@ ICLGEMMLowpReductionKernel::ICLGEMMLowpReductionKernel()
{
}
-void CLGEMMLowpMatrixAReductionKernel::configure(const ICLTensor *mtx_a, ICLTensor *vector_sum_row)
+void CLGEMMLowpMatrixAReductionKernel::configure(const ICLTensor *mtx_a, ICLTensor *vector_sum_row, const GEMMLowpReductionKernelInfo &info)
{
// Perform validate step
ARM_COMPUTE_ERROR_ON_NULLPTR(mtx_a, vector_sum_row);
@@ -114,6 +97,7 @@ void CLGEMMLowpMatrixAReductionKernel::configure(const ICLTensor *mtx_a, ICLTens
build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(mtx_a->info()->dimension(0)));
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(mtx_a->info()->data_type()));
build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(mtx_a->info()->data_type()));
+ build_opts.add_option_if(info.mul_by_scalar, "-DSCALAR=" + support::cpp11::to_string(info.scalar));
const bool is_dot8_supported = dot8_supported(CLKernelLibrary::get().get_device());
@@ -123,9 +107,9 @@ void CLGEMMLowpMatrixAReductionKernel::configure(const ICLTensor *mtx_a, ICLTens
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
// Configure kernel window
- auto win_config = validate_and_configure_window_matrix_a_reduction(_input->info(), _output->info());
- ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure_internal(win_config.second);
+ // This kernel does not need padding
+ Window win = calculate_max_window(*vector_sum_row->info(), Steps());
+ ICLKernel::configure_internal(win);
_config_id = kernel_name;
_config_id += "_";
@@ -136,10 +120,10 @@ void CLGEMMLowpMatrixAReductionKernel::configure(const ICLTensor *mtx_a, ICLTens
_config_id += support::cpp11::to_string(_input->info()->dimension(2));
}
-Status CLGEMMLowpMatrixAReductionKernel::validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row)
+Status CLGEMMLowpMatrixAReductionKernel::validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row, const GEMMLowpReductionKernelInfo &info)
{
+ ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_matrix_a_reduction(mtx_a, vector_sum_row));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_matrix_a_reduction(mtx_a->clone().get(), vector_sum_row->clone().get()).first);
return Status{};
}
@@ -168,7 +152,7 @@ void CLGEMMLowpMatrixAReductionKernel::run(const Window &window, cl::CommandQueu
while(collapsed.slide_window_slice_2D(slice_out));
}
-void CLGEMMLowpMatrixBReductionKernel::configure(const ICLTensor *mtx_b, ICLTensor *vector_sum_col)
+void CLGEMMLowpMatrixBReductionKernel::configure(const ICLTensor *mtx_b, ICLTensor *vector_sum_col, const GEMMLowpReductionKernelInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(mtx_b, vector_sum_col);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_matrix_b_reduction(mtx_b->info(), vector_sum_col->info()));
@@ -182,6 +166,7 @@ void CLGEMMLowpMatrixBReductionKernel::configure(const ICLTensor *mtx_b, ICLTens
build_opts.add_option("-DROWS_B=" + support::cpp11::to_string(mtx_b->info()->dimension(1)));
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(mtx_b->info()->data_type()));
build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(mtx_b->info()->data_type()));
+ build_opts.add_option_if(info.mul_by_scalar, "-DSCALAR=" + support::cpp11::to_string(info.scalar));
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemmlowp_matrix_b_reduction", build_opts.options()));
@@ -192,8 +177,9 @@ void CLGEMMLowpMatrixBReductionKernel::configure(const ICLTensor *mtx_b, ICLTens
ICLKernel::configure_internal(win_config.second);
}
-Status CLGEMMLowpMatrixBReductionKernel::validate(const ITensorInfo *mtx_b, const ITensorInfo *vector_sum_col)
+Status CLGEMMLowpMatrixBReductionKernel::validate(const ITensorInfo *mtx_b, const ITensorInfo *vector_sum_col, const GEMMLowpReductionKernelInfo &info)
{
+ ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_matrix_b_reduction(mtx_b, vector_sum_col));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_matrix_b_reduction(mtx_b->clone().get(), vector_sum_col->clone().get()).first);
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 9346e9357c..90e5698fd8 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -28,6 +28,7 @@
#include "arm_compute/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
@@ -149,6 +150,9 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
_mtx_b_reshape_kernel.configure(_convert_to_qasymm8 ? &_qasymm8_weights : b, &_tmp_b, rhs_info);
}
+ // Using default reduction info
+ const GEMMLowpReductionKernelInfo reduction_info {};
+
// Initialize matrix B reduction kernel only if _a_offset is not equal to 0
if(_a_offset != 0)
{
@@ -160,7 +164,7 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
}
// Configure Matrix B reduction kernel
- _mtx_b_reduction_kernel.configure(_convert_to_qasymm8 ? &_qasymm8_weights : b, &_vector_sum_col);
+ _mtx_b_reduction_kernel.configure(_convert_to_qasymm8 ? &_qasymm8_weights : b, &_vector_sum_col, reduction_info);
}
// Initialize Matrix A reduction kernel only if _b_offset is not equal to 0
@@ -171,7 +175,7 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
_memory_group.manage(&_vector_sum_row);
// Configure matrix A reduction kernel
- _mtx_a_reduction_kernel.configure(a, &_vector_sum_row);
+ _mtx_a_reduction_kernel.configure(a, &_vector_sum_row, reduction_info);
}
GEMMKernelInfo gemm_kernel_info;
@@ -356,13 +360,14 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
TensorInfo info_vector_sum_col{};
TensorInfo info_vector_sum_row{};
+ const GEMMLowpReductionKernelInfo reduction_info;
// Validate matrix B reduction kernel only if _a_offset is not equal to 0
if(a_offset != 0)
{
info_vector_sum_col = TensorInfo(compute_reductionA_shape(weights_info), 1, DataType::S32);
// Configure Matrix B reduction kernel
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixBReductionKernel::validate(&weights_info, &info_vector_sum_col));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixBReductionKernel::validate(&weights_info, &info_vector_sum_col, reduction_info));
}
// Validate Matrix A reduction kernel only if _b_offset is not equal to 0
@@ -371,7 +376,7 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
info_vector_sum_row = TensorInfo(compute_reductionB_shape(*a), 1, DataType::S32);
// Configure matrix A reduction kernel
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(a, &info_vector_sum_row));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(a, &info_vector_sum_row, reduction_info));
}
GEMMKernelInfo gemm_kernel_info;