From 3ae3f978cf9ce3174609b7152af87acb410b0fe0 Mon Sep 17 00:00:00 2001 From: Keith Davis Date: Fri, 21 May 2021 16:33:48 +0100 Subject: MLCE-510 Add CpuRef Shape Operator to ArmNN * Add front end * Add reference workload * Serialization/Deserialization * Add unit tests * Update ArmNN Versioning Signed-off-by: Keith Davis Change-Id: I6fcb1fa341d6f08dea4003b13544e6e9f53fefd3 --- src/armnn/layers/ShapeLayer.cpp | 73 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 src/armnn/layers/ShapeLayer.cpp (limited to 'src/armnn/layers/ShapeLayer.cpp') diff --git a/src/armnn/layers/ShapeLayer.cpp b/src/armnn/layers/ShapeLayer.cpp new file mode 100644 index 0000000000..4193fa9aab --- /dev/null +++ b/src/armnn/layers/ShapeLayer.cpp @@ -0,0 +1,73 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ShapeLayer.hpp" + +#include "LayerCloneBase.hpp" + +#include +#include + +#include +#include + +namespace armnn +{ + +ShapeLayer::ShapeLayer(const char* name) + : Layer(1, 1, LayerType::Shape, name) +{ +} + +std::unique_ptr ShapeLayer::CreateWorkload(const IWorkloadFactory& factory) const +{ + ShapeQueueDescriptor descriptor; + SetAdditionalInfo(descriptor); + + return factory.CreateShape(descriptor, PrepInfoAndDesc(descriptor)); +} + +ShapeLayer* ShapeLayer::Clone(Graph& graph) const +{ + return CloneBase(graph, GetName()); +} + +void ShapeLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(1, CHECK_LOCATION()); + + const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); + + VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); + + auto inferredShape = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() }); + + ARMNN_ASSERT(inferredShape.size() == 1); + + ValidateAndCopyShape(outputShape, inferredShape[0], m_ShapeInferenceMethod, "ShapeLayer"); +} + +std::vector ShapeLayer::InferOutputShapes(const std::vector& inputShapes) const +{ + IgnoreUnused(inputShapes); + ARMNN_ASSERT(inputShapes.size() == 1); + + TensorShape outputShape({ inputShapes[0].GetNumDimensions()} ); + + return std::vector({ outputShape }); +} + +void ShapeLayer::Accept(ILayerVisitor& visitor) const +{ + IgnoreUnused(visitor); + throw armnn::Exception("ShapeLayer VisitShapeLayer is not implemented"); +} + +void ShapeLayer::ExecuteStrategy(IStrategy& strategy) const +{ + strategy.ExecuteStrategy(this, BaseDescriptor(), {}, GetName()); +} + +} // namespace armnn -- cgit v1.2.1