diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp index eada61e1b3..f238d42d98 100644 --- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp +++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentMatMul.cpp @@ -91,14 +91,16 @@ Status ClComponentMatMul::validate(const Properties &properties, const auto rhs = tensors.get_const_tensor(TensorType::ACL_SRC_1); const auto dst = tensors.get_const_tensor(TensorType::ACL_DST_0); + // Currently, the only supported case is when adj_lhs = false and adj_rhs = true + ARM_COMPUTE_RETURN_ERROR_ON((attributes.adj_lhs() != false) && (attributes.adj_rhs() != true)); + // Check if Matching data type ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst); // Data type ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F16, DataType::F32); - // Data layout - ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(lhs, DataLayout::NHWC); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst); // All tensor infos are initialized ARM_COMPUTE_RETURN_ERROR_ON(lhs->tensor_shape().total_size() == 0); @@ -108,20 +110,18 @@ Status ClComponentMatMul::validate(const Properties &properties, // Device requirements are met ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(lhs); - // Check if dst shape is correct + // Check if block sizes are supported MatMulKernelInfo matmul_kernel_info = MatMulKernelInfo(attributes.adj_lhs(), attributes.adj_rhs(), settings.m0(), settings.n0(), settings.k0()); - const auto expected_dst_shape = - misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info); - - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), expected_dst_shape); - - // Check if block sizes are supported ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(attributes, settings)); - ARM_COMPUTE_RETURN_ON_ERROR( opencl::kernels::validate_matmul_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); + // Check if dst shape is correct + const auto expected_dst_shape = + misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), expected_dst_shape); + return Status{}; } |