diff options
Diffstat (limited to 'src/armnn/layers')
-rw-r--r-- | src/armnn/layers/GatherNdLayer.cpp | 104 | ||||
-rw-r--r-- | src/armnn/layers/GatherNdLayer.hpp | 48 |
2 files changed, 152 insertions, 0 deletions
diff --git a/src/armnn/layers/GatherNdLayer.cpp b/src/armnn/layers/GatherNdLayer.cpp new file mode 100644 index 0000000000..1ca2cbbae3 --- /dev/null +++ b/src/armnn/layers/GatherNdLayer.cpp @@ -0,0 +1,104 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "GatherNdLayer.hpp" +#include "LayerCloneBase.hpp" + +#include <armnn/TypesUtils.hpp> +#include <armnn/backends/WorkloadData.hpp> +#include <armnn/backends/WorkloadFactory.hpp> + +namespace armnn +{ + +GatherNdLayer::GatherNdLayer(const char* name) + : Layer(2, 1, LayerType::GatherNd, name) +{ +} + +std::unique_ptr<IWorkload> GatherNdLayer::CreateWorkload(const armnn::IWorkloadFactory& factory) const +{ + GatherNdQueueDescriptor descriptor; + SetAdditionalInfo(descriptor); + + return factory.CreateWorkload(LayerType::GatherNd, descriptor, PrepInfoAndDesc(descriptor)); +} + +GatherNdLayer* GatherNdLayer::Clone(Graph& graph) const +{ + return CloneBase<GatherNdLayer>(graph, GetName()); +} + +std::vector<TensorShape> GatherNdLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const +{ + ARMNN_ASSERT(inputShapes.size() == 2); + const TensorShape& params = inputShapes[0]; + const TensorShape& indices = inputShapes[1]; + + if (indices.GetDimensionality() == Dimensionality::Scalar && indices.GetNumDimensions() == 1) + { + return std::vector<TensorShape>({ TensorShape(Dimensionality::Scalar)}); + } + + const unsigned int paramsDim = params.GetNumDimensions(); + const unsigned int indicesDim = indices.GetNumDimensions(); + + // last dimension of indices + unsigned int index_depth = indices[indicesDim - 1]; + ARMNN_ASSERT(index_depth <= paramsDim); + + // all but the last dimension of indices + std::vector<unsigned int> outer_shape; + outer_shape.reserve(indicesDim - 1); + for (unsigned int i = 0; i < indicesDim - 1; ++i) + { + outer_shape.emplace_back(indices[i]); + } + + // elements after index_depth + std::vector<unsigned int> inner_shape; + inner_shape.reserve(paramsDim - index_depth); + for (unsigned int i = index_depth; i < paramsDim; ++i) + { + inner_shape.emplace_back(params[i]); + } + + // concatenate outer_shape + inner_shape + std::vector<unsigned int> output_shape; + output_shape.reserve( outer_shape.size() + inner_shape.size() ); + output_shape.insert( output_shape.end(), outer_shape.begin(), outer_shape.end() ); + output_shape.insert( output_shape.end(), inner_shape.begin(), inner_shape.end() ); + + const auto outputDim = static_cast<unsigned int>(output_shape.size()); + return std::vector<TensorShape>({ TensorShape({outputDim, output_shape.data()})}); +} + +void GatherNdLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(2, CHECK_LOCATION()); + + const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); + + VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); + + std::vector<TensorShape> inferredShapes = InferOutputShapes( + {GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()}); + ARMNN_ASSERT(inferredShapes.size() == 1); + ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified || + inferredShapes[0].GetDimensionality() == Dimensionality::Scalar); + + ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherNdLayer"); +} + +ARMNN_NO_DEPRECATE_WARN_BEGIN +void GatherNdLayer::Accept(ILayerVisitor& visitor) const +{ + IgnoreUnused(visitor); + throw armnn::Exception("GatherNdLayer VisitGatherNdLayer is not implemented"); +} +ARMNN_NO_DEPRECATE_WARN_END + +} // namespace armnn diff --git a/src/armnn/layers/GatherNdLayer.hpp b/src/armnn/layers/GatherNdLayer.hpp new file mode 100644 index 0000000000..9e07715f90 --- /dev/null +++ b/src/armnn/layers/GatherNdLayer.hpp @@ -0,0 +1,48 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "Layer.hpp" + +namespace armnn +{ + +/// This layer represents a GatherNd operator. +class GatherNdLayer : public Layer +{ +public: + /// Makes a workload for the Gather type. + /// @param [in] factory The workload factory which will create the workload. + /// @return A pointer to the created workload, or nullptr if not created. + virtual std::unique_ptr<IWorkload> CreateWorkload(const IWorkloadFactory& factory) const override; + + /// Creates a dynamically-allocated copy of this layer. + /// @param [in] graph The graph into which this layer is being cloned. + GatherNdLayer* Clone(Graph& graph) const override; + + /// Infers the output shapes from given input shapes and layer properties. + /// @param [in] inputShapes The input shapes layer has. + /// @return A vector to the inferred output shape. + std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override; + + /// Check if the input tensor shape(s) + /// will lead to a valid configuration of @ref GatherNdLayer. + void ValidateTensorShapesFromInputs() override; + + ARMNN_NO_DEPRECATE_WARN_BEGIN + void Accept(ILayerVisitor& visitor) const override; + ARMNN_NO_DEPRECATE_WARN_END + +protected: + /// Constructor to create a GatherNdLayer. + /// @param [in] name Optional name for the layer. + GatherNdLayer(const char* name); + + /// Default destructor + ~GatherNdLayer() = default; +}; + +} // namespace armnn |