diff options
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/components')
-rw-r--r-- | src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp | 32 |
1 files changed, 25 insertions, 7 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp index a17d835ac6..9b006b13ce 100644 --- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp +++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Arm Limited. + * Copyright (c) 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -66,8 +66,30 @@ Status ClComponentElementwiseBinary::validate(const ArgumentPack<ITensorInfo> &t const TensorShape out_shape = TensorShape::broadcast_shape(lhs->tensor_shape(), rhs->tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst."); - ARM_COMPUTE_RETURN_ERROR_ON_MSG((!rhs_in_place && !lhs_in_place) && detail::have_different_dimensions(lhs->tensor_shape(), dst->tensor_shape(), 0), - "Only the rhs operand can be broadcast to match the accumulator's (lhs) shape"); + + const auto &lhs_shape = lhs->tensor_shape(); + const auto &rhs_shape = rhs->tensor_shape(); + const auto &dst_shape = dst->tensor_shape(); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + detail::have_different_dimensions(lhs_shape, dst_shape, 0) && detail::have_different_dimensions(rhs_shape, dst_shape, 0), + "Only LHS or RHS can be broadcasting, not both."); + + // Dimension Y and Z are collapsed together in the current kernel implementation, + // hence they cannot be independently broadcast or non-broadcast. + // See: ClTemplateElementwiseBinary::get_window + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + (lhs_shape[1] != dst_shape[1] || rhs_shape[1] != dst_shape[1]) != (lhs_shape[2] != dst_shape[2] || rhs_shape[2] != dst_shape[2]), + "Dimension Y and Z must both be either broadcast or non-broadcast."); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + detail::have_different_dimensions(lhs_shape, dst_shape, 3), + "LHS broadcast in dimension 3 or higher is not supported."); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + detail::have_different_dimensions(rhs_shape, dst_shape, 3), + "RHS broadcast in dimension 3 or higher is not supported."); + // Matching data type ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst); @@ -76,10 +98,6 @@ Status ClComponentElementwiseBinary::validate(const ArgumentPack<ITensorInfo> &t ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(lhs, rhs); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(lhs, dst); - // Batching case not supported yet - const size_t idx_batch = get_data_layout_dimension_index(lhs->data_layout(), DataLayoutDimension::BATCHES); - ARM_COMPUTE_RETURN_ERROR_ON_MSG((lhs->tensor_shape()[idx_batch] != 1) || (rhs->tensor_shape()[idx_batch] != 1) || (dst->tensor_shape()[idx_batch] != 1), "Batching case not supported yet"); - // All tensor infos are initialized ARM_COMPUTE_RETURN_ERROR_ON(lhs->tensor_shape().total_size() == 0); ARM_COMPUTE_RETURN_ERROR_ON(rhs->tensor_shape().total_size() == 0); |