From d820db6fc479f7daef6788377cb765369fcddc22 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 5 Aug 2019 14:23:23 +0100 Subject: COMPMID-2545: Reduce tests required by GEMM (OpenCL) Removed FP16 tests from the new GEMM functions (GEMMNative, GEMMReshaped and GEMMReshapedOnlyRHS) since not called by CLGEMM Change-Id: Id52281fc9557d45e29db0a74964d4bdec55d8f46 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1695 Reviewed-by: Michele Di Giorgio Tested-by: Arm Jenkins --- src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp | 5 +++-- src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp | 5 +++-- src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp index 00b06f6e24..b1d0059057 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp @@ -57,7 +57,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, { ARM_COMPUTE_UNUSED(alpha); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32, DataType::F16); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3"); @@ -66,7 +66,8 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.m0 < 1 || lhs_info.m0 > 8); ARM_COMPUTE_RETURN_ERROR_ON_MSG(((rhs_info.n0 & (rhs_info.n0 - 1)) && rhs_info.n0 != 3), "Only 2,3,4,8,16 are supported for n0"); ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr) - && (!gemm_info.broadcast_bias), "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); + && (!gemm_info.broadcast_bias), + "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); const unsigned int m = gemm_info.m; const unsigned int n = gemm_info.n; diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp index f0405bfd76..63451b49b8 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp @@ -63,7 +63,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_UNUSED(alpha); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output); ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32, DataType::F16); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3"); @@ -75,7 +75,8 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.m0 < 2 || lhs_info.m0 > 8); ARM_COMPUTE_RETURN_ERROR_ON_MSG(((rhs_info.n0 & (rhs_info.n0 - 1)) && rhs_info.n0 != 3), "Only 2,3,4,8,16 are supported for n0"); ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr) - && (!gemm_info.broadcast_bias), "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); + && (!gemm_info.broadcast_bias), + "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); const unsigned int m = gemm_info.m; const unsigned int n = gemm_info.n; diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp index 411a122968..0e9ca78918 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp @@ -57,7 +57,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, { ARM_COMPUTE_UNUSED(alpha); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32, DataType::F16); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3"); @@ -66,7 +66,8 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.m0 < 1 || lhs_info.m0 > 8); ARM_COMPUTE_RETURN_ERROR_ON_MSG(((rhs_info.n0 & (rhs_info.n0 - 1)) && rhs_info.n0 != 3), "Only 2,3,4,8,16 are supported for n0"); ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr) - && (!gemm_info.broadcast_bias), "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); + && (!gemm_info.broadcast_bias), + "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); const unsigned int m = gemm_info.m; const unsigned int n = gemm_info.n; -- cgit v1.2.1