aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorMatthew Jackson <matthew.jackson@arm.com>2019-07-04 14:59:16 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-07-10 12:06:51 +0000
commit2b8c1da565871b3e69567c2cfc46c8dcbef301aa (patch)
tree682327de212e273405cb257028568db997644c35 /include
parentad5293a86e315049de36afd723dcd1a7e70681a7 (diff)
downloadarmnn-2b8c1da565871b3e69567c2cfc46c8dcbef301aa.tar.gz
IVGCVSW-3418 Add Arm NN front end support for the new Stack layer
* Added new StackLayer class * Made necessary changes to Descriptors, ILayerSupport, ILayerVisitor, etc. * Added unit tests Signed-off-by: Matthew Jackson <matthew.jackson@arm.com> Change-Id: Ieb97a928a342ffe1901c6058eb895711c358fd3d
Diffstat (limited to 'include')
-rw-r--r--include/armnn/Descriptors.hpp23
-rw-r--r--include/armnn/DescriptorsFwd.hpp1
-rw-r--r--include/armnn/ILayerSupport.hpp5
-rw-r--r--include/armnn/ILayerVisitor.hpp8
-rw-r--r--include/armnn/INetwork.hpp7
-rw-r--r--include/armnn/LayerSupport.hpp8
-rw-r--r--include/armnn/LayerVisitorBase.hpp4
7 files changed, 56 insertions, 0 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index cb76615889..377f0705d7 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -648,6 +648,29 @@ struct PadDescriptor
float m_PadValue;
};
+/// A StackDescriptor for the StackLayer.
+struct StackDescriptor
+{
+ StackDescriptor()
+ : m_Axis(0)
+ , m_NumInputs(0)
+ , m_InputShape()
+ {}
+
+ StackDescriptor(uint32_t axis, uint32_t numInputs, const TensorShape& inputShape)
+ : m_Axis(axis)
+ , m_NumInputs(numInputs)
+ , m_InputShape(inputShape)
+ {}
+
+ /// 0-based axis along which to stack the input tensors.
+ uint32_t m_Axis;
+ /// Number of input tensors.
+ uint32_t m_NumInputs;
+ /// Required shape of all input tensors.
+ TensorShape m_InputShape;
+};
+
/// A StridedSliceDescriptor for the StridedSliceLayer.
struct StridedSliceDescriptor
{
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index 9627ddc7f1..eddf91f4ce 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -30,6 +30,7 @@ struct ResizeDescriptor;
struct SoftmaxDescriptor;
struct SpaceToBatchNdDescriptor;
struct SpaceToDepthDescriptor;
+struct StackDescriptor;
struct StridedSliceDescriptor;
struct TransposeConvolution2dDescriptor;
struct ViewsDescriptor;
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 53dd29d87e..3cc6eabe9f 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -271,6 +271,11 @@ public:
const ViewsDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ virtual bool IsStackSupported(const std::vector<const TensorInfo*> inputs,
+ const TensorInfo& output,
+ const StackDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
virtual bool IsStridedSliceSupported(const TensorInfo& input,
const TensorInfo& output,
const StridedSliceDescriptor& descriptor,
diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp
index 86cf4a33cf..6e5b5463ac 100644
--- a/include/armnn/ILayerVisitor.hpp
+++ b/include/armnn/ILayerVisitor.hpp
@@ -370,6 +370,14 @@ public:
const ViewsDescriptor& splitterDescriptor,
const char* name = nullptr) = 0;
+ /// Function a stack layer should call back to when its Accept(ILayerVisitor&) function is invoked.
+ /// @param layer - pointer to the layer which is calling back to this visit function.
+ /// @param stackDescriptor - Parameters for the stack operation.
+ /// @param name - Optional name for the layer.
+ virtual void VisitStackLayer(const IConnectableLayer* layer,
+ const StackDescriptor& stackDescriptor,
+ const char* name = nullptr) = 0;
+
/// Function a strided slice layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param stridedSliceDescriptor - Parameters for the strided slice operation.
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 1f1b51095b..9e88c9279d 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -451,6 +451,13 @@ public:
const Optional<ConstTensor>& biases,
const char* name = nullptr) = 0;
+ /// Adds a stack layer to the network.
+ /// @param descriptor - Description of the stack layer.
+ /// @param name - Optional name for the layer.
+ /// @return - Interface for configuring the layer.
+ virtual IConnectableLayer* AddStackLayer(const StackDescriptor& descriptor,
+ const char* name = nullptr) = 0;
+
virtual void Accept(ILayerVisitor& visitor) const = 0;
protected:
diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp
index 65f9d089ba..6a3f1774bd 100644
--- a/include/armnn/LayerSupport.hpp
+++ b/include/armnn/LayerSupport.hpp
@@ -360,6 +360,14 @@ bool IsSplitterSupported(const BackendId& backend,
size_t reasonIfUnsupportedMaxLength = 1024);
/// Deprecated in favor of IBackend and ILayerSupport interfaces
+bool IsStackSupported(const BackendId& backend,
+ const std::vector<const TensorInfo*> inputs,
+ const TensorInfo& output,
+ const StackDescriptor& descriptor,
+ char* reasonIfUnsupported = nullptr,
+ size_t reasonIfUnsupportedMaxLength = 1024);
+
+/// Deprecated in favor of IBackend and ILayerSupport interfaces
bool IsStridedSliceSupported(const BackendId& backend,
const TensorInfo& input,
const TensorInfo& output,
diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp
index d657154b47..f107e9fb68 100644
--- a/include/armnn/LayerVisitorBase.hpp
+++ b/include/armnn/LayerVisitorBase.hpp
@@ -188,6 +188,10 @@ public:
const ViewsDescriptor&,
const char*) override { DefaultPolicy::Apply(__func__); }
+ void VisitStackLayer(const IConnectableLayer*,
+ const StackDescriptor&,
+ const char*) override { DefaultPolicy::Apply(__func__); }
+
void VisitStridedSliceLayer(const IConnectableLayer*,
const StridedSliceDescriptor&,
const char*) override { DefaultPolicy::Apply(__func__); }