aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/StackLayer.cpp
diff options
context:
space:
mode:
authorMatthew Jackson <matthew.jackson@arm.com>2019-07-04 14:59:16 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-07-10 12:06:51 +0000
commit2b8c1da565871b3e69567c2cfc46c8dcbef301aa (patch)
tree682327de212e273405cb257028568db997644c35 /src/armnn/layers/StackLayer.cpp
parentad5293a86e315049de36afd723dcd1a7e70681a7 (diff)
downloadarmnn-2b8c1da565871b3e69567c2cfc46c8dcbef301aa.tar.gz
IVGCVSW-3418 Add Arm NN front end support for the new Stack layer
* Added new StackLayer class * Made necessary changes to Descriptors, ILayerSupport, ILayerVisitor, etc. * Added unit tests Signed-off-by: Matthew Jackson <matthew.jackson@arm.com> Change-Id: Ieb97a928a342ffe1901c6058eb895711c358fd3d
Diffstat (limited to 'src/armnn/layers/StackLayer.cpp')
-rw-r--r--src/armnn/layers/StackLayer.cpp98
1 files changed, 98 insertions, 0 deletions
diff --git a/src/armnn/layers/StackLayer.cpp b/src/armnn/layers/StackLayer.cpp
new file mode 100644
index 0000000000..59bc8d5a13
--- /dev/null
+++ b/src/armnn/layers/StackLayer.cpp
@@ -0,0 +1,98 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "StackLayer.hpp"
+#include "LayerCloneBase.hpp"
+
+#include <armnn/TypesUtils.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+#include <queue>
+
+namespace armnn
+{
+
+StackLayer::StackLayer(const StackDescriptor& param, const char* name)
+ : LayerWithParameters(param.m_NumInputs, 1, LayerType::Stack, param, name)
+{
+}
+
+std::unique_ptr<IWorkload> StackLayer::CreateWorkload(const Graph& graph, const IWorkloadFactory& factory) const
+{
+ StackQueueDescriptor descriptor;
+ return factory.CreateStack(descriptor, PrepInfoAndDesc(descriptor, graph));
+}
+
+StackLayer* StackLayer::Clone(Graph& graph) const
+{
+ return CloneBase<StackLayer>(graph, m_Param, GetName());
+}
+
+std::vector<TensorShape> StackLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
+{
+ const TensorShape& inputShape = m_Param.m_InputShape;
+ const unsigned int inputNumDimensions = inputShape.GetNumDimensions();
+ const unsigned int axis = m_Param.m_Axis;
+
+ BOOST_ASSERT(axis <= inputNumDimensions);
+
+ unsigned int dimensionSizes[inputNumDimensions + 1];
+ for (unsigned int i = 0; i < axis; ++i)
+ {
+ dimensionSizes[i] = inputShape[i];
+ }
+
+ dimensionSizes[axis] = m_Param.m_NumInputs;
+
+ for (unsigned int i = axis + 1; i < inputNumDimensions + 1; ++i)
+ {
+ dimensionSizes[i] = inputShape[i-1];
+ }
+
+ TensorShape targetShape = TensorShape(inputNumDimensions + 1, dimensionSizes);
+
+ return std::vector<TensorShape>({ targetShape });
+}
+
+void StackLayer::ValidateTensorShapesFromInputs()
+{
+ // Validates Stack layer.
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "StackLayer: Num Input Slots must match Num Inputs.",
+ m_Param.m_NumInputs,
+ GetNumInputSlots());
+
+ VerifyLayerConnections(m_Param.m_NumInputs, CHECK_LOCATION());
+
+ // Constructs and validates input shapes
+ std::vector<TensorShape> inputShapes;
+ for (unsigned int i = 0; i < GetNumInputSlots(); ++i)
+ {
+ TensorShape inputShape = GetInputSlot(i).GetConnection()->GetTensorInfo().GetShape();
+ if (inputShape != m_Param.m_InputShape)
+ {
+ throw LayerValidationException("ConcatLayer: TensorShape set on InputSlot[" +
+ std::to_string(i) +
+ "] does not match defined input shape");
+ }
+ inputShapes.push_back(inputShape);
+ }
+
+ auto inferredShapes = InferOutputShapes(inputShapes);
+
+ BOOST_ASSERT(inferredShapes.size() == 1);
+
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "StackLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+ GetOutputSlot(0).GetTensorInfo().GetShape(),
+ inferredShapes[0]);
+}
+
+void StackLayer::Accept(ILayerVisitor& visitor) const
+{
+ visitor.VisitStackLayer(this, GetParameters(), GetName());
+}
+
+} // namespace armnn armnn