From b1fcefddf3f59219a9d7930d607175b7e6c39347 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Wed, 15 Jun 2022 19:02:28 +0100 Subject: Implement new Elementwise Dynamic Fusion Operators: Div, Floor Resolves: COMPMID-5355 Change-Id: I92f73fbe885f28bbe7b07965b90cfd807c93602f Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7745 Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: SiCong Li --- .../dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp | 63 ++++++++++++++++++++-- 1 file changed, 58 insertions(+), 5 deletions(-) (limited to 'src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp') diff --git a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp index de58ce70ed..cab51a2ce6 100644 --- a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp +++ b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp @@ -124,7 +124,7 @@ bool ClDirectConv2dKernel::operator==(const ClKernel &other) const return config() == other.config() && tensors() == other.tensors() && desc == converted.desc; } -Status ClAddKernel::generate(ClKernelBlueprint &bp) const +Status ClElementwiseKernel::generate(ClKernelBlueprint &bp) const { const auto lhs = _tensors.get_const_tensor(TensorType::ACL_SRC_0); const auto rhs = _tensors.get_const_tensor(TensorType::ACL_SRC_1); @@ -137,11 +137,11 @@ Status ClAddKernel::generate(ClKernelBlueprint &bp) const ArgumentID dst_id; add_tensor(bp, dst->desc, dst_id, dst->id); - add_kcomp_eltwise_add(bp, desc, lhs_id, rhs_id, dst_id); + add_kcomp_eltwise_op(bp, desc, lhs_id, rhs_id, dst_id); return Status{}; } -Status ClAddKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst) +Status ClElementwiseKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst) { // 1. Check validity ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst); @@ -186,9 +186,61 @@ Status ClAddKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, con return Status{}; } -bool ClAddKernel::operator==(const ClKernel &other) const +bool ClElementwiseKernel::operator==(const ClKernel &other) const { - const auto converted = *utils::cast::polymorphic_downcast(&other); + const auto converted = *utils::cast::polymorphic_downcast(&other); + return config() == other.config() && tensors() == other.tensors() && desc == converted.desc; +} + +Status ClFloorKernel::generate(ClKernelBlueprint &bp) const +{ + const auto src = _tensors.get_const_tensor(TensorType::ACL_SRC_0); + const auto dst = _tensors.get_const_tensor(TensorType::ACL_DST_0); + ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst); + ArgumentID src_id; + add_tensor(bp, src->desc, src_id, src->id); + ArgumentID dst_id; + add_tensor(bp, dst->desc, dst_id, dst->id); + + add_kcomp_floor(bp, desc, src_id, dst_id); + return Status{}; +} + +Status ClFloorKernel::validate(const ITensorInfo *src, const ITensorInfo *dst) +{ + // 1. Check validity + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst); + + // Matching data type + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst); + + // Matching data layout + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, dst); + + // All tensor infos are initialized + ARM_COMPUTE_RETURN_ERROR_ON(src->tensor_shape().total_size() == 0); + ARM_COMPUTE_RETURN_ERROR_ON(dst->tensor_shape().total_size() == 0); + + // Device requirements are met + ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src); + + // dst shape is correct + ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(src->tensor_shape(), dst->tensor_shape(), 0), "Wrong shape for dst"); + + // 2. Check support level + + // Data type + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32, DataType::F16); + + // Data layout + ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(src, DataLayout::NHWC); + + return Status{}; +} + +bool ClFloorKernel::operator==(const ClKernel &other) const +{ + const auto converted = *utils::cast::polymorphic_downcast(&other); return config() == other.config() && tensors() == other.tensors() && desc == converted.desc; } @@ -202,6 +254,7 @@ std::vector traverse(const ClKernelGraph &graph) } return kernels; } + std::vector traverse(ClKernelGraph &graph) { std::vector kernels; -- cgit v1.2.1