From 1a378107af40669eaa23a12e064bb8fabff2473e Mon Sep 17 00:00:00 2001 From: Sheri Zhang Date: Thu, 30 Apr 2020 12:59:39 +0100 Subject: COMPMID-3290: Test improvement for CLGEMMMatrixMultiplyReshapedOnlyRHSKernel Signed-off-by: Sheri Zhang Change-Id: I7335ee07f777087e06ca26f762b2b5e3668362ab Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3175 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Sang-Hoon Park --- .../CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp') diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp index 13f8152fb4..8e194d5139 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp @@ -24,23 +24,16 @@ #include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h" #include "arm_compute/core/AccessWindowStatic.h" -#include "arm_compute/core/CL/CLHelpers.h" -#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/CL/CLValidate.h" #include "arm_compute/core/CL/ICLTensor.h" -#include "arm_compute/core/CL/OpenCL.h" #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/TensorInfo.h" -#include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" -#include "arm_compute/core/Validate.h" -#include "arm_compute/core/Window.h" #include "arm_compute/core/utils/helpers/float_ops.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "support/StringSupport.h" -#include -#include #include using namespace arm_compute::misc::shape_calculator; @@ -57,13 +50,15 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, { ARM_COMPUTE_UNUSED(alpha); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output); + ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_info.m0 < 1 || lhs_info.m0 > 8, "Only 1,2,3,4,5,6,7,8 are supported for m0"); + ARM_COMPUTE_RETURN_ERROR_ON(rhs_info.k0 > 16 || rhs_info.k0 < 2); ARM_COMPUTE_RETURN_ERROR_ON_MSG(((rhs_info.k0 & (rhs_info.k0 - 1)) && rhs_info.k0 != 3), "Only 2,3,4,8,16 are supported for k0"); - ARM_COMPUTE_RETURN_ERROR_ON(rhs_info.k0 > 16); - ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.m0 < 1 || lhs_info.m0 > 8); + ARM_COMPUTE_RETURN_ERROR_ON(rhs_info.n0 > 16 || rhs_info.n0 < 2); ARM_COMPUTE_RETURN_ERROR_ON_MSG(((rhs_info.n0 & (rhs_info.n0 - 1)) && rhs_info.n0 != 3), "Only 2,3,4,8,16 are supported for n0"); ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr) && (!gemm_info.broadcast_bias), @@ -83,7 +78,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const unsigned int input2_dim0 = input2->dimension(0); const unsigned int input2_dim1 = input2->dimension(1); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input2, input1); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input2, input0); if(gemm_info.broadcast_bias) { ARM_COMPUTE_RETURN_ERROR_ON_MSG((input2_dim1 != 1 || input2_dim0 != n), "Incorrect dimension of bias matrix which is to be broadcasted"); @@ -220,7 +215,8 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *input configure(CLKernelLibrary::get().get_compile_context(), input0, input1, input2, output, alpha, beta, lhs_info, rhs_info, gemm_info); } -void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, +void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, + float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info) -- cgit v1.2.1