aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-09-16 14:27:45 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-09-17 09:03:43 +0000
commit636ab40d3741e12eaad11d5b50e4b34bfbb258b5 (patch)
treedefaba57dc28c7b5dbe19075e24b6c8c0cefc9b2 /src
parent4dc64a69ba383ece509d442598617445a3b4847f (diff)
downloadarmnn-636ab40d3741e12eaad11d5b50e4b34bfbb258b5.tar.gz
IVGCVSW-3875 Add frontend for SLICE layer
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: Iebe675a0cee02db6f133d48ce58cbc1e233061db
Diffstat (limited to 'src')
-rw-r--r--src/armnn/InternalTypes.hpp1
-rw-r--r--src/armnn/LayersFwd.hpp2
-rw-r--r--src/armnn/Network.cpp5
-rw-r--r--src/armnn/Network.hpp2
-rw-r--r--src/armnn/layers/SliceLayer.cpp66
-rw-r--r--src/armnn/layers/SliceLayer.hpp49
-rw-r--r--src/armnn/test/TestNameOnlyLayerVisitor.cpp116
-rw-r--r--src/armnn/test/TestNameOnlyLayerVisitor.hpp82
-rw-r--r--src/armnnSerializer/Serializer.cpp7
-rw-r--r--src/armnnSerializer/Serializer.hpp4
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp8
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.hpp5
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp57
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp7
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp21
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.hpp3
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp2
17 files changed, 357 insertions, 80 deletions
diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp
index 98308f92a1..1e05fff769 100644
--- a/src/armnn/InternalTypes.hpp
+++ b/src/armnn/InternalTypes.hpp
@@ -58,6 +58,7 @@ enum class LayerType
Reshape,
Resize,
Rsqrt,
+ Slice,
Softmax,
SpaceToBatchNd,
SpaceToDepth,
diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp
index 6e4cf6ab04..a98c104f85 100644
--- a/src/armnn/LayersFwd.hpp
+++ b/src/armnn/LayersFwd.hpp
@@ -50,6 +50,7 @@
#include "layers/ReshapeLayer.hpp"
#include "layers/ResizeLayer.hpp"
#include "layers/RsqrtLayer.hpp"
+#include "layers/SliceLayer.hpp"
#include "layers/SoftmaxLayer.hpp"
#include "layers/SpaceToBatchNdLayer.hpp"
#include "layers/SpaceToDepthLayer.hpp"
@@ -131,6 +132,7 @@ DECLARE_LAYER(QuantizedLstm)
DECLARE_LAYER(Reshape)
DECLARE_LAYER(Resize)
DECLARE_LAYER(Rsqrt)
+DECLARE_LAYER(Slice)
DECLARE_LAYER(Softmax)
DECLARE_LAYER(SpaceToBatchNd)
DECLARE_LAYER(SpaceToDepth)
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 6971cb89ba..c055407b3a 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1129,6 +1129,11 @@ normalizationDescriptor,
return m_Graph->AddLayer<NormalizationLayer>(normalizationDescriptor, name);
}
+IConnectableLayer* Network::AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name)
+{
+ return m_Graph->AddLayer<SliceLayer>(sliceDescriptor, name);
+}
+
IConnectableLayer* Network::AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
const char* name)
{
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp
index aac875aac7..274cc1ab7c 100644
--- a/src/armnn/Network.hpp
+++ b/src/armnn/Network.hpp
@@ -117,6 +117,8 @@ public:
IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
const char* name = nullptr) override;
+ IConnectableLayer* AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name = nullptr) override;
+
IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
const char* name = nullptr) override;
diff --git a/src/armnn/layers/SliceLayer.cpp b/src/armnn/layers/SliceLayer.cpp
new file mode 100644
index 0000000000..8ea5fd8f25
--- /dev/null
+++ b/src/armnn/layers/SliceLayer.cpp
@@ -0,0 +1,66 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "SliceLayer.hpp"
+
+#include "LayerCloneBase.hpp"
+
+#include <armnn/TypesUtils.hpp>
+
+#include <backendsCommon/WorkloadData.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+#include <boost/assert.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+
+namespace armnn
+{
+
+SliceLayer::SliceLayer(const SliceDescriptor& param, const char* name)
+ : LayerWithParameters(1, 1, LayerType::Slice, param, name)
+{
+}
+
+std::unique_ptr<IWorkload> SliceLayer::CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const
+{
+ SliceQueueDescriptor descriptor;
+ return factory.CreateSlice(descriptor, PrepInfoAndDesc(descriptor, graph));
+}
+
+SliceLayer* SliceLayer::Clone(Graph& graph) const
+{
+ return CloneBase<SliceLayer>(graph, m_Param, GetName());
+}
+
+void SliceLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(1, CHECK_LOCATION());
+
+ auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
+
+ BOOST_ASSERT(inferredShapes.size() == 1);
+
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "SliceLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+ GetOutputSlot(0).GetTensorInfo().GetShape(),
+ inferredShapes[0]);
+}
+
+std::vector<TensorShape> SliceLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
+{
+ BOOST_ASSERT(inputShapes.size() == 1);
+
+ TensorShape outputShape(boost::numeric_cast<unsigned int>(m_Param.m_Size.size()), m_Param.m_Size.data());
+
+ return std::vector<TensorShape>({ outputShape });
+}
+
+void SliceLayer::Accept(ILayerVisitor& visitor) const
+{
+ visitor.VisitSliceLayer(this, GetParameters(), GetName());
+}
+
+} // namespace armnn
diff --git a/src/armnn/layers/SliceLayer.hpp b/src/armnn/layers/SliceLayer.hpp
new file mode 100644
index 0000000000..38f0747f05
--- /dev/null
+++ b/src/armnn/layers/SliceLayer.hpp
@@ -0,0 +1,49 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "LayerWithParameters.hpp"
+
+namespace armnn
+{
+
+class SliceLayer : public LayerWithParameters<SliceDescriptor>
+{
+public:
+ /// Makes a workload for the Slice type.
+ /// @param [in] graph The graph where this layer can be found.
+ /// @param [in] factory The workload factory which will create the workload.
+ /// @return A pointer to the created workload, or nullptr if not created.
+ virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const override;
+
+ /// Creates a dynamically-allocated copy of this layer.
+ /// @param [in] graph The graph into which this layer is being cloned.
+ SliceLayer* Clone(Graph& graph) const override;
+
+ /// Check if the input tensor shape(s)
+ /// will lead to a valid configuration of @ref SliceLayer.
+ void ValidateTensorShapesFromInputs() override;
+
+ /// By default returns inputShapes if the number of inputs are equal to number of outputs,
+ /// otherwise infers the output shapes from given input shapes and layer properties.
+ /// @param [in] inputShapes The input shapes layer has.
+ /// @return A vector to the inferred output shape.
+ std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
+
+ void Accept(ILayerVisitor& visitor) const override;
+
+protected:
+ /// Constructor to create a SliceLayer.
+ /// @param [in] param SliceDescriptor to configure the resize operation.
+ /// @param [in] name Optional name for the layer.
+ SliceLayer(const SliceDescriptor& param, const char* name);
+
+ /// Default destructor.
+ ~SliceLayer() = default;
+};
+
+} // namespace armnn
diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.cpp b/src/armnn/test/TestNameOnlyLayerVisitor.cpp
index 4bb9614385..c4c4a479eb 100644
--- a/src/armnn/test/TestNameOnlyLayerVisitor.cpp
+++ b/src/armnn/test/TestNameOnlyLayerVisitor.cpp
@@ -10,6 +10,7 @@ namespace armnn {
BOOST_AUTO_TEST_SUITE(TestNameOnlyLayerVisitor)
+// Addition
BOOST_AUTO_TEST_CASE(CheckAdditionLayerVisitorName)
{
TestAdditionLayerVisitor visitor("AdditionLayer");
@@ -28,24 +29,45 @@ BOOST_AUTO_TEST_CASE(CheckAdditionLayerVisitorNameNullptr)
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckMultiplicationLayerVisitorName)
+// Division
+BOOST_AUTO_TEST_CASE(CheckDivisionLayerVisitorName)
{
- TestMultiplicationLayerVisitor visitor("MultiplicationLayer");
+ TestDivisionLayerVisitor visitor("DivisionLayer");
Network net;
- IConnectableLayer *const layer = net.AddMultiplicationLayer("MultiplicationLayer");
+ IConnectableLayer *const layer = net.AddAdditionLayer("DivisionLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckMultiplicationLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckDivisionLayerVisitorNameNullptr)
{
- TestMultiplicationLayerVisitor visitor;
+ TestDivisionLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddMultiplicationLayer();
+ IConnectableLayer *const layer = net.AddDivisionLayer();
+ layer->Accept(visitor);
+}
+
+// Equal
+BOOST_AUTO_TEST_CASE(CheckEqualLayerVisitorName)
+{
+ TestEqualLayerVisitor visitor("EqualLayer");
+ Network net;
+
+ IConnectableLayer *const layer = net.AddEqualLayer("EqualLayer");
+ layer->Accept(visitor);
+}
+
+BOOST_AUTO_TEST_CASE(CheckEqualLayerVisitorNameNullptr)
+{
+ TestEqualLayerVisitor visitor;
+ Network net;
+
+ IConnectableLayer *const layer = net.AddEqualLayer();
layer->Accept(visitor);
}
+// Floor
BOOST_AUTO_TEST_CASE(CheckFloorLayerVisitorName)
{
TestFloorLayerVisitor visitor("FloorLayer");
@@ -64,42 +86,45 @@ BOOST_AUTO_TEST_CASE(CheckFloorLayerVisitorNameNullptr)
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckDivisionLayerVisitorName)
+// Gather
+BOOST_AUTO_TEST_CASE(CheckGatherLayerVisitorName)
{
- TestDivisionLayerVisitor visitor("DivisionLayer");
+ TestGatherLayerVisitor visitor("GatherLayer");
Network net;
- IConnectableLayer *const layer = net.AddAdditionLayer("DivisionLayer");
+ IConnectableLayer *const layer = net.AddGatherLayer("GatherLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckDivisionLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckGatherLayerVisitorNameNullptr)
{
- TestDivisionLayerVisitor visitor;
+ TestGatherLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddDivisionLayer();
+ IConnectableLayer *const layer = net.AddGatherLayer();
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckSubtractionLayerVisitorName)
+// Greater
+BOOST_AUTO_TEST_CASE(CheckGreaterLayerVisitorName)
{
- TestSubtractionLayerVisitor visitor("SubtractionLayer");
+ TestGreaterLayerVisitor visitor("GreaterLayer");
Network net;
- IConnectableLayer *const layer = net.AddSubtractionLayer("SubtractionLayer");
+ IConnectableLayer *const layer = net.AddGreaterLayer("GreaterLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckSubtractionLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckGreaterLayerVisitorNameNullptr)
{
- TestSubtractionLayerVisitor visitor;
+ TestGreaterLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddSubtractionLayer();
+ IConnectableLayer *const layer = net.AddGreaterLayer();
layer->Accept(visitor);
}
+// Maximum
BOOST_AUTO_TEST_CASE(CheckMaximumLayerVisitorName)
{
TestMaximumLayerVisitor visitor("MaximumLayer");
@@ -118,6 +143,7 @@ BOOST_AUTO_TEST_CASE(CheckMaximumLayerVisitorNameNullptr)
layer->Accept(visitor);
}
+// Minimum
BOOST_AUTO_TEST_CASE(CheckMinimumLayerVisitorName)
{
TestMinimumLayerVisitor visitor("MinimumLayer");
@@ -136,78 +162,82 @@ BOOST_AUTO_TEST_CASE(CheckMinimumLayerVisitorNameNullptr)
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckGreaterLayerVisitorName)
+// Multiplication
+BOOST_AUTO_TEST_CASE(CheckMultiplicationLayerVisitorName)
{
- TestGreaterLayerVisitor visitor("GreaterLayer");
+ TestMultiplicationLayerVisitor visitor("MultiplicationLayer");
Network net;
- IConnectableLayer *const layer = net.AddGreaterLayer("GreaterLayer");
+ IConnectableLayer *const layer = net.AddMultiplicationLayer("MultiplicationLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckGreaterLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckMultiplicationLayerVisitorNameNullptr)
{
- TestGreaterLayerVisitor visitor;
+ TestMultiplicationLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddGreaterLayer();
+ IConnectableLayer *const layer = net.AddMultiplicationLayer();
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckEqualLayerVisitorName)
+// Rsqrt
+BOOST_AUTO_TEST_CASE(CheckRsqrtLayerVisitorName)
{
- TestEqualLayerVisitor visitor("EqualLayer");
+ TestRsqrtLayerVisitor visitor("RsqrtLayer");
Network net;
- IConnectableLayer *const layer = net.AddEqualLayer("EqualLayer");
+ IConnectableLayer *const layer = net.AddRsqrtLayer("RsqrtLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckEqualLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckRsqrtLayerVisitorNameNullptr)
{
- TestEqualLayerVisitor visitor;
+ TestRsqrtLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddEqualLayer();
+ IConnectableLayer *const layer = net.AddRsqrtLayer();
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckRsqrtLayerVisitorName)
+// Slice
+BOOST_AUTO_TEST_CASE(CheckSliceLayerVisitorName)
{
- TestRsqrtLayerVisitor visitor("RsqrtLayer");
+ TestSliceLayerVisitor visitor("SliceLayer");
Network net;
- IConnectableLayer *const layer = net.AddRsqrtLayer("RsqrtLayer");
+ IConnectableLayer *const layer = net.AddSliceLayer(SliceDescriptor(), "SliceLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckRsqrtLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckSliceLayerVisitorNameNullptr)
{
- TestRsqrtLayerVisitor visitor;
+ TestSliceLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddRsqrtLayer();
+ IConnectableLayer *const layer = net.AddSliceLayer(SliceDescriptor());
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckGatherLayerVisitorName)
+// Subtraction
+BOOST_AUTO_TEST_CASE(CheckSubtractionLayerVisitorName)
{
- TestGatherLayerVisitor visitor("GatherLayer");
+ TestSubtractionLayerVisitor visitor("SubtractionLayer");
Network net;
- IConnectableLayer *const layer = net.AddGatherLayer("GatherLayer");
+ IConnectableLayer *const layer = net.AddSubtractionLayer("SubtractionLayer");
layer->Accept(visitor);
}
-BOOST_AUTO_TEST_CASE(CheckGatherLayerVisitorNameNullptr)
+BOOST_AUTO_TEST_CASE(CheckSubtractionLayerVisitorNameNullptr)
{
- TestGatherLayerVisitor visitor;
+ TestSubtractionLayerVisitor visitor;
Network net;
- IConnectableLayer *const layer = net.AddGatherLayer();
+ IConnectableLayer *const layer = net.AddSubtractionLayer();
layer->Accept(visitor);
}
BOOST_AUTO_TEST_SUITE_END()
-} //namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.hpp b/src/armnn/test/TestNameOnlyLayerVisitor.hpp
index c0037ae28f..dec0d15a96 100644
--- a/src/armnn/test/TestNameOnlyLayerVisitor.hpp
+++ b/src/armnn/test/TestNameOnlyLayerVisitor.hpp
@@ -22,97 +22,97 @@ public:
};
};
-class TestMultiplicationLayerVisitor : public TestLayerVisitor
+class TestDivisionLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestMultiplicationLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestDivisionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitMultiplicationLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
+ void VisitDivisionLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestFloorLayerVisitor : public TestLayerVisitor
+class TestEqualLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestFloorLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestEqualLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitFloorLayer(const IConnectableLayer* layer,
+ void VisitEqualLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestDivisionLayerVisitor : public TestLayerVisitor
+class TestFloorLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestDivisionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestFloorLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitDivisionLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
+ void VisitFloorLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestSubtractionLayerVisitor : public TestLayerVisitor
+class TestGatherLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestSubtractionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestGatherLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitSubtractionLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
+ void VisitGatherLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestMaximumLayerVisitor : public TestLayerVisitor
+class TestGreaterLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestMaximumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestGreaterLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitMaximumLayer(const IConnectableLayer* layer,
+ void VisitGreaterLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestMinimumLayerVisitor : public TestLayerVisitor
+class TestMultiplicationLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestMinimumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestMultiplicationLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitMinimumLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
+ void VisitMultiplicationLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestGreaterLayerVisitor : public TestLayerVisitor
+class TestMaximumLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestGreaterLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestMaximumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitGreaterLayer(const IConnectableLayer* layer,
+ void VisitMaximumLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-class TestEqualLayerVisitor : public TestLayerVisitor
+class TestMinimumLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestEqualLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestMinimumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitEqualLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
+ void VisitMinimumLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
@@ -130,16 +130,30 @@ public:
};
};
-class TestGatherLayerVisitor : public TestLayerVisitor
+class TestSliceLayerVisitor : public TestLayerVisitor
{
public:
- explicit TestGatherLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+ explicit TestSliceLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
- void VisitGatherLayer(const IConnectableLayer* layer,
- const char* name = nullptr) override {
+ void VisitSliceLayer(const IConnectableLayer* layer,
+ const SliceDescriptor& sliceDescriptor,
+ const char* name = nullptr) override
+ {
+ CheckLayerPointer(layer);
+ CheckLayerName(name);
+ };
+};
+
+class TestSubtractionLayerVisitor : public TestLayerVisitor
+{
+public:
+ explicit TestSubtractionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
+
+ void VisitSubtractionLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
-} //namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 602c4ab99f..06bfb91e83 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -745,6 +745,13 @@ void SerializerVisitor::VisitRsqrtLayer(const armnn::IConnectableLayer* layer, c
CreateAnyLayer(fbRsqrtLayer.o, serializer::Layer::Layer_RsqrtLayer);
}
+void SerializerVisitor::VisitSliceLayer(const armnn::IConnectableLayer* layer,
+ const armnn::SliceDescriptor& sliceDescriptor,
+ const char* name)
+{
+ throw UnimplementedException("SerializerVisitor::VisitSliceLayer is not implemented");
+}
+
// Build FlatBuffer for Softmax Layer
void SerializerVisitor::VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
const armnn::SoftmaxDescriptor& softmaxDescriptor,
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 190ed231e3..8e65902002 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -193,6 +193,10 @@ public:
void VisitRsqrtLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
+ void VisitSliceLayer(const armnn::IConnectableLayer* layer,
+ const armnn::SliceDescriptor& sliceDescriptor,
+ const char* name = nullptr) override;
+
void VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
const armnn::SoftmaxDescriptor& softmaxDescriptor,
const char* name = nullptr) override;
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index a8d1eaddc3..7f1fd1097a 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -414,6 +414,14 @@ bool LayerSupportBase::IsRsqrtSupported(const TensorInfo &input,
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
+bool LayerSupportBase::IsSliceSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const SliceDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+}
+
bool LayerSupportBase::IsSoftmaxSupported(const TensorInfo& input,
const TensorInfo& output,
const SoftmaxDescriptor& descriptor,
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index 25dbdf2906..8df1f8d54f 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -253,6 +253,11 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsSliceSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const SliceDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsSoftmaxSupported(const TensorInfo& input,
const TensorInfo& output,
const SoftmaxDescriptor& descriptor,
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index f290cbd9cf..2fa0c92daf 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -2443,7 +2443,7 @@ void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
-
+
// Infer number of batches, input size and output size from tensor dimensions
const uint32_t numBatches = inputInfo.GetShape()[0];
const uint32_t inputSize = inputInfo.GetShape()[1];
@@ -2584,4 +2584,59 @@ void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
}
+void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ const std::string descriptorName{"SliceQueueDescriptor"};
+
+ ValidateNumInputs(workloadInfo, descriptorName, 1);
+ ValidateNumOutputs(workloadInfo, descriptorName, 1);
+
+ const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+
+ ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+ const unsigned int rank = inputTensorInfo.GetNumDimensions();
+ if (rank > 4)
+ {
+ throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
+ }
+
+ ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
+
+ // Check if m_Begin and m_Size have the expected length
+ if (m_Parameters.m_Begin.size() != rank)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
+ }
+ if (m_Parameters.m_Size.size() != rank)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Length of size descriptor must equal rank " + std::to_string(rank));
+ }
+
+ // Check if the shape of the output tensor matches m_Size
+ const TensorShape& outputShape = outputTensorInfo.GetShape();
+ for (unsigned int i = 0u; i < rank; ++i)
+ {
+ if (m_Parameters.m_Size[i] != outputShape[i])
+ {
+ throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
+ }
+ }
+
+ // Check if the sum of begin offset and size in a given dimension
+ // does not exceed the size of corresponding input
+ const TensorShape& inputShape = inputTensorInfo.GetShape();
+ for(unsigned int i = 0u; i < rank; ++i)
+ {
+ if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] >= inputShape[i])
+ {
+ throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
+ std::to_string(i) + " exceeds input size.");
+ }
+ }
+}
+
} // namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index 35130ad160..1e49243b34 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -533,4 +533,9 @@ struct AbsQueueDescriptor : QueueDescriptor
void Validate(const WorkloadInfo& workloadInfo) const;
};
-} //namespace armnn
+struct SliceQueueDescriptor : QueueDescriptorWithParameters<SliceDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+} // namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 17bd98b349..9d6b2bd6a9 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -760,6 +760,19 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::Slice:
+ {
+ auto cLayer = boost::polymorphic_downcast<const SliceLayer*>(&layer);
+
+ const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+
+ result = layerSupportObject->IsSliceSupported(OverrideDataType(input, dataType),
+ OverrideDataType(output, dataType),
+ cLayer->GetParameters(),
+ reason);
+ break;
+ }
case LayerType::Softmax:
{
auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
@@ -1245,6 +1258,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescrip
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::unique_ptr<IWorkload>();
+}
+
std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
@@ -1300,4 +1319,4 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
return std::unique_ptr<IWorkload>();
}
-} // namepsace armnn \ No newline at end of file
+} // namepsace armnn
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index 6fd334b49c..91cf2c742c 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -186,6 +186,9 @@ public:
virtual std::unique_ptr<IWorkload> CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
+ virtual std::unique_ptr<IWorkload> CreateSlice(const SliceQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
virtual std::unique_ptr<IWorkload> CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index 1dc9e9700f..17b7934e9f 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -473,6 +473,8 @@ DECLARE_LAYER_POLICY_2_PARAM(Reshape)
DECLARE_LAYER_POLICY_1_PARAM(Rsqrt)
+DECLARE_LAYER_POLICY_2_PARAM(Slice)
+
DECLARE_LAYER_POLICY_2_PARAM(Softmax)
DECLARE_LAYER_POLICY_2_PARAM(SpaceToBatchNd)