From b392e9845b7f40ab0c389f29f13f6ec84dd814d1 Mon Sep 17 00:00:00 2001 From: mathad01 Date: Wed, 7 Apr 2021 12:07:30 +0100 Subject: IVGCVSW-5410 Add front-end support for CAST IVGCVSW-5415 Add TfLiteParser support for CAST * Added front end support for CAST, including support in the Reference workload, Serialization, Deserializtion, Unit tests, and TfLiteParser. Signed-off-by: mathad01 Change-Id: Iaf670ca5912a21ed6bc84f7f83a68b42154846bb --- src/armnn/BackendHelper.cpp | 7 ++++++ src/armnn/LayersFwd.hpp | 3 ++- src/armnn/Network.cpp | 9 +++++++ src/armnn/Network.hpp | 2 ++ src/armnn/layers/CastLayer.cpp | 55 ++++++++++++++++++++++++++++++++++++++++++ src/armnn/layers/CastLayer.hpp | 40 ++++++++++++++++++++++++++++++ 6 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 src/armnn/layers/CastLayer.cpp create mode 100644 src/armnn/layers/CastLayer.hpp (limited to 'src/armnn') diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp index 1c926f4d30..31dfaa53a3 100644 --- a/src/armnn/BackendHelper.cpp +++ b/src/armnn/BackendHelper.cpp @@ -113,6 +113,13 @@ bool LayerSupportHandle::IsBatchToSpaceNdSupported(const TensorInfo& input, reasonIfUnsupported.value()); } +bool LayerSupportHandle::IsCastSupported(const TensorInfo& input, + const TensorInfo& output, + Optional reasonIfUnsupported) +{ + return m_LayerSupport->IsCastSupported(input, output, reasonIfUnsupported.value()); +} + bool LayerSupportHandle::IsComparisonSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp index 6782fb5eb7..19cd9bdf6c 100644 --- a/src/armnn/LayersFwd.hpp +++ b/src/armnn/LayersFwd.hpp @@ -11,6 +11,7 @@ #include "layers/ArgMinMaxLayer.hpp" #include "layers/BatchNormalizationLayer.hpp" #include "layers/BatchToSpaceNdLayer.hpp" +#include "layers/CastLayer.hpp" #include "layers/ComparisonLayer.hpp" #include "layers/ConcatLayer.hpp" #include "layers/ConstantLayer.hpp" @@ -166,5 +167,5 @@ DECLARE_LAYER(Switch) DECLARE_LAYER(Transpose) DECLARE_LAYER(TransposeConvolution2d) DECLARE_LAYER(Unmap) - +DECLARE_LAYER(Cast) } diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index b9a0e47ec5..860048fecd 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -59,6 +59,10 @@ IConnectableLayer* INetwork::AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc, return pNetworkImpl->AddArgMinMaxLayer(desc, name); } +IConnectableLayer* INetwork::AddCastLayer(const char* name) +{ + return pNetworkImpl->AddCastLayer(name); +} IConnectableLayer* INetwork::AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor, const char* name) @@ -1705,6 +1709,11 @@ IConnectableLayer* NetworkImpl::AddBatchToSpaceNdLayer(const BatchToSpaceNdDescr return m_Graph->AddLayer(batchToSpaceNdDescriptor, name); } +IConnectableLayer* NetworkImpl::AddCastLayer(const char* name) +{ + return m_Graph->AddLayer(name); +} + IConnectableLayer* NetworkImpl::AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor, const char* name) { diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index 30941ca9e4..ad9b51cf35 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -46,6 +46,8 @@ public: IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor, const char* name = nullptr); + IConnectableLayer* AddCastLayer(const char* name = nullptr); + IConnectableLayer* AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor, const char* name = nullptr); diff --git a/src/armnn/layers/CastLayer.cpp b/src/armnn/layers/CastLayer.cpp new file mode 100644 index 0000000000..16dd9a3744 --- /dev/null +++ b/src/armnn/layers/CastLayer.cpp @@ -0,0 +1,55 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include "CastLayer.hpp" + +#include "LayerCloneBase.hpp" +#include + +#include +#include + +namespace armnn +{ + +CastLayer::CastLayer(const char* name) + : Layer(1, 1, LayerType::Cast, name) +{ +} + +std::unique_ptr CastLayer::CreateWorkload(const IWorkloadFactory& factory) const +{ + CastQueueDescriptor descriptor; + SetAdditionalInfo(descriptor); + + return factory.CreateCast(descriptor, PrepInfoAndDesc(descriptor)); +} + +CastLayer* CastLayer::Clone(Graph& graph) const +{ + return CloneBase(graph, GetName()); +} + +void CastLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(1, CHECK_LOCATION()); + + const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); + + VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); + + auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() }); + + ARMNN_ASSERT(inferredShapes.size() == 1); + + ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "CastLayer"); +} + +void CastLayer::Accept(ILayerVisitor& visitor) const +{ + IgnoreUnused(visitor); + throw armnn::Exception("CastLayer VisitCastLayer is not implemented"); +} + +} // namespace armnn diff --git a/src/armnn/layers/CastLayer.hpp b/src/armnn/layers/CastLayer.hpp new file mode 100644 index 0000000000..8a9ea43934 --- /dev/null +++ b/src/armnn/layers/CastLayer.hpp @@ -0,0 +1,40 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "Layer.hpp" + +namespace armnn +{ + +/// This layer represents a cast operation +class CastLayer : public Layer +{ +public: + /// Makes a workload for the Cast type. + /// @param [in] factory The workload factory which will create the workload. + /// @return A pointer to the created workload, or nullptr if not created. + virtual std::unique_ptr CreateWorkload(const IWorkloadFactory &factory) const override; + + /// Creates a dynamically-allocated copy of this layer. + /// @param [in] graph The graph into which this layer is being cloned. + CastLayer* Clone(Graph& graph) const override; + + /// Check if the input tensor shape(s) + /// will lead to a valid configuration of @ref ConvertFp16ToFp32Layer. + /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validated. + void ValidateTensorShapesFromInputs() override; + + void Accept(ILayerVisitor& visitor) const override; + +protected: + /// Constructor to create a CastLayer. + CastLayer(const char *name); + + /// Default destructor + ~CastLayer() = default; +}; +} // namespace armnn \ No newline at end of file -- cgit v1.2.1