From fde45d836cf753a94915ac42d8a13da7edc52221 Mon Sep 17 00:00:00 2001 From: Adnan AlSinan Date: Tue, 24 Oct 2023 12:03:21 +0100 Subject: Extend CKW MatMul with nt_t - Add the kernel variant: (nt_t) to GpuCKWMatMul. - Extend CKW MatMul validation test with nt_t. - Fixes a bug in CKW where z-dim = 1. Resolves: COMPMID-6435 Signed-off-by: Adnan AlSinan Change-Id: I4c5e8791e55f21ffff3c11eca7802c51a4259977 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10525 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Gian Marco Iodice Benchmark: Arm Jenkins --- .../sketch/gpu/components/cl/ClComponentMatMul.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) (limited to 'src/dynamic_fusion/sketch/gpu/components/cl') diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp index eada61e1b3..f238d42d98 100644 --- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp +++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp @@ -91,14 +91,16 @@ Status ClComponentMatMul::validate(const Properties &properties, const auto rhs = tensors.get_const_tensor(TensorType::ACL_SRC_1); const auto dst = tensors.get_const_tensor(TensorType::ACL_DST_0); + // Currently, the only supported case is when adj_lhs = false and adj_rhs = true + ARM_COMPUTE_RETURN_ERROR_ON((attributes.adj_lhs() != false) && (attributes.adj_rhs() != true)); + // Check if Matching data type ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst); // Data type ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F16, DataType::F32); - // Data layout - ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(lhs, DataLayout::NHWC); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst); // All tensor infos are initialized ARM_COMPUTE_RETURN_ERROR_ON(lhs->tensor_shape().total_size() == 0); @@ -108,20 +110,18 @@ Status ClComponentMatMul::validate(const Properties &properties, // Device requirements are met ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(lhs); - // Check if dst shape is correct + // Check if block sizes are supported MatMulKernelInfo matmul_kernel_info = MatMulKernelInfo(attributes.adj_lhs(), attributes.adj_rhs(), settings.m0(), settings.n0(), settings.k0()); - const auto expected_dst_shape = - misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info); - - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), expected_dst_shape); - - // Check if block sizes are supported ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(attributes, settings)); - ARM_COMPUTE_RETURN_ON_ERROR( opencl::kernels::validate_matmul_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); + // Check if dst shape is correct + const auto expected_dst_shape = + misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), expected_dst_shape); + return Status{}; } -- cgit v1.2.1