aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/InternalTypes.hpp1
-rw-r--r--src/armnn/LayersFwd.hpp2
-rw-r--r--src/armnn/Network.cpp5
-rw-r--r--src/armnn/Network.hpp2
-rw-r--r--src/armnn/layers/SliceLayer.cpp66
-rw-r--r--src/armnn/layers/SliceLayer.hpp49
-rw-r--r--src/armnn/test/TestNameOnlyLayerVisitor.cpp116
-rw-r--r--src/armnn/test/TestNameOnlyLayerVisitor.hpp82
8 files changed, 246 insertions, 77 deletions
diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp
index 98308f92a1..1e05fff769 100644
--- a/src/armnn/InternalTypes.hpp
+++ b/src/armnn/InternalTypes.hpp
@@ -58,6 +58,7 @@ enum class LayerType
Reshape,
Resize,
Rsqrt,
+ Slice,
Softmax,
SpaceToBatchNd,
SpaceToDepth,
diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp
index 6e4cf6ab04..a98c104f85 100644
--- a/src/armnn/LayersFwd.hpp
+++ b/src/armnn/LayersFwd.hpp
@@ -50,6 +50,7 @@
#include "layers/ReshapeLayer.hpp"
#include "layers/ResizeLayer.hpp"
#include "layers/RsqrtLayer.hpp"
+#include "layers/SliceLayer.hpp"
#include "layers/SoftmaxLayer.hpp"
#include "layers/SpaceToBatchNdLayer.hpp"
#include "layers/SpaceToDepthLayer.hpp"
@@ -131,6 +132,7 @@ DECLARE_LAYER(QuantizedLstm)
DECLARE_LAYER(Reshape)
DECLARE_LAYER(Resize)
DECLARE_LAYER(Rsqrt)
+DECLARE_LAYER(Slice)
DECLARE_LAYER(Softmax)
DECLARE_LAYER(SpaceToBatchNd)
DECLARE_LAYER(SpaceToDepth)
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 6971cb89ba..c055407b3a 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1129,6 +1129,11 @@ normalizationDescriptor,
return m_Graph->AddLayer<NormalizationLayer>(normalizationDescriptor, name);
}
+IConnectableLayer* Network::AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name)
+{
+ return m_Graph->AddLayer<SliceLayer>(sliceDescriptor, name);
+}
+
IConnectableLayer* Network::AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
const char* name)
{
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp
index aac875aac7..274cc1ab7c 100644
--- a/src/armnn/Network.hpp
+++ b/src/armnn/Network.hpp
@@ -117,6 +117,8 @@ public:
IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
const char* name = nullptr) override;
+ IConnectableLayer* AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name = nullptr) override;
+
IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
const char* name = nullptr) override;
diff --git a/src/armnn/layers/SliceLayer.cpp b/src/armnn/layers/SliceLayer.cpp
new file mode 100644
index 0000000000..8ea5fd8f25
--- /dev/null
+++ b/src/armnn/layers/SliceLayer.cpp
@@ -0,0 +1,66 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "SliceLayer.hpp"
+
+#include "LayerCloneBase.hpp"
+
+#include <armnn/TypesUtils.hpp>
+
+#include <backendsCommon/WorkloadData.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+#include <boost/assert.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+
+namespace armnn
+{
+
+SliceLayer::SliceLayer(const SliceDescriptor& param, const char* name)
+ : LayerWithParameters(1, 1, LayerType::Slice, param, name)
+{
+}
+
+std::unique_ptr<IWorkload> SliceLayer::CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const
+{
+ SliceQueueDescriptor descriptor;
+ return factory.CreateSlice(descriptor, PrepInfoAndDesc(descriptor, graph));
+}
+
+SliceLayer* SliceLayer::Clone(Graph& graph) const
+{
+ return CloneBase<SliceLayer>(graph, m_Param, GetName());
+}
+
+void SliceLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(1, CHECK_LOCATION());
+
+ auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
+
+ BOOST_ASSERT(inferredShapes.size() == 1);
+
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "SliceLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+ GetOutputSlot(0).GetTensorInfo().GetShape(),
+ inferredShapes[0]);
+}
+
+std::vector<TensorShape> SliceLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
+{
+ BOOST_ASSERT(inputShapes.size() == 1);
+
+ TensorShape outputShape(boost::numeric_cast<unsigned int>(m_Param.m_Size.size()), m_Param.m_Size.data());
+
+ return std::vector<TensorShape>({ outputShape });
+}
+
+void SliceLayer::Accept(ILayerVisitor& visitor) const
+{
+ visitor.VisitSliceLayer(this, GetParameters(), GetName());
+}
+
+} // namespace armnn
diff --git a/src/armnn/layers/SliceLayer.hpp b/src/armnn/layers/SliceLayer.hpp
new file mode 100644
index 0000000000..38f0747f05
--- /dev/null
+++ b/src/armnn/layers/SliceLayer.hpp
@@ -0,0 +1,49 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "LayerWithParameters.hpp"
+
+namespace armnn
+{
+
+class SliceLayer : public LayerWithParameters<SliceDescriptor>
+{
+public:
+ /// Makes a workload for the Slice type.
+ /// @param [in] graph The graph where this layer can be found.
+ /// @param [in] factory The workload factory which will create the workload.
+ /// @return A pointer to the created workload, or nullptr if not created.
+ virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const override;
+
+ /// Creates a dynamically-allocated copy of this layer.
+ /// @param [in] graph The graph into which this layer is being cloned.
+ SliceLayer* Clone(Graph& graph) const override;
+
+ /// Check if the input tensor shape(s)
+ /// will lead to a valid configuration of @ref SliceLayer.
+ void ValidateTensorShapesFromInputs() override;
+
+ /// By default returns inputShapes if the number of inputs are equal to number of outputs,
+ /// otherwise infers the output shapes from given input shapes and layer properties.
+ /// @param [in] inputShapes The input shapes layer has.
+ /// @return A vector to the inferred output shape.
+ std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
+
+ void Accept(ILayerVisitor& visitor) const override;
+
+protected:
+ /// Constructor to create a SliceLayer.
+ /// @param [in] param SliceDescriptor to configure the resize operation.
+ /// @param [in] name Optional name for the layer.
+ SliceLayer(const SliceDescriptor& param, const char* name);
+
+ /// Default destructor.
+ ~SliceLayer() = default;
+};
+
+} // namespace armnn
diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.cpp b/src/armnn/test/TestNameOnlyLayerVisitor.cpp
index 4bb9614385..c4c4a479eb 100644
--- a/src/armnn/test/TestNameOnlyLayerVisitor.cpp
+++ b/src/armnn/test/TestNameOnlyLayerVisitor.cpp
@@ -10,6 +10,7 @@ namespace armnn {
BOOST_AUTO_TEST_SUITE(TestNameOnlyLayerVisitor)
+// Addition
BOOST_AUTO_TEST_CASE(CheckAdditionLayerVisitorName)
{
TestAdditionLayerVisitor visitor("AdditionLayer");
@@ -28,24 +29,45 @@ BOOST_AUTO_TEST_CASE(CheckAdditionLayerVisitorNameNullptr)
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckMultiplicationLayerVisitorName)
+// Division
+BOOST_AUTO_TEST_CASE(CheckDivisionLayerVisitorName)
{
- TestMultiplicationLayerVisitor visitor("MultiplicationLayer");
+ TestDivisionLayerVisitor visitor("DivisionLayer");
Network net;
- IConnectableLayer *const layer = net.AddMultiplicationLayer("MultiplicationLayer");
+ IConnectableLayer *const layer = net.AddAdditionLayer("DivisionLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckMultiplicationLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckDivisionLayerVisitorNameNullptr)
{
- TestMultiplicationLayerVisitor visitor;
+ TestDivisionLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddMultiplicationLayer();
+ IConnectableLayer *const layer = net.AddDivisionLayer();
+ layer->Accept(visitor);
+}
+
+// Equal
+BOOST_AUTO_TEST_CASE(CheckEqualLayerVisitorName)
+{
+ TestEqualLayerVisitor visitor("EqualLayer");
+ Network net;
+
+ IConnectableLayer *const layer = net.AddEqualLayer("EqualLayer");
+ layer->Accept(visitor);
+}
+
+BOOST_AUTO_TEST_CASE(CheckEqualLayerVisitorNameNullptr)
+{
+ TestEqualLayerVisitor visitor;
+ Network net;
+
+ IConnectableLayer *const layer = net.AddEqualLayer();
layer->Accept(visitor);
}
+// Floor
BOOST_AUTO_TEST_CASE(CheckFloorLayerVisitorName)
{
TestFloorLayerVisitor visitor("FloorLayer");
@@ -64,42 +86,45 @@ BOOST_AUTO_TEST_CASE(CheckFloorLayerVisitorNameNullptr)
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckDivisionLayerVisitorName)
+// Gather
+BOOST_AUTO_TEST_CASE(CheckGatherLayerVisitorName)
{
- TestDivisionLayerVisitor visitor("DivisionLayer");
+ TestGatherLayerVisitor visitor("GatherLayer");
Network net;
- IConnectableLayer *const layer = net.AddAdditionLayer("DivisionLayer");
+ IConnectableLayer *const layer = net.AddGatherLayer("GatherLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckDivisionLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckGatherLayerVisitorNameNullptr)
{
- TestDivisionLayerVisitor visitor;
+ TestGatherLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddDivisionLayer();
+ IConnectableLayer *const layer = net.AddGatherLayer();
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckSubtractionLayerVisitorName)
+// Greater
+BOOST_AUTO_TEST_CASE(CheckGreaterLayerVisitorName)
{
- TestSubtractionLayerVisitor visitor("SubtractionLayer");
+ TestGreaterLayerVisitor visitor("GreaterLayer");
Network net;
- IConnectableLayer *const layer = net.AddSubtractionLayer("SubtractionLayer");
+ IConnectableLayer *const layer = net.AddGreaterLayer("GreaterLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckSubtractionLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckGreaterLayerVisitorNameNullptr)
{
- TestSubtractionLayerVisitor visitor;
+ TestGreaterLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddSubtractionLayer();
+ IConnectableLayer *const layer = net.AddGreaterLayer();
layer->Accept(visitor);
}
+// Maximum
BOOST_AUTO_TEST_CASE(CheckMaximumLayerVisitorName)
{
TestMaximumLayerVisitor visitor("MaximumLayer");
@@ -118,6 +143,7 @@ BOOST_AUTO_TEST_CASE(CheckMaximumLayerVisitorNameNullptr)
layer->Accept(visitor);
}
+// Minimum
BOOST_AUTO_TEST_CASE(CheckMinimumLayerVisitorName)
{
TestMinimumLayerVisitor visitor("MinimumLayer");
@@ -136,78 +162,82 @@ BOOST_AUTO_TEST_CASE(CheckMinimumLayerVisitorNameNullptr)
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckGreaterLayerVisitorName)
+// Multiplication
+BOOST_AUTO_TEST_CASE(CheckMultiplicationLayerVisitorName)
{
- TestGreaterLayerVisitor visitor("GreaterLayer");
+ TestMultiplicationLayerVisitor visitor("MultiplicationLayer");
Network net;
- IConnectableLayer *const layer = net.AddGreaterLayer("GreaterLayer");
+ IConnectableLayer *const layer = net.AddMultiplicationLayer("MultiplicationLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckGreaterLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckMultiplicationLayerVisitorNameNullptr)
{
- TestGreaterLayerVisitor visitor;
+ TestMultiplicationLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddGreaterLayer();
+ IConnectableLayer *const layer = net.AddMultiplicationLayer();
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckEqualLayerVisitorName)
+// Rsqrt
+BOOST_AUTO_TEST_CASE(CheckRsqrtLayerVisitorName)
{
- TestEqualLayerVisitor visitor("EqualLayer");
+ TestRsqrtLayerVisitor visitor("RsqrtLayer");
Network net;
- IConnectableLayer *const layer = net.AddEqualLayer("EqualLayer");
+ IConnectableLayer *const layer = net.AddRsqrtLayer("RsqrtLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckEqualLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckRsqrtLayerVisitorNameNullptr)
{
- TestEqualLayerVisitor visitor;
+ TestRsqrtLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddEqualLayer();
+ IConnectableLayer *const layer = net.AddRsqrtLayer();
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckRsqrtLayerVisitorName)
+// Slice
+BOOST_AUTO_TEST_CASE(CheckSliceLayerVisitorName)
{
- TestRsqrtLayerVisitor visitor("RsqrtLayer");
+ TestSliceLayerVisitor visitor("SliceLayer");
Network net;
- IConnectableLayer *const layer = net.AddRsqrtLayer("RsqrtLayer");
+ IConnectableLayer *const layer = net.AddSliceLayer(SliceDescriptor(), "SliceLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckRsqrtLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckSliceLayerVisitorNameNullptr)
{
- TestRsqrtLayerVisitor visitor;
+ TestSliceLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddRsqrtLayer();
+ IConnectableLayer *const layer = net.AddSliceLayer(SliceDescriptor());
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckGatherLayerVisitorName)
+// Subtraction
+BOOST_AUTO_TEST_CASE(CheckSubtractionLayerVisitorName)
{
- TestGatherLayerVisitor visitor("GatherLayer");
+ TestSubtractionLayerVisitor visitor("SubtractionLayer");
Network net;
- IConnectableLayer *const layer = net.AddGatherLayer("GatherLayer");
+ IConnectableLayer *const layer = net.AddSubtractionLayer("SubtractionLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckGatherLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckSubtractionLayerVisitorNameNullptr)
{
- TestGatherLayerVisitor visitor;
+ TestSubtractionLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddGatherLayer();
+ IConnectableLayer *const layer = net.AddSubtractionLayer();
layer->Accept(visitor);
}
BOOST_AUTO_TEST_SUITE_END()
-} //namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.hpp b/src/armnn/test/TestNameOnlyLayerVisitor.hpp
index c0037ae28f..dec0d15a96 100644
--- a/src/armnn/test/TestNameOnlyLayerVisitor.hpp
+++ b/src/armnn/test/TestNameOnlyLayerVisitor.hpp
@@ -22,97 +22,97 @@ public:
};
};
-class TestMultiplicationLayerVisitor : public TestLayerVisitor
+class TestDivisionLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestMultiplicationLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestDivisionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitMultiplicationLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
+ void VisitDivisionLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestFloorLayerVisitor : public TestLayerVisitor
+class TestEqualLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestFloorLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestEqualLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitFloorLayer(const IConnectableLayer* layer,
+ void VisitEqualLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestDivisionLayerVisitor : public TestLayerVisitor
+class TestFloorLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestDivisionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestFloorLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitDivisionLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
+ void VisitFloorLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestSubtractionLayerVisitor : public TestLayerVisitor
+class TestGatherLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestSubtractionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestGatherLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitSubtractionLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
+ void VisitGatherLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestMaximumLayerVisitor : public TestLayerVisitor
+class TestGreaterLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestMaximumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestGreaterLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitMaximumLayer(const IConnectableLayer* layer,
+ void VisitGreaterLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestMinimumLayerVisitor : public TestLayerVisitor
+class TestMultiplicationLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestMinimumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestMultiplicationLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitMinimumLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
+ void VisitMultiplicationLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestGreaterLayerVisitor : public TestLayerVisitor
+class TestMaximumLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestGreaterLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestMaximumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitGreaterLayer(const IConnectableLayer* layer,
+ void VisitMaximumLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestEqualLayerVisitor : public TestLayerVisitor
+class TestMinimumLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestEqualLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestMinimumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitEqualLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
+ void VisitMinimumLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
@@ -130,16 +130,30 @@ public:
};
};
-class TestGatherLayerVisitor : public TestLayerVisitor
+class TestSliceLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestGatherLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestSliceLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitGatherLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
+ void VisitSliceLayer(const IConnectableLayer* layer,
+ const SliceDescriptor& sliceDescriptor,
+ const char* name = nullptr) override
+ {
+ CheckLayerPointer(layer);
+ CheckLayerName(name);
+ };
+};
+
+class TestSubtractionLayerVisitor : public TestLayerVisitor
+{
+public:
+ explicit TestSubtractionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+
+ void VisitSubtractionLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-} //namespace armnn \ No newline at end of file
+} // namespace armnn