aboutsummaryrefslogtreecommitdiff
path: root/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp')
-rw-r--r--src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp63
1 files changed, 58 insertions, 5 deletions
diff --git a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp
index de58ce70ed..cab51a2ce6 100644
--- a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp
+++ b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp
@@ -124,7 +124,7 @@ bool ClDirectConv2dKernel::operator==(const ClKernel &other) const
return config() == other.config() && tensors() == other.tensors() && desc == converted.desc;
}
-Status ClAddKernel::generate(ClKernelBlueprint &bp) const
+Status ClElementwiseKernel::generate(ClKernelBlueprint &bp) const
{
const auto lhs = _tensors.get_const_tensor(TensorType::ACL_SRC_0);
const auto rhs = _tensors.get_const_tensor(TensorType::ACL_SRC_1);
@@ -137,11 +137,11 @@ Status ClAddKernel::generate(ClKernelBlueprint &bp) const
ArgumentID dst_id;
add_tensor(bp, dst->desc, dst_id, dst->id);
- add_kcomp_eltwise_add(bp, desc, lhs_id, rhs_id, dst_id);
+ add_kcomp_eltwise_op(bp, desc, lhs_id, rhs_id, dst_id);
return Status{};
}
-Status ClAddKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst)
+Status ClElementwiseKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst)
{
// 1. Check validity
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst);
@@ -186,9 +186,61 @@ Status ClAddKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, con
return Status{};
}
-bool ClAddKernel::operator==(const ClKernel &other) const
+bool ClElementwiseKernel::operator==(const ClKernel &other) const
{
- const auto converted = *utils::cast::polymorphic_downcast<const ClAddKernel *>(&other);
+ const auto converted = *utils::cast::polymorphic_downcast<const ClElementwiseKernel *>(&other);
+ return config() == other.config() && tensors() == other.tensors() && desc == converted.desc;
+}
+
+Status ClFloorKernel::generate(ClKernelBlueprint &bp) const
+{
+ const auto src = _tensors.get_const_tensor(TensorType::ACL_SRC_0);
+ const auto dst = _tensors.get_const_tensor(TensorType::ACL_DST_0);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
+ ArgumentID src_id;
+ add_tensor(bp, src->desc, src_id, src->id);
+ ArgumentID dst_id;
+ add_tensor(bp, dst->desc, dst_id, dst->id);
+
+ add_kcomp_floor(bp, desc, src_id, dst_id);
+ return Status{};
+}
+
+Status ClFloorKernel::validate(const ITensorInfo *src, const ITensorInfo *dst)
+{
+ // 1. Check validity
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst);
+
+ // Matching data type
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
+
+ // Matching data layout
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, dst);
+
+ // All tensor infos are initialized
+ ARM_COMPUTE_RETURN_ERROR_ON(src->tensor_shape().total_size() == 0);
+ ARM_COMPUTE_RETURN_ERROR_ON(dst->tensor_shape().total_size() == 0);
+
+ // Device requirements are met
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src);
+
+ // dst shape is correct
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(src->tensor_shape(), dst->tensor_shape(), 0), "Wrong shape for dst");
+
+ // 2. Check support level
+
+ // Data type
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32, DataType::F16);
+
+ // Data layout
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(src, DataLayout::NHWC);
+
+ return Status{};
+}
+
+bool ClFloorKernel::operator==(const ClKernel &other) const
+{
+ const auto converted = *utils::cast::polymorphic_downcast<const ClFloorKernel *>(&other);
return config() == other.config() && tensors() == other.tensors() && desc == converted.desc;
}
@@ -202,6 +254,7 @@ std::vector<const ClKernel *> traverse(const ClKernelGraph &graph)
}
return kernels;
}
+
std::vector<ClKernel *> traverse(ClKernelGraph &graph)
{
std::vector<ClKernel *> kernels;