aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp')
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp20
1 files changed, 10 insertions, 10 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp
index eada61e1b3..f238d42d98 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp
@@ -91,14 +91,16 @@ Status ClComponentMatMul::validate(const Properties &properties,
const auto rhs = tensors.get_const_tensor(TensorType::ACL_SRC_1);
const auto dst = tensors.get_const_tensor(TensorType::ACL_DST_0);
+ // Currently, the only supported case is when adj_lhs = false and adj_rhs = true
+ ARM_COMPUTE_RETURN_ERROR_ON((attributes.adj_lhs() != false) && (attributes.adj_rhs() != true));
+
// Check if Matching data type
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst);
// Data type
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F16, DataType::F32);
- // Data layout
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(lhs, DataLayout::NHWC);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst);
// All tensor infos are initialized
ARM_COMPUTE_RETURN_ERROR_ON(lhs->tensor_shape().total_size() == 0);
@@ -108,20 +110,18 @@ Status ClComponentMatMul::validate(const Properties &properties,
// Device requirements are met
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(lhs);
- // Check if dst shape is correct
+ // Check if block sizes are supported
MatMulKernelInfo matmul_kernel_info =
MatMulKernelInfo(attributes.adj_lhs(), attributes.adj_rhs(), settings.m0(), settings.n0(), settings.k0());
- const auto expected_dst_shape =
- misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info);
-
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), expected_dst_shape);
-
- // Check if block sizes are supported
ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(attributes, settings));
-
ARM_COMPUTE_RETURN_ON_ERROR(
opencl::kernels::validate_matmul_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info));
+ // Check if dst shape is correct
+ const auto expected_dst_shape =
+ misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), expected_dst_shape);
+
return Status{};
}