From b1fcefddf3f59219a9d7930d607175b7e6c39347 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Wed, 15 Jun 2022 19:02:28 +0100 Subject: Implement new Elementwise Dynamic Fusion Operators: Div, Floor Resolves: COMPMID-5355 Change-Id: I92f73fbe885f28bbe7b07965b90cfd807c93602f Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7745 Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: SiCong Li --- .../WorkloadImpl/OperatorGraphImpl.cpp | 55 +++++++++++++++++++--- 1 file changed, 48 insertions(+), 7 deletions(-) (limited to 'src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp') diff --git a/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp b/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp index f971196729..274a2517bb 100644 --- a/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp +++ b/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp @@ -113,9 +113,14 @@ bool operator==(const Conv2dDescriptor &conv2d0, const Conv2dDescriptor &conv2d1 return std::make_tuple(conv2d0.pad, conv2d0.stride, conv2d0.dilation) == std::make_tuple(conv2d1.pad, conv2d1.stride, conv2d1.dilation); } -bool operator==(const AddDescriptor &, const AddDescriptor &) +bool operator==(const ElementwiseDescriptor &ed0, const ElementwiseDescriptor &ed1) { - return std::make_tuple() == std::make_tuple(); // Currently two Add ops are always the same + return ed0.op == ed1.op; // Compare Arithmatic Operations of two ElementwiseDescriptor objects +} + +bool operator==(const FloorDescriptor &, const FloorDescriptor &) +{ + return std::make_tuple() == std::make_tuple(); // Currently two Floor ops are always the same } bool Conv2dContent::operator==(const OperatorContent &other) const @@ -124,9 +129,15 @@ bool Conv2dContent::operator==(const OperatorContent &other) const return desc == converted.desc; } -bool AddContent::operator==(const OperatorContent &other) const +bool ElementwiseContent::operator==(const OperatorContent &other) const +{ + const auto converted = *utils::cast::polymorphic_downcast(&other); + return desc == converted.desc; +} + +bool FloorContent::operator==(const OperatorContent &other) const { - const auto converted = *utils::cast::polymorphic_downcast(&other); + const auto converted = *utils::cast::polymorphic_downcast(&other); return desc == converted.desc; } @@ -311,7 +322,7 @@ Status Conv2dContent::translate_direct_conv2d(ClKernelGraph &kernel_graph) const return Status{}; } -Status AddContent::translate(ClKernelGraph &kernel_graph) const +Status ElementwiseContent::translate(ClKernelGraph &kernel_graph) const { const auto lhs = _tensors.get_const_tensor(TensorType::ACL_SRC_0); const auto rhs = _tensors.get_const_tensor(TensorType::ACL_SRC_1); @@ -338,16 +349,46 @@ Status AddContent::translate(ClKernelGraph &kernel_graph) const DependencyGraph::Id add_id; ClKernelConfig config{ UnitWorkloadStage{ UnitWorkloadStage::Stage::Run }, TileDescriptor{}, StoreType::TStoreIndirectWidthSelect }; - st = ClAddKernel::validate(lhs->desc, rhs->desc, dst->desc); + st = ClElementwiseKernel::validate(lhs->desc, rhs->desc, dst->desc); ARM_COMPUTE_RETURN_ON_ERROR(st); - st = kernel_graph.add_kernel(config, ClEltwiseAddKernelDescriptor{ desc }, tensors, add_id); + st = kernel_graph.add_kernel(config, ClElementwiseKernelDescriptor{ desc }, tensors, add_id); ARM_COMPUTE_RETURN_ON_ERROR(st); ARM_COMPUTE_UNUSED(add_id); return Status{}; } +Status FloorContent::translate(ClKernelGraph &kernel_graph) const +{ + const auto src = _tensors.get_const_tensor(TensorType::ACL_SRC_0); + const auto dst = _tensors.get_const_tensor(TensorType::ACL_DST_0); + ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst); + + ITensorDescPack tensors; + + DependencyGraph::Id src_id; + auto st = add_kernel_tensor(kernel_graph, *_graph, *src, src_id); + ARM_COMPUTE_RETURN_ON_ERROR(st); + tensors.add_const_tensor(ACL_SRC_0, kernel_graph.get_tensor(src_id)); + + DependencyGraph::Id dst_id; + st = add_kernel_tensor(kernel_graph, *_graph, *dst, dst_id); + ARM_COMPUTE_RETURN_ON_ERROR(st); + tensors.add_const_tensor(ACL_DST_0, kernel_graph.get_tensor(dst_id)); + + DependencyGraph::Id add_id; + ClKernelConfig config{ UnitWorkloadStage{ UnitWorkloadStage::Stage::Run }, TileDescriptor{}, StoreType::TStoreIndirectWidthSelect }; + + st = ClFloorKernel::validate(src->desc, dst->desc); + ARM_COMPUTE_RETURN_ON_ERROR(st); + + st = kernel_graph.add_kernel(config, ClFloorKernelDescriptor{ desc }, tensors, add_id); + ARM_COMPUTE_RETURN_ON_ERROR(st); + + return Status{}; +} + std::vector traverse(const OperatorGraph::Implementation &graph) { std::vector ops; -- cgit v1.2.1