diff options
Diffstat (limited to 'src/backends/cl/ClLayerSupport.cpp')
-rw-r--r-- | src/backends/cl/ClLayerSupport.cpp | 27 |
1 files changed, 26 insertions, 1 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index bfe4f6e9fd..030b4c2d09 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -18,6 +18,7 @@ #if defined(ARMCOMPUTECL_ENABLED) #include <aclCommon/ArmComputeUtils.hpp> #include <aclCommon/ArmComputeTensorUtils.hpp> +#include <backendsCommon/WorkloadUtils.hpp> #include "workloads/ClAbsWorkload.hpp" #include "workloads/ClAdditionWorkload.hpp" #include "workloads/ClActivationWorkload.hpp" @@ -72,6 +73,7 @@ #include "workloads/ClResizeWorkload.hpp" #include "workloads/ClReverseV2Workload.hpp" #include "workloads/ClRsqrtWorkload.hpp" +#include "workloads/ClScatterNdWorkload.hpp" #include "workloads/ClSinWorkload.hpp" #include "workloads/ClSliceWorkload.hpp" #include "workloads/ClSoftmaxWorkload.hpp" @@ -577,6 +579,13 @@ bool ClLayerSupport::IsLayerSupported(const LayerType& type, infos[1], infos[2], reasonIfUnsupported); + case LayerType::ScatterNd: + return IsScatterNdSupported(infos[0], // input/shape + infos[1], // indices + infos[2], // updates + infos[3], // output + *(PolymorphicDowncast<const ScatterNdDescriptor*>(&descriptor)), + reasonIfUnsupported); case LayerType::Shape: return LayerSupportBase::IsShapeSupported(infos[0], infos[1], @@ -1441,6 +1450,22 @@ bool ClLayerSupport::IsReverseV2Supported(const TensorInfo& input, output); } +bool ClLayerSupport::IsScatterNdSupported(const TensorInfo& input, + const TensorInfo& indices, + const TensorInfo& updates, + const TensorInfo& output, + const ScatterNdDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) const +{ + FORWARD_WORKLOAD_VALIDATE_FUNC(ClScatterNdWorkloadValidate, + reasonIfUnsupported, + input, + indices, + updates, + output, + descriptor); +} + bool ClLayerSupport::IsSliceSupported(const TensorInfo& input, const TensorInfo& output, const SliceDescriptor& descriptor, |