From b238f5f720be59c10b6caa633a481d0d9a3cc7a0 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 2 Aug 2019 09:09:53 +0100 Subject: COMPMID-2539: Add bias addition check in CLGEMM validation Change-Id: Ib33574662d2b62ce80dd7f74a656199ed64225bc Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1676 Reviewed-by: Michele Di Giorgio Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins --- src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp | 2 ++ src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp | 2 ++ src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp | 2 ++ src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp | 2 ++ 4 files changed, 8 insertions(+) diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp index e793c65059..c64ed580ce 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp @@ -61,6 +61,8 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 2 && reshape_info.reinterpret_input_as_3d(), "The input1 tensor cannot have more than 2 dimensions if input0 has to be reinterpreted as 3D"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((reshape_info.reinterpret_input_as_3d() || reshape_info.depth_output_gemm3d() != 0) && (input2 != nullptr) + && (!reshape_info.broadcast_bias()), "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); if(!is_interleaved_transposed) { diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp index 3c07c1ddee..00b06f6e24 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp @@ -65,6 +65,8 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_RETURN_ERROR_ON(rhs_info.k0 > 16); ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.m0 < 1 || lhs_info.m0 > 8); ARM_COMPUTE_RETURN_ERROR_ON_MSG(((rhs_info.n0 & (rhs_info.n0 - 1)) && rhs_info.n0 != 3), "Only 2,3,4,8,16 are supported for n0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr) + && (!gemm_info.broadcast_bias), "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); const unsigned int m = gemm_info.m; const unsigned int n = gemm_info.n; diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp index fd6fd7c970..f0405bfd76 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp @@ -74,6 +74,8 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.k0 > 16); ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.m0 < 2 || lhs_info.m0 > 8); ARM_COMPUTE_RETURN_ERROR_ON_MSG(((rhs_info.n0 & (rhs_info.n0 - 1)) && rhs_info.n0 != 3), "Only 2,3,4,8,16 are supported for n0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr) + && (!gemm_info.broadcast_bias), "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); const unsigned int m = gemm_info.m; const unsigned int n = gemm_info.n; diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp index 5f92cad8a7..411a122968 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp @@ -65,6 +65,8 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_RETURN_ERROR_ON(rhs_info.k0 > 16); ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.m0 < 1 || lhs_info.m0 > 8); ARM_COMPUTE_RETURN_ERROR_ON_MSG(((rhs_info.n0 & (rhs_info.n0 - 1)) && rhs_info.n0 != 3), "Only 2,3,4,8,16 are supported for n0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr) + && (!gemm_info.broadcast_bias), "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); const unsigned int m = gemm_info.m; const unsigned int n = gemm_info.n; -- cgit v1.2.1