aboutsummaryrefslogtreecommitdiff
path: root/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2022-06-15 19:02:28 +0100
committerMichalis Spyrou <michalis.spyrou@arm.com>2022-06-27 14:05:05 +0000
commitb1fcefddf3f59219a9d7930d607175b7e6c39347 (patch)
tree34e95efded15194b3c8abe4ba3da308c3259301d /src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp
parent41eb2d92c89274200d59ff97653e2bd66819b310 (diff)
downloadComputeLibrary-b1fcefddf3f59219a9d7930d607175b7e6c39347.tar.gz
Implement new Elementwise Dynamic Fusion Operators: Div, Floor
Resolves: COMPMID-5355 Change-Id: I92f73fbe885f28bbe7b07965b90cfd807c93602f Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7745 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com>
Diffstat (limited to 'src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp')
-rw-r--r--src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp63
1 files changed, 58 insertions, 5 deletions
diff --git a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp
index de58ce70ed..cab51a2ce6 100644
--- a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp
+++ b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp
@@ -124,7 +124,7 @@ bool ClDirectConv2dKernel::operator==(const ClKernel &other) const
return config() == other.config() && tensors() == other.tensors() && desc == converted.desc;
}
-Status ClAddKernel::generate(ClKernelBlueprint &bp) const
+Status ClElementwiseKernel::generate(ClKernelBlueprint &bp) const
{
const auto lhs = _tensors.get_const_tensor(TensorType::ACL_SRC_0);
const auto rhs = _tensors.get_const_tensor(TensorType::ACL_SRC_1);
@@ -137,11 +137,11 @@ Status ClAddKernel::generate(ClKernelBlueprint &bp) const
ArgumentID dst_id;
add_tensor(bp, dst->desc, dst_id, dst->id);
- add_kcomp_eltwise_add(bp, desc, lhs_id, rhs_id, dst_id);
+ add_kcomp_eltwise_op(bp, desc, lhs_id, rhs_id, dst_id);
return Status{};
}
-Status ClAddKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst)
+Status ClElementwiseKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst)
{
// 1. Check validity
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst);
@@ -186,9 +186,61 @@ Status ClAddKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, con
return Status{};
}
-bool ClAddKernel::operator==(const ClKernel &other) const
+bool ClElementwiseKernel::operator==(const ClKernel &other) const
{
- const auto converted = *utils::cast::polymorphic_downcast<const ClAddKernel *>(&other);
+ const auto converted = *utils::cast::polymorphic_downcast<const ClElementwiseKernel *>(&other);
+ return config() == other.config() && tensors() == other.tensors() && desc == converted.desc;
+}
+
+Status ClFloorKernel::generate(ClKernelBlueprint &bp) const
+{
+ const auto src = _tensors.get_const_tensor(TensorType::ACL_SRC_0);
+ const auto dst = _tensors.get_const_tensor(TensorType::ACL_DST_0);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
+ ArgumentID src_id;
+ add_tensor(bp, src->desc, src_id, src->id);
+ ArgumentID dst_id;
+ add_tensor(bp, dst->desc, dst_id, dst->id);
+
+ add_kcomp_floor(bp, desc, src_id, dst_id);
+ return Status{};
+}
+
+Status ClFloorKernel::validate(const ITensorInfo *src, const ITensorInfo *dst)
+{
+ // 1. Check validity
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst);
+
+ // Matching data type
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
+
+ // Matching data layout
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, dst);
+
+ // All tensor infos are initialized
+ ARM_COMPUTE_RETURN_ERROR_ON(src->tensor_shape().total_size() == 0);
+ ARM_COMPUTE_RETURN_ERROR_ON(dst->tensor_shape().total_size() == 0);
+
+ // Device requirements are met
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src);
+
+ // dst shape is correct
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(src->tensor_shape(), dst->tensor_shape(), 0), "Wrong shape for dst");
+
+ // 2. Check support level
+
+ // Data type
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32, DataType::F16);
+
+ // Data layout
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(src, DataLayout::NHWC);
+
+ return Status{};
+}
+
+bool ClFloorKernel::operator==(const ClKernel &other) const
+{
+ const auto converted = *utils::cast::polymorphic_downcast<const ClFloorKernel *>(&other);
return config() == other.config() && tensors() == other.tensors() && desc == converted.desc;
}
@@ -202,6 +254,7 @@ std::vector<const ClKernel *> traverse(const ClKernelGraph &graph)
}
return kernels;
}
+
std::vector<ClKernel *> traverse(ClKernelGraph &graph)
{
std::vector<ClKernel *> kernels;