diff options
Diffstat (limited to 'src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp')
-rw-r--r-- | src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp | 63 |
1 files changed, 58 insertions, 5 deletions
diff --git a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp index de58ce70ed..cab51a2ce6 100644 --- a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp +++ b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp @@ -124,7 +124,7 @@ bool ClDirectConv2dKernel::operator==(const ClKernel &other) const return config() == other.config() && tensors() == other.tensors() && desc == converted.desc; } -Status ClAddKernel::generate(ClKernelBlueprint &bp) const +Status ClElementwiseKernel::generate(ClKernelBlueprint &bp) const { const auto lhs = _tensors.get_const_tensor(TensorType::ACL_SRC_0); const auto rhs = _tensors.get_const_tensor(TensorType::ACL_SRC_1); @@ -137,11 +137,11 @@ Status ClAddKernel::generate(ClKernelBlueprint &bp) const ArgumentID dst_id; add_tensor(bp, dst->desc, dst_id, dst->id); - add_kcomp_eltwise_add(bp, desc, lhs_id, rhs_id, dst_id); + add_kcomp_eltwise_op(bp, desc, lhs_id, rhs_id, dst_id); return Status{}; } -Status ClAddKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst) +Status ClElementwiseKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst) { // 1. Check validity ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst); @@ -186,9 +186,61 @@ Status ClAddKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, con return Status{}; } -bool ClAddKernel::operator==(const ClKernel &other) const +bool ClElementwiseKernel::operator==(const ClKernel &other) const { - const auto converted = *utils::cast::polymorphic_downcast<const ClAddKernel *>(&other); + const auto converted = *utils::cast::polymorphic_downcast<const ClElementwiseKernel *>(&other); + return config() == other.config() && tensors() == other.tensors() && desc == converted.desc; +} + +Status ClFloorKernel::generate(ClKernelBlueprint &bp) const +{ + const auto src = _tensors.get_const_tensor(TensorType::ACL_SRC_0); + const auto dst = _tensors.get_const_tensor(TensorType::ACL_DST_0); + ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst); + ArgumentID src_id; + add_tensor(bp, src->desc, src_id, src->id); + ArgumentID dst_id; + add_tensor(bp, dst->desc, dst_id, dst->id); + + add_kcomp_floor(bp, desc, src_id, dst_id); + return Status{}; +} + +Status ClFloorKernel::validate(const ITensorInfo *src, const ITensorInfo *dst) +{ + // 1. Check validity + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst); + + // Matching data type + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst); + + // Matching data layout + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, dst); + + // All tensor infos are initialized + ARM_COMPUTE_RETURN_ERROR_ON(src->tensor_shape().total_size() == 0); + ARM_COMPUTE_RETURN_ERROR_ON(dst->tensor_shape().total_size() == 0); + + // Device requirements are met + ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src); + + // dst shape is correct + ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(src->tensor_shape(), dst->tensor_shape(), 0), "Wrong shape for dst"); + + // 2. Check support level + + // Data type + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32, DataType::F16); + + // Data layout + ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(src, DataLayout::NHWC); + + return Status{}; +} + +bool ClFloorKernel::operator==(const ClKernel &other) const +{ + const auto converted = *utils::cast::polymorphic_downcast<const ClFloorKernel *>(&other); return config() == other.config() && tensors() == other.tensors() && desc == converted.desc; } @@ -202,6 +254,7 @@ std::vector<const ClKernel *> traverse(const ClKernelGraph &graph) } return kernels; } + std::vector<ClKernel *> traverse(ClKernelGraph &graph) { std::vector<ClKernel *> kernels; |