aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/components/cl
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-01-03 17:59:14 +0000
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-01-06 13:45:20 +0000
commitb3077fbaee868579f9a41888fef1f71286d6757c (patch)
tree23e6a6c63dc860697ae8e9301da7ddbb29d62c98 /src/dynamic_fusion/sketch/gpu/components/cl
parent3558c5840e7c973e2b1a86ae3a9335b44cad59d4 (diff)
downloadComputeLibrary-b3077fbaee868579f9a41888fef1f71286d6757c.tar.gz
LHS broadcasting addition for dynamic fusion
* Binary elementwise operator now can have broadcasting in either X dimension, Y+Z dimension, or both, in either LHS or RHS operand. * Fix bug in CL code to support batching. Resolves: COMPMID-5704 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: I51b04986d30861f255ca9f754adffa0e6c85a26b Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8898 Reviewed-by: SiCong Li <sicong.li@arm.com> Reviewed-by: Ramy Elgammal <ramy.elgammal@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Dynamic-Fusion: Ramy Elgammal <ramy.elgammal@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/components/cl')
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp32
1 files changed, 25 insertions, 7 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp
index a17d835ac6..9b006b13ce 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022 Arm Limited.
+ * Copyright (c) 2022-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -66,8 +66,30 @@ Status ClComponentElementwiseBinary::validate(const ArgumentPack<ITensorInfo> &t
const TensorShape out_shape = TensorShape::broadcast_shape(lhs->tensor_shape(), rhs->tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst.");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((!rhs_in_place && !lhs_in_place) && detail::have_different_dimensions(lhs->tensor_shape(), dst->tensor_shape(), 0),
- "Only the rhs operand can be broadcast to match the accumulator's (lhs) shape");
+
+ const auto &lhs_shape = lhs->tensor_shape();
+ const auto &rhs_shape = rhs->tensor_shape();
+ const auto &dst_shape = dst->tensor_shape();
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ detail::have_different_dimensions(lhs_shape, dst_shape, 0) && detail::have_different_dimensions(rhs_shape, dst_shape, 0),
+ "Only LHS or RHS can be broadcasting, not both.");
+
+ // Dimension Y and Z are collapsed together in the current kernel implementation,
+ // hence they cannot be independently broadcast or non-broadcast.
+ // See: ClTemplateElementwiseBinary::get_window
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ (lhs_shape[1] != dst_shape[1] || rhs_shape[1] != dst_shape[1]) != (lhs_shape[2] != dst_shape[2] || rhs_shape[2] != dst_shape[2]),
+ "Dimension Y and Z must both be either broadcast or non-broadcast.");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ detail::have_different_dimensions(lhs_shape, dst_shape, 3),
+ "LHS broadcast in dimension 3 or higher is not supported.");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ detail::have_different_dimensions(rhs_shape, dst_shape, 3),
+ "RHS broadcast in dimension 3 or higher is not supported.");
+
// Matching data type
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst);
@@ -76,10 +98,6 @@ Status ClComponentElementwiseBinary::validate(const ArgumentPack<ITensorInfo> &t
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(lhs, rhs);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(lhs, dst);
- // Batching case not supported yet
- const size_t idx_batch = get_data_layout_dimension_index(lhs->data_layout(), DataLayoutDimension::BATCHES);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((lhs->tensor_shape()[idx_batch] != 1) || (rhs->tensor_shape()[idx_batch] != 1) || (dst->tensor_shape()[idx_batch] != 1), "Batching case not supported yet");
-
// All tensor infos are initialized
ARM_COMPUTE_RETURN_ERROR_ON(lhs->tensor_shape().total_size() == 0);
ARM_COMPUTE_RETURN_ERROR_ON(rhs->tensor_shape().total_size() == 0);