aboutsummaryrefslogtreecommitdiff
path: root/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp')
-rw-r--r--src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp55
1 files changed, 48 insertions, 7 deletions
diff --git a/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp b/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp
index f971196729..274a2517bb 100644
--- a/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp
+++ b/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp
@@ -113,9 +113,14 @@ bool operator==(const Conv2dDescriptor &conv2d0, const Conv2dDescriptor &conv2d1
return std::make_tuple(conv2d0.pad, conv2d0.stride, conv2d0.dilation) == std::make_tuple(conv2d1.pad, conv2d1.stride, conv2d1.dilation);
}
-bool operator==(const AddDescriptor &, const AddDescriptor &)
+bool operator==(const ElementwiseDescriptor &ed0, const ElementwiseDescriptor &ed1)
{
- return std::make_tuple() == std::make_tuple(); // Currently two Add ops are always the same
+ return ed0.op == ed1.op; // Compare Arithmatic Operations of two ElementwiseDescriptor objects
+}
+
+bool operator==(const FloorDescriptor &, const FloorDescriptor &)
+{
+ return std::make_tuple() == std::make_tuple(); // Currently two Floor ops are always the same
}
bool Conv2dContent::operator==(const OperatorContent &other) const
@@ -124,9 +129,15 @@ bool Conv2dContent::operator==(const OperatorContent &other) const
return desc == converted.desc;
}
-bool AddContent::operator==(const OperatorContent &other) const
+bool ElementwiseContent::operator==(const OperatorContent &other) const
+{
+ const auto converted = *utils::cast::polymorphic_downcast<const ElementwiseContent *>(&other);
+ return desc == converted.desc;
+}
+
+bool FloorContent::operator==(const OperatorContent &other) const
{
- const auto converted = *utils::cast::polymorphic_downcast<const AddContent *>(&other);
+ const auto converted = *utils::cast::polymorphic_downcast<const FloorContent *>(&other);
return desc == converted.desc;
}
@@ -311,7 +322,7 @@ Status Conv2dContent::translate_direct_conv2d(ClKernelGraph &kernel_graph) const
return Status{};
}
-Status AddContent::translate(ClKernelGraph &kernel_graph) const
+Status ElementwiseContent::translate(ClKernelGraph &kernel_graph) const
{
const auto lhs = _tensors.get_const_tensor(TensorType::ACL_SRC_0);
const auto rhs = _tensors.get_const_tensor(TensorType::ACL_SRC_1);
@@ -338,16 +349,46 @@ Status AddContent::translate(ClKernelGraph &kernel_graph) const
DependencyGraph::Id add_id;
ClKernelConfig config{ UnitWorkloadStage{ UnitWorkloadStage::Stage::Run }, TileDescriptor{}, StoreType::TStoreIndirectWidthSelect };
- st = ClAddKernel::validate(lhs->desc, rhs->desc, dst->desc);
+ st = ClElementwiseKernel::validate(lhs->desc, rhs->desc, dst->desc);
ARM_COMPUTE_RETURN_ON_ERROR(st);
- st = kernel_graph.add_kernel<ClAddKernel>(config, ClEltwiseAddKernelDescriptor{ desc }, tensors, add_id);
+ st = kernel_graph.add_kernel<ClElementwiseKernel>(config, ClElementwiseKernelDescriptor{ desc }, tensors, add_id);
ARM_COMPUTE_RETURN_ON_ERROR(st);
ARM_COMPUTE_UNUSED(add_id);
return Status{};
}
+Status FloorContent::translate(ClKernelGraph &kernel_graph) const
+{
+ const auto src = _tensors.get_const_tensor(TensorType::ACL_SRC_0);
+ const auto dst = _tensors.get_const_tensor(TensorType::ACL_DST_0);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
+
+ ITensorDescPack<ClKernelTensor> tensors;
+
+ DependencyGraph::Id src_id;
+ auto st = add_kernel_tensor(kernel_graph, *_graph, *src, src_id);
+ ARM_COMPUTE_RETURN_ON_ERROR(st);
+ tensors.add_const_tensor(ACL_SRC_0, kernel_graph.get_tensor(src_id));
+
+ DependencyGraph::Id dst_id;
+ st = add_kernel_tensor(kernel_graph, *_graph, *dst, dst_id);
+ ARM_COMPUTE_RETURN_ON_ERROR(st);
+ tensors.add_const_tensor(ACL_DST_0, kernel_graph.get_tensor(dst_id));
+
+ DependencyGraph::Id add_id;
+ ClKernelConfig config{ UnitWorkloadStage{ UnitWorkloadStage::Stage::Run }, TileDescriptor{}, StoreType::TStoreIndirectWidthSelect };
+
+ st = ClFloorKernel::validate(src->desc, dst->desc);
+ ARM_COMPUTE_RETURN_ON_ERROR(st);
+
+ st = kernel_graph.add_kernel<ClFloorKernel>(config, ClFloorKernelDescriptor{ desc }, tensors, add_id);
+ ARM_COMPUTE_RETURN_ON_ERROR(st);
+
+ return Status{};
+}
+
std::vector<const OperatorContent *> traverse(const OperatorGraph::Implementation &graph)
{
std::vector<const OperatorContent *> ops;