aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp
diff options
context:
space:
mode:
authorAdnan AlSinan <adnan.alsinan@arm.com>2023-10-24 12:03:21 +0100
committerAdnan AlSinan <adnan.alsinan@arm.com>2023-10-31 11:00:45 +0000
commitfde45d836cf753a94915ac42d8a13da7edc52221 (patch)
tree6ed787749aa3caec13a0b3c2c64ea591b423089c /src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp
parent5ef0bdd53dd2ce6bc7ad28077ffac3bf9e939b5f (diff)
downloadComputeLibrary-fde45d836cf753a94915ac42d8a13da7edc52221.tar.gz
Extend CKW MatMul with nt_t
- Add the kernel variant: (nt_t) to GpuCKWMatMul. - Extend CKW MatMul validation test with nt_t. - Fixes a bug in CKW where z-dim = 1. Resolves: COMPMID-6435 Signed-off-by: Adnan AlSinan <adnan.alsinan@arm.com> Change-Id: I4c5e8791e55f21ffff3c11eca7802c51a4259977 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10525 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp')
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp20
1 files changed, 10 insertions, 10 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp
index eada61e1b3..f238d42d98 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp
@@ -91,14 +91,16 @@ Status ClComponentMatMul::validate(const Properties &properties,
const auto rhs = tensors.get_const_tensor(TensorType::ACL_SRC_1);
const auto dst = tensors.get_const_tensor(TensorType::ACL_DST_0);
+ // Currently, the only supported case is when adj_lhs = false and adj_rhs = true
+ ARM_COMPUTE_RETURN_ERROR_ON((attributes.adj_lhs() != false) && (attributes.adj_rhs() != true));
+
// Check if Matching data type
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst);
// Data type
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F16, DataType::F32);
- // Data layout
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(lhs, DataLayout::NHWC);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst);
// All tensor infos are initialized
ARM_COMPUTE_RETURN_ERROR_ON(lhs->tensor_shape().total_size() == 0);
@@ -108,20 +110,18 @@ Status ClComponentMatMul::validate(const Properties &properties,
// Device requirements are met
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(lhs);
- // Check if dst shape is correct
+ // Check if block sizes are supported
MatMulKernelInfo matmul_kernel_info =
MatMulKernelInfo(attributes.adj_lhs(), attributes.adj_rhs(), settings.m0(), settings.n0(), settings.k0());
- const auto expected_dst_shape =
- misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info);
-
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), expected_dst_shape);
-
- // Check if block sizes are supported
ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(attributes, settings));
-
ARM_COMPUTE_RETURN_ON_ERROR(
opencl::kernels::validate_matmul_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info));
+ // Check if dst shape is correct
+ const auto expected_dst_shape =
+ misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), expected_dst_shape);
+
return Status{};
}