From ce5045a00485f8a8c35814c0781ccbcca5678e5c Mon Sep 17 00:00:00 2001 From: Kevin May Date: Wed, 2 Oct 2019 14:07:47 +0100 Subject: IVGCVSW-3932 Add frontend for INSTANCE_NORMALIZATION Signed-off-by: Kevin May Change-Id: Ib152148ccd8d2733c617d0cf9402661fc6b71316 --- src/armnn/InternalTypes.cpp | 1 + src/armnn/InternalTypes.hpp | 1 + src/armnn/LayersFwd.hpp | 2 + src/armnn/Network.cpp | 6 +++ src/armnn/Network.hpp | 3 ++ src/armnn/layers/InstanceNormalizationLayer.cpp | 52 ++++++++++++++++++++++ src/armnn/layers/InstanceNormalizationLayer.hpp | 43 ++++++++++++++++++ .../test/TestNameAndDescriptorLayerVisitor.cpp | 23 ++++++++++ .../test/TestNameAndDescriptorLayerVisitor.hpp | 34 ++++++++++++++ 9 files changed, 165 insertions(+) create mode 100644 src/armnn/layers/InstanceNormalizationLayer.cpp create mode 100644 src/armnn/layers/InstanceNormalizationLayer.hpp (limited to 'src/armnn') diff --git a/src/armnn/InternalTypes.cpp b/src/armnn/InternalTypes.cpp index e6f7367ab5..612d00be5f 100644 --- a/src/armnn/InternalTypes.cpp +++ b/src/armnn/InternalTypes.cpp @@ -38,6 +38,7 @@ char const* GetLayerTypeAsCString(LayerType type) case LayerType::Gather: return "Gather"; case LayerType::Greater: return "Greater"; case LayerType::Input: return "Input"; + case LayerType::InstanceNormalization: return "InstanceNormalization"; case LayerType::L2Normalization: return "L2Normalization"; case LayerType::Lstm: return "Lstm"; case LayerType::Maximum: return "Maximum"; diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp index fbca9bcbcb..039d0f8ac8 100644 --- a/src/armnn/InternalTypes.hpp +++ b/src/armnn/InternalTypes.hpp @@ -38,6 +38,7 @@ enum class LayerType Gather, Greater, Input, + InstanceNormalization, L2Normalization, Lstm, Maximum, diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp index 3599eacf7d..1f539f3076 100644 --- a/src/armnn/LayersFwd.hpp +++ b/src/armnn/LayersFwd.hpp @@ -30,6 +30,7 @@ #include "layers/GatherLayer.hpp" #include "layers/GreaterLayer.hpp" #include "layers/InputLayer.hpp" +#include "layers/InstanceNormalizationLayer.hpp" #include "layers/L2NormalizationLayer.hpp" #include "layers/LstmLayer.hpp" #include "layers/MaximumLayer.hpp" @@ -113,6 +114,7 @@ DECLARE_LAYER(FullyConnected) DECLARE_LAYER(Gather) DECLARE_LAYER(Greater) DECLARE_LAYER(Input) +DECLARE_LAYER(InstanceNormalization) DECLARE_LAYER(L2Normalization) DECLARE_LAYER(Lstm) DECLARE_LAYER(Maximum) diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index cf9a138084..9d10b9ace1 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1224,6 +1224,12 @@ resizeDescriptor, const char* name) return m_Graph->AddLayer(resizeDescriptor, name); } +IConnectableLayer* Network::AddInstanceNormalizationLayer(const InstanceNormalizationDescriptor& desc, + const char* name) +{ + return m_Graph->AddLayer(desc, name); +} + IConnectableLayer* Network::AddL2NormalizationLayer(const L2NormalizationDescriptor& desc, const char* name) { diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index 4a8bfbc9f2..e11f3d2185 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -152,6 +152,9 @@ public: IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor, const char* name = nullptr) override; + IConnectableLayer* AddInstanceNormalizationLayer(const InstanceNormalizationDescriptor& desc, + const char* name = nullptr) override; + IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc, const char* name = nullptr) override; diff --git a/src/armnn/layers/InstanceNormalizationLayer.cpp b/src/armnn/layers/InstanceNormalizationLayer.cpp new file mode 100644 index 0000000000..fc3044af50 --- /dev/null +++ b/src/armnn/layers/InstanceNormalizationLayer.cpp @@ -0,0 +1,52 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include "InstanceNormalizationLayer.hpp" + +#include "LayerCloneBase.hpp" + +#include +#include +#include + +namespace armnn +{ + +InstanceNormalizationLayer::InstanceNormalizationLayer(const InstanceNormalizationDescriptor& param, const char* name) + : LayerWithParameters(1, 1, LayerType::InstanceNormalization, param, name) +{ +} + +std::unique_ptr InstanceNormalizationLayer::CreateWorkload(const Graph& graph, + const IWorkloadFactory& factory) const +{ + InstanceNormalizationQueueDescriptor descriptor; + return factory.CreateInstanceNormalization(descriptor, PrepInfoAndDesc(descriptor, graph)); +} + +InstanceNormalizationLayer* InstanceNormalizationLayer::Clone(Graph& graph) const +{ + return CloneBase(graph, m_Param, GetName()); +} + +void InstanceNormalizationLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(1, CHECK_LOCATION()); + + auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() }); + + BOOST_ASSERT(inferredShapes.size() == 1); + + ConditionalThrowIfNotEqual( + "InstanceNormalizationLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", + GetOutputSlot(0).GetTensorInfo().GetShape(), + inferredShapes[0]); +} + +void InstanceNormalizationLayer::Accept(ILayerVisitor& visitor) const +{ + visitor.VisitInstanceNormalizationLayer(this, GetParameters(), GetName()); +} + +} // namespace armnn diff --git a/src/armnn/layers/InstanceNormalizationLayer.hpp b/src/armnn/layers/InstanceNormalizationLayer.hpp new file mode 100644 index 0000000000..9ba56731c6 --- /dev/null +++ b/src/armnn/layers/InstanceNormalizationLayer.hpp @@ -0,0 +1,43 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "LayerWithParameters.hpp" + +namespace armnn +{ + +/// This layer represents an instance normalization operation. +class InstanceNormalizationLayer : public LayerWithParameters +{ +public: + /// Makes a workload for the InstanceNormalization type. + /// @param [in] graph The graph where this layer can be found. + /// @param [in] factory The workload factory which will create the workload. + /// @return A pointer to the created workload, or nullptr if not created. + virtual std::unique_ptr CreateWorkload(const Graph& graph, + const IWorkloadFactory& factory) const override; + + /// Creates a dynamically-allocated copy of this layer. + /// @param [in] graph The graph into which this layer is being cloned. + InstanceNormalizationLayer* Clone(Graph& graph) const override; + + /// Check if the input tensor shape(s) + /// will lead to a valid configuration of @ref InstanceNormalizationLayer. + void ValidateTensorShapesFromInputs() override; + + void Accept(ILayerVisitor& visitor) const override; + +protected: + /// Constructor to create a InstanceNormalizationLayer. + /// @param [in] param InstanceNormalizationDescriptor to configure the Instance normalization operation. + /// @param [in] name Optional name for the layer. + InstanceNormalizationLayer(const InstanceNormalizationDescriptor& param, const char* name); + + /// Default destructor + ~InstanceNormalizationLayer() = default; +}; + +} // namespace diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp index 653612f208..dcc5dc4cfb 100644 --- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp +++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp @@ -282,6 +282,29 @@ BOOST_AUTO_TEST_CASE(CheckResizeLayerVisitorNameNullAndDescriptor) layer->Accept(visitor); } +BOOST_AUTO_TEST_CASE(CheckInstanceNormalizationLayerVisitorNameAndDescriptor) +{ + const char* layerName = "InstanceNormalizationLayer"; + InstanceNormalizationDescriptor descriptor; + descriptor.m_DataLayout = DataLayout::NHWC; + TestInstanceNormalizationLayerVisitor visitor(descriptor, layerName); + Network net; + + IConnectableLayer *const layer = net.AddInstanceNormalizationLayer(descriptor, layerName); + layer->Accept(visitor); +} + +BOOST_AUTO_TEST_CASE(CheckInstanceNormalizationLayerVisitorNameNullAndDescriptor) +{ + InstanceNormalizationDescriptor descriptor; + descriptor.m_DataLayout = DataLayout::NHWC; + TestInstanceNormalizationLayerVisitor visitor(descriptor); + Network net; + + IConnectableLayer *const layer = net.AddInstanceNormalizationLayer(descriptor); + layer->Accept(visitor); +} + BOOST_AUTO_TEST_CASE(CheckL2NormalizationLayerVisitorNameAndDescriptor) { const char* layerName = "L2NormalizationLayer"; diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp index f1936d6847..aa0b3597fa 100644 --- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp +++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp @@ -418,6 +418,40 @@ public: }; }; +class TestInstanceNormalizationLayerVisitor : public TestLayerVisitor +{ +private: + InstanceNormalizationDescriptor m_VisitorDescriptor; + +public: + explicit TestInstanceNormalizationLayerVisitor(const InstanceNormalizationDescriptor& desc, + const char* name = nullptr) + : TestLayerVisitor(name) + { + m_VisitorDescriptor.m_Beta = desc.m_Beta; + m_VisitorDescriptor.m_Gamma = desc.m_Gamma; + m_VisitorDescriptor.m_Eps = desc.m_Eps; + m_VisitorDescriptor.m_DataLayout = desc.m_DataLayout; + }; + + void CheckDescriptor(const InstanceNormalizationDescriptor& desc) + { + BOOST_CHECK(desc.m_Beta == m_VisitorDescriptor.m_Beta); + BOOST_CHECK(desc.m_Gamma == m_VisitorDescriptor.m_Gamma); + BOOST_CHECK(desc.m_Eps == m_VisitorDescriptor.m_Eps); + BOOST_CHECK(desc.m_DataLayout == m_VisitorDescriptor.m_DataLayout); + } + + void VisitInstanceNormalizationLayer(const IConnectableLayer* layer, + const InstanceNormalizationDescriptor& desc, + const char* name = nullptr) override + { + CheckLayerPointer(layer); + CheckDescriptor(desc); + CheckLayerName(name); + }; +}; + class TestL2NormalizationLayerVisitor : public TestLayerVisitor { private: -- cgit v1.2.1