aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp')
-rw-r--r--src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp65
1 files changed, 29 insertions, 36 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp
index 88d729170c..5b136427e4 100644
--- a/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp
+++ b/src/dynamic_fusion/sketch/gpu/components/cl/ClComponentElementwiseBinary.cpp
@@ -24,6 +24,7 @@
#include "ClComponentElementwiseBinary.h"
#include "arm_compute/core/Validate.h"
+
#include "src/core/CL/CLValidate.h"
#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateElementwiseBinary.h"
@@ -39,56 +40,55 @@ namespace dynamic_fusion
{
namespace
{
-std::set<ElementwiseBinaryCommonAttributes::ElementwiseOp> supported_ops
-{
- ElementwiseBinaryCommonAttributes::ElementwiseOp::Add,
- ElementwiseBinaryCommonAttributes::ElementwiseOp::Sub,
- ElementwiseBinaryCommonAttributes::ElementwiseOp::Mul
-};
+std::set<ElementwiseBinaryCommonAttributes::ElementwiseOp> supported_ops{
+ ElementwiseBinaryCommonAttributes::ElementwiseOp::Add, ElementwiseBinaryCommonAttributes::ElementwiseOp::Sub,
+ ElementwiseBinaryCommonAttributes::ElementwiseOp::Mul};
}
-Status ClComponentElementwiseBinary::validate(const ArgumentPack<ITensorInfo> &tensors, const ElementwiseBinaryCommonAttributes &attributes)
+Status ClComponentElementwiseBinary::validate(const ArgumentPack<ITensorInfo> &tensors,
+ const ElementwiseBinaryCommonAttributes &attributes)
{
const auto lhs = tensors.get_const_tensor(TensorType::ACL_SRC_0);
const auto rhs = tensors.get_const_tensor(TensorType::ACL_SRC_1);
const auto dst = tensors.get_const_tensor(TensorType::ACL_DST_0);
// Check operator type
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(supported_ops.find(attributes.operation()) == supported_ops.end(), "Provided Elementwise operation not supported.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(supported_ops.find(attributes.operation()) == supported_ops.end(),
+ "Provided Elementwise operation not supported.");
// Check validity
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst);
//Check data type for different elementwise operators
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::S32, DataType::S16, DataType::U8);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::S32,
+ DataType::S16, DataType::U8);
// dst shape is correct
const TensorShape out_shape = TensorShape::broadcast_shape(lhs->tensor_shape(), rhs->tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0),
+ "Wrong shape for dst.");
const auto &lhs_shape = lhs->tensor_shape();
const auto &rhs_shape = rhs->tensor_shape();
const auto &dst_shape = dst->tensor_shape();
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- detail::have_different_dimensions(lhs_shape, dst_shape, 0) && detail::have_different_dimensions(rhs_shape, dst_shape, 0),
- "Only LHS or RHS can be broadcasting, not both.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(lhs_shape, dst_shape, 0) &&
+ detail::have_different_dimensions(rhs_shape, dst_shape, 0),
+ "Only LHS or RHS can be broadcasting, not both.");
// Dimension Y and Z are collapsed together in the current kernel implementation,
// hence they cannot be independently broadcast or non-broadcast.
// See: ClTemplateElementwiseBinary::get_window
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- (lhs_shape[1] != dst_shape[1] || rhs_shape[1] != dst_shape[1]) != (lhs_shape[2] != dst_shape[2] || rhs_shape[2] != dst_shape[2]),
- "Dimension Y and Z must both be either broadcast or non-broadcast.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((lhs_shape[1] != dst_shape[1] || rhs_shape[1] != dst_shape[1]) !=
+ (lhs_shape[2] != dst_shape[2] || rhs_shape[2] != dst_shape[2]),
+ "Dimension Y and Z must both be either broadcast or non-broadcast.");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- detail::have_different_dimensions(lhs_shape, dst_shape, 3),
- "LHS broadcast in dimension 3 or higher is not supported.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(lhs_shape, dst_shape, 3),
+ "LHS broadcast in dimension 3 or higher is not supported.");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- detail::have_different_dimensions(rhs_shape, dst_shape, 3),
- "RHS broadcast in dimension 3 or higher is not supported.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(rhs_shape, dst_shape, 3),
+ "RHS broadcast in dimension 3 or higher is not supported.");
// Matching data type
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs);
@@ -112,22 +112,15 @@ Status ClComponentElementwiseBinary::validate(const ArgumentPack<ITensorInfo> &t
ClComponentElementwiseBinary::~ClComponentElementwiseBinary()
{
}
-ClComponentElementwiseBinary::ClComponentElementwiseBinary(
- ComponentId id,
- const Properties &properties,
- const ArgumentPack<ITensorInfo> &tensors,
- const Attributes &attributes)
- : IGpuKernelComponent{ id, properties, tensors },
+ClComponentElementwiseBinary::ClComponentElementwiseBinary(ComponentId id,
+ const Properties &properties,
+ const ArgumentPack<ITensorInfo> &tensors,
+ const Attributes &attributes)
+ : IGpuKernelComponent{id, properties, tensors},
#ifndef ACL_INTERNAL_TEST_CKW_IN_DF
- _component_writer
-{
- std::make_unique<ClTemplateElementwiseBinary>(id, tensors, attributes)
-}
+ _component_writer{std::make_unique<ClTemplateElementwiseBinary>(id, tensors, attributes)}
#else //ACL_INTERNAL_TEST_CKW_IN_DF
- _component_writer
-{
- std::make_unique<GpuCkwElementwiseBinary>(id, tensors, attributes)
-}
+ _component_writer{std::make_unique<GpuCkwElementwiseBinary>(id, tensors, attributes)}
#endif //ACL_INTERNAL_TEST_CKW_IN_DF
{
}