From 636ab40d3741e12eaad11d5b50e4b34bfbb258b5 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Mon, 16 Sep 2019 14:27:45 +0100 Subject: IVGCVSW-3875 Add frontend for SLICE layer Signed-off-by: Aron Virginas-Tar Change-Id: Iebe675a0cee02db6f133d48ce58cbc1e233061db --- src/armnn/InternalTypes.hpp | 1 + src/armnn/LayersFwd.hpp | 2 + src/armnn/Network.cpp | 5 ++ src/armnn/Network.hpp | 2 + src/armnn/layers/SliceLayer.cpp | 66 ++++++++++++++++ src/armnn/layers/SliceLayer.hpp | 49 ++++++++++++ src/armnn/test/TestNameOnlyLayerVisitor.cpp | 116 +++++++++++++++++----------- src/armnn/test/TestNameOnlyLayerVisitor.hpp | 82 ++++++++++++-------- 8 files changed, 246 insertions(+), 77 deletions(-) create mode 100644 src/armnn/layers/SliceLayer.cpp create mode 100644 src/armnn/layers/SliceLayer.hpp (limited to 'src/armnn') diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp index 98308f92a1..1e05fff769 100644 --- a/src/armnn/InternalTypes.hpp +++ b/src/armnn/InternalTypes.hpp @@ -58,6 +58,7 @@ enum class LayerType Reshape, Resize, Rsqrt, + Slice, Softmax, SpaceToBatchNd, SpaceToDepth, diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp index 6e4cf6ab04..a98c104f85 100644 --- a/src/armnn/LayersFwd.hpp +++ b/src/armnn/LayersFwd.hpp @@ -50,6 +50,7 @@ #include "layers/ReshapeLayer.hpp" #include "layers/ResizeLayer.hpp" #include "layers/RsqrtLayer.hpp" +#include "layers/SliceLayer.hpp" #include "layers/SoftmaxLayer.hpp" #include "layers/SpaceToBatchNdLayer.hpp" #include "layers/SpaceToDepthLayer.hpp" @@ -131,6 +132,7 @@ DECLARE_LAYER(QuantizedLstm) DECLARE_LAYER(Reshape) DECLARE_LAYER(Resize) DECLARE_LAYER(Rsqrt) +DECLARE_LAYER(Slice) DECLARE_LAYER(Softmax) DECLARE_LAYER(SpaceToBatchNd) DECLARE_LAYER(SpaceToDepth) diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 6971cb89ba..c055407b3a 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1129,6 +1129,11 @@ normalizationDescriptor, return m_Graph->AddLayer(normalizationDescriptor, name); } +IConnectableLayer* Network::AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name) +{ + return m_Graph->AddLayer(sliceDescriptor, name); +} + IConnectableLayer* Network::AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor, const char* name) { diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index aac875aac7..274cc1ab7c 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -117,6 +117,8 @@ public: IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor, const char* name = nullptr) override; + IConnectableLayer* AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name = nullptr) override; + IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor, const char* name = nullptr) override; diff --git a/src/armnn/layers/SliceLayer.cpp b/src/armnn/layers/SliceLayer.cpp new file mode 100644 index 0000000000..8ea5fd8f25 --- /dev/null +++ b/src/armnn/layers/SliceLayer.cpp @@ -0,0 +1,66 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "SliceLayer.hpp" + +#include "LayerCloneBase.hpp" + +#include + +#include +#include + +#include +#include + +namespace armnn +{ + +SliceLayer::SliceLayer(const SliceDescriptor& param, const char* name) + : LayerWithParameters(1, 1, LayerType::Slice, param, name) +{ +} + +std::unique_ptr SliceLayer::CreateWorkload(const Graph& graph, + const IWorkloadFactory& factory) const +{ + SliceQueueDescriptor descriptor; + return factory.CreateSlice(descriptor, PrepInfoAndDesc(descriptor, graph)); +} + +SliceLayer* SliceLayer::Clone(Graph& graph) const +{ + return CloneBase(graph, m_Param, GetName()); +} + +void SliceLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(1, CHECK_LOCATION()); + + auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() }); + + BOOST_ASSERT(inferredShapes.size() == 1); + + ConditionalThrowIfNotEqual( + "SliceLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", + GetOutputSlot(0).GetTensorInfo().GetShape(), + inferredShapes[0]); +} + +std::vector SliceLayer::InferOutputShapes(const std::vector& inputShapes) const +{ + BOOST_ASSERT(inputShapes.size() == 1); + + TensorShape outputShape(boost::numeric_cast(m_Param.m_Size.size()), m_Param.m_Size.data()); + + return std::vector({ outputShape }); +} + +void SliceLayer::Accept(ILayerVisitor& visitor) const +{ + visitor.VisitSliceLayer(this, GetParameters(), GetName()); +} + +} // namespace armnn diff --git a/src/armnn/layers/SliceLayer.hpp b/src/armnn/layers/SliceLayer.hpp new file mode 100644 index 0000000000..38f0747f05 --- /dev/null +++ b/src/armnn/layers/SliceLayer.hpp @@ -0,0 +1,49 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "LayerWithParameters.hpp" + +namespace armnn +{ + +class SliceLayer : public LayerWithParameters +{ +public: + /// Makes a workload for the Slice type. + /// @param [in] graph The graph where this layer can be found. + /// @param [in] factory The workload factory which will create the workload. + /// @return A pointer to the created workload, or nullptr if not created. + virtual std::unique_ptr CreateWorkload(const Graph& graph, + const IWorkloadFactory& factory) const override; + + /// Creates a dynamically-allocated copy of this layer. + /// @param [in] graph The graph into which this layer is being cloned. + SliceLayer* Clone(Graph& graph) const override; + + /// Check if the input tensor shape(s) + /// will lead to a valid configuration of @ref SliceLayer. + void ValidateTensorShapesFromInputs() override; + + /// By default returns inputShapes if the number of inputs are equal to number of outputs, + /// otherwise infers the output shapes from given input shapes and layer properties. + /// @param [in] inputShapes The input shapes layer has. + /// @return A vector to the inferred output shape. + std::vector InferOutputShapes(const std::vector& inputShapes) const override; + + void Accept(ILayerVisitor& visitor) const override; + +protected: + /// Constructor to create a SliceLayer. + /// @param [in] param SliceDescriptor to configure the resize operation. + /// @param [in] name Optional name for the layer. + SliceLayer(const SliceDescriptor& param, const char* name); + + /// Default destructor. + ~SliceLayer() = default; +}; + +} // namespace armnn diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.cpp b/src/armnn/test/TestNameOnlyLayerVisitor.cpp index 4bb9614385..c4c4a479eb 100644 --- a/src/armnn/test/TestNameOnlyLayerVisitor.cpp +++ b/src/armnn/test/TestNameOnlyLayerVisitor.cpp @@ -10,6 +10,7 @@ namespace armnn { BOOST_AUTO_TEST_SUITE(TestNameOnlyLayerVisitor) +// Addition BOOST_AUTO_TEST_CASE(CheckAdditionLayerVisitorName) { TestAdditionLayerVisitor visitor("AdditionLayer"); @@ -28,24 +29,45 @@ BOOST_AUTO_TEST_CASE(CheckAdditionLayerVisitorNameNullptr) layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckMultiplicationLayerVisitorName) +// Division +BOOST_AUTO_TEST_CASE(CheckDivisionLayerVisitorName) { - TestMultiplicationLayerVisitor visitor("MultiplicationLayer"); + TestDivisionLayerVisitor visitor("DivisionLayer"); Network net; - IConnectableLayer *const layer = net.AddMultiplicationLayer("MultiplicationLayer"); + IConnectableLayer *const layer = net.AddAdditionLayer("DivisionLayer"); layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckMultiplicationLayerVisitorNameNullptr) +BOOST_AUTO_TEST_CASE(CheckDivisionLayerVisitorNameNullptr) { - TestMultiplicationLayerVisitor visitor; + TestDivisionLayerVisitor visitor; Network net; - IConnectableLayer *const layer = net.AddMultiplicationLayer(); + IConnectableLayer *const layer = net.AddDivisionLayer(); + layer->Accept(visitor); +} + +// Equal +BOOST_AUTO_TEST_CASE(CheckEqualLayerVisitorName) +{ + TestEqualLayerVisitor visitor("EqualLayer"); + Network net; + + IConnectableLayer *const layer = net.AddEqualLayer("EqualLayer"); + layer->Accept(visitor); +} + +BOOST_AUTO_TEST_CASE(CheckEqualLayerVisitorNameNullptr) +{ + TestEqualLayerVisitor visitor; + Network net; + + IConnectableLayer *const layer = net.AddEqualLayer(); layer->Accept(visitor); } +// Floor BOOST_AUTO_TEST_CASE(CheckFloorLayerVisitorName) { TestFloorLayerVisitor visitor("FloorLayer"); @@ -64,42 +86,45 @@ BOOST_AUTO_TEST_CASE(CheckFloorLayerVisitorNameNullptr) layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckDivisionLayerVisitorName) +// Gather +BOOST_AUTO_TEST_CASE(CheckGatherLayerVisitorName) { - TestDivisionLayerVisitor visitor("DivisionLayer"); + TestGatherLayerVisitor visitor("GatherLayer"); Network net; - IConnectableLayer *const layer = net.AddAdditionLayer("DivisionLayer"); + IConnectableLayer *const layer = net.AddGatherLayer("GatherLayer"); layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckDivisionLayerVisitorNameNullptr) +BOOST_AUTO_TEST_CASE(CheckGatherLayerVisitorNameNullptr) { - TestDivisionLayerVisitor visitor; + TestGatherLayerVisitor visitor; Network net; - IConnectableLayer *const layer = net.AddDivisionLayer(); + IConnectableLayer *const layer = net.AddGatherLayer(); layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckSubtractionLayerVisitorName) +// Greater +BOOST_AUTO_TEST_CASE(CheckGreaterLayerVisitorName) { - TestSubtractionLayerVisitor visitor("SubtractionLayer"); + TestGreaterLayerVisitor visitor("GreaterLayer"); Network net; - IConnectableLayer *const layer = net.AddSubtractionLayer("SubtractionLayer"); + IConnectableLayer *const layer = net.AddGreaterLayer("GreaterLayer"); layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckSubtractionLayerVisitorNameNullptr) +BOOST_AUTO_TEST_CASE(CheckGreaterLayerVisitorNameNullptr) { - TestSubtractionLayerVisitor visitor; + TestGreaterLayerVisitor visitor; Network net; - IConnectableLayer *const layer = net.AddSubtractionLayer(); + IConnectableLayer *const layer = net.AddGreaterLayer(); layer->Accept(visitor); } +// Maximum BOOST_AUTO_TEST_CASE(CheckMaximumLayerVisitorName) { TestMaximumLayerVisitor visitor("MaximumLayer"); @@ -118,6 +143,7 @@ BOOST_AUTO_TEST_CASE(CheckMaximumLayerVisitorNameNullptr) layer->Accept(visitor); } +// Minimum BOOST_AUTO_TEST_CASE(CheckMinimumLayerVisitorName) { TestMinimumLayerVisitor visitor("MinimumLayer"); @@ -136,78 +162,82 @@ BOOST_AUTO_TEST_CASE(CheckMinimumLayerVisitorNameNullptr) layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckGreaterLayerVisitorName) +// Multiplication +BOOST_AUTO_TEST_CASE(CheckMultiplicationLayerVisitorName) { - TestGreaterLayerVisitor visitor("GreaterLayer"); + TestMultiplicationLayerVisitor visitor("MultiplicationLayer"); Network net; - IConnectableLayer *const layer = net.AddGreaterLayer("GreaterLayer"); + IConnectableLayer *const layer = net.AddMultiplicationLayer("MultiplicationLayer"); layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckGreaterLayerVisitorNameNullptr) +BOOST_AUTO_TEST_CASE(CheckMultiplicationLayerVisitorNameNullptr) { - TestGreaterLayerVisitor visitor; + TestMultiplicationLayerVisitor visitor; Network net; - IConnectableLayer *const layer = net.AddGreaterLayer(); + IConnectableLayer *const layer = net.AddMultiplicationLayer(); layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckEqualLayerVisitorName) +// Rsqrt +BOOST_AUTO_TEST_CASE(CheckRsqrtLayerVisitorName) { - TestEqualLayerVisitor visitor("EqualLayer"); + TestRsqrtLayerVisitor visitor("RsqrtLayer"); Network net; - IConnectableLayer *const layer = net.AddEqualLayer("EqualLayer"); + IConnectableLayer *const layer = net.AddRsqrtLayer("RsqrtLayer"); layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckEqualLayerVisitorNameNullptr) +BOOST_AUTO_TEST_CASE(CheckRsqrtLayerVisitorNameNullptr) { - TestEqualLayerVisitor visitor; + TestRsqrtLayerVisitor visitor; Network net; - IConnectableLayer *const layer = net.AddEqualLayer(); + IConnectableLayer *const layer = net.AddRsqrtLayer(); layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckRsqrtLayerVisitorName) +// Slice +BOOST_AUTO_TEST_CASE(CheckSliceLayerVisitorName) { - TestRsqrtLayerVisitor visitor("RsqrtLayer"); + TestSliceLayerVisitor visitor("SliceLayer"); Network net; - IConnectableLayer *const layer = net.AddRsqrtLayer("RsqrtLayer"); + IConnectableLayer *const layer = net.AddSliceLayer(SliceDescriptor(), "SliceLayer"); layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckRsqrtLayerVisitorNameNullptr) +BOOST_AUTO_TEST_CASE(CheckSliceLayerVisitorNameNullptr) { - TestRsqrtLayerVisitor visitor; + TestSliceLayerVisitor visitor; Network net; - IConnectableLayer *const layer = net.AddRsqrtLayer(); + IConnectableLayer *const layer = net.AddSliceLayer(SliceDescriptor()); layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckGatherLayerVisitorName) +// Subtraction +BOOST_AUTO_TEST_CASE(CheckSubtractionLayerVisitorName) { - TestGatherLayerVisitor visitor("GatherLayer"); + TestSubtractionLayerVisitor visitor("SubtractionLayer"); Network net; - IConnectableLayer *const layer = net.AddGatherLayer("GatherLayer"); + IConnectableLayer *const layer = net.AddSubtractionLayer("SubtractionLayer"); layer->Accept(visitor); } -BOOST_AUTO_TEST_CASE(CheckGatherLayerVisitorNameNullptr) +BOOST_AUTO_TEST_CASE(CheckSubtractionLayerVisitorNameNullptr) { - TestGatherLayerVisitor visitor; + TestSubtractionLayerVisitor visitor; Network net; - IConnectableLayer *const layer = net.AddGatherLayer(); + IConnectableLayer *const layer = net.AddSubtractionLayer(); layer->Accept(visitor); } BOOST_AUTO_TEST_SUITE_END() -} //namespace armnn \ No newline at end of file +} // namespace armnn diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.hpp b/src/armnn/test/TestNameOnlyLayerVisitor.hpp index c0037ae28f..dec0d15a96 100644 --- a/src/armnn/test/TestNameOnlyLayerVisitor.hpp +++ b/src/armnn/test/TestNameOnlyLayerVisitor.hpp @@ -22,97 +22,97 @@ public: }; }; -class TestMultiplicationLayerVisitor : public TestLayerVisitor +class TestDivisionLayerVisitor : public TestLayerVisitor { public: - explicit TestMultiplicationLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; + explicit TestDivisionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - void VisitMultiplicationLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { + void VisitDivisionLayer(const IConnectableLayer* layer, + const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; -class TestFloorLayerVisitor : public TestLayerVisitor +class TestEqualLayerVisitor : public TestLayerVisitor { public: - explicit TestFloorLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; + explicit TestEqualLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - void VisitFloorLayer(const IConnectableLayer* layer, + void VisitEqualLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; -class TestDivisionLayerVisitor : public TestLayerVisitor +class TestFloorLayerVisitor : public TestLayerVisitor { public: - explicit TestDivisionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; + explicit TestFloorLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - void VisitDivisionLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { + void VisitFloorLayer(const IConnectableLayer* layer, + const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; -class TestSubtractionLayerVisitor : public TestLayerVisitor +class TestGatherLayerVisitor : public TestLayerVisitor { public: - explicit TestSubtractionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; + explicit TestGatherLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - void VisitSubtractionLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { + void VisitGatherLayer(const IConnectableLayer* layer, + const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; -class TestMaximumLayerVisitor : public TestLayerVisitor +class TestGreaterLayerVisitor : public TestLayerVisitor { public: - explicit TestMaximumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; + explicit TestGreaterLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - void VisitMaximumLayer(const IConnectableLayer* layer, + void VisitGreaterLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; -class TestMinimumLayerVisitor : public TestLayerVisitor +class TestMultiplicationLayerVisitor : public TestLayerVisitor { public: - explicit TestMinimumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; + explicit TestMultiplicationLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - void VisitMinimumLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { + void VisitMultiplicationLayer(const IConnectableLayer* layer, + const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; -class TestGreaterLayerVisitor : public TestLayerVisitor +class TestMaximumLayerVisitor : public TestLayerVisitor { public: - explicit TestGreaterLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; + explicit TestMaximumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - void VisitGreaterLayer(const IConnectableLayer* layer, + void VisitMaximumLayer(const IConnectableLayer* layer, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; -class TestEqualLayerVisitor : public TestLayerVisitor +class TestMinimumLayerVisitor : public TestLayerVisitor { public: - explicit TestEqualLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; + explicit TestMinimumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - void VisitEqualLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { + void VisitMinimumLayer(const IConnectableLayer* layer, + const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; @@ -130,16 +130,30 @@ public: }; }; -class TestGatherLayerVisitor : public TestLayerVisitor +class TestSliceLayerVisitor : public TestLayerVisitor { public: - explicit TestGatherLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; + explicit TestSliceLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; - void VisitGatherLayer(const IConnectableLayer* layer, - const char* name = nullptr) override { + void VisitSliceLayer(const IConnectableLayer* layer, + const SliceDescriptor& sliceDescriptor, + const char* name = nullptr) override + { + CheckLayerPointer(layer); + CheckLayerName(name); + }; +}; + +class TestSubtractionLayerVisitor : public TestLayerVisitor +{ +public: + explicit TestSubtractionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {}; + + void VisitSubtractionLayer(const IConnectableLayer* layer, + const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); }; }; -} //namespace armnn \ No newline at end of file +} // namespace armnn -- cgit v1.2.1