From 906f94631aa7ef590b9d8ff45507e818a0d1ac2c Mon Sep 17 00:00:00 2001 From: Jim Flynn Date: Fri, 10 May 2019 13:55:21 +0100 Subject: IVGCVSW-3076 Add ConcatLayer methods to public API !android-nn-driver:1120 Change-Id: I5192fa3deb4ea9766d38ad0bf4dfbfa0b4924c41 Signed-off-by: Jim Flynn --- include/armnn/ILayerSupport.hpp | 7 +++++++ include/armnn/ILayerVisitor.hpp | 15 +++++++++++++++ include/armnn/INetwork.hpp | 11 +++++++++++ include/armnn/LayerSupport.hpp | 10 ++++++++++ include/armnn/LayerVisitorBase.hpp | 4 ++++ 5 files changed, 47 insertions(+) (limited to 'include') diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp index b8e48c8704..c3fb7b016e 100644 --- a/include/armnn/ILayerSupport.hpp +++ b/include/armnn/ILayerSupport.hpp @@ -4,6 +4,7 @@ // #pragma once +#include #include #include @@ -47,6 +48,11 @@ public: const BatchToSpaceNdDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsConcatSupported(const std::vector inputs, + const TensorInfo& output, + const OriginsDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsConstantSupported(const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const = 0; @@ -184,6 +190,7 @@ public: const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const = 0; + ARMNN_DEPRECATED_MSG("Use IsConcatSupported instead") virtual bool IsMergerSupported(const std::vector inputs, const TensorInfo& output, const OriginsDescriptor& descriptor, diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp index ab793bc587..10d0cc6b63 100644 --- a/include/armnn/ILayerVisitor.hpp +++ b/include/armnn/ILayerVisitor.hpp @@ -58,6 +58,20 @@ public: const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor, const char* name = nullptr) = 0; + /// Function that a concat layer should call back to when its Accept(ILayerVisitor&) function is invoked. + /// @param layer - pointer to the layer which is calling back to this visit function. + /// @param mergerDescriptor - WindowsDescriptor to configure the concatenation process. Number of Views must be + /// equal to the number of inputs, and their order must match - e.g. first view + /// corresponds to the first input, second view to the second input, etc.... + /// @param name - Optional name for the layer. + virtual void VisitConcatLayer(const IConnectableLayer* layer, + const OriginsDescriptor& mergerDescriptor, + const char* name = nullptr) + { + // default implementation to ease transition while MergerLayer is being deprecated + VisitMergerLayer(layer, mergerDescriptor, name); + } + /// Function a layer with no inputs and a single output, which always corresponds to /// the passed in constant tensor should call back to when its Accept(ILayerVisitor&) function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. @@ -211,6 +225,7 @@ public: /// the number of inputs, and their order must match - e.g. first view corresponds to /// the first input, second view to the second input, etc.... /// @param name - Optional name for the layer. + // NOTE: this method will be deprecated and replaced by VisitConcatLayer virtual void VisitMergerLayer(const IConnectableLayer* layer, const OriginsDescriptor& mergerDescriptor, const char* name = nullptr) = 0; diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 7141770298..bae6e94955 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -4,6 +4,7 @@ // #pragma once +#include #include #include #include @@ -100,6 +101,15 @@ public: /// @return - Interface for configuring the layer. virtual IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name = nullptr) = 0; + /// Adds a concatenation layer to the network. + /// @param mergerDescriptor - WindowsDescriptor to configure the concatenation process. Number of Views must + /// be equal to the number of inputs, and their order must match - e.g. first view + /// corresponds to the first input, second view to the second input, etc.... + /// @param name - Optional name for the layer. + /// @return - Interface for configuring the layer. + virtual IConnectableLayer* AddConcatLayer(const OriginsDescriptor& mergerDescriptor, + const char* name = nullptr) = 0; + /// Adds a 2D convolution layer to the network. /// @param convolution2dDescriptor - Description of the 2D convolution layer. /// @param weights - Tensor for the weights data. @@ -248,6 +258,7 @@ public: /// the first input, second view to the second input, etc.... /// @param name - Optional name for the layer. /// @return - Interface for configuring the layer. + ARMNN_DEPRECATED_MSG("Use AddConcatLayer instead") virtual IConnectableLayer* AddMergerLayer(const OriginsDescriptor& mergerDescriptor, const char* name = nullptr) = 0; diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp index c9fc264e0c..e105b67740 100644 --- a/include/armnn/LayerSupport.hpp +++ b/include/armnn/LayerSupport.hpp @@ -4,6 +4,7 @@ // #pragma once +#include #include #include #include @@ -48,6 +49,14 @@ bool IsBatchToSpaceNdSupported(const BackendId& backend, char* reasonIfUnsupported = nullptr, size_t reasonIfUnsupportedMaxLength = 1024); +/// Deprecated in favor of IBackend and ILayerSupport interfaces +bool IsConcatSupported(const BackendId& backend, + const std::vector inputs, + const TensorInfo& output, + const OriginsDescriptor& descriptor, + char* reasonIfUnsupported = nullptr, + size_t reasonIfUnsupportedMaxLength = 1024); + /// Deprecated in favor of IBackend and ILayerSupport interfaces bool IsConstantSupported(const BackendId& backend, const TensorInfo& output, @@ -212,6 +221,7 @@ bool IsMergeSupported(const BackendId& backend, size_t reasonIfUnsupportedMaxLength = 1024); /// Deprecated in favor of IBackend and ILayerSupport interfaces +ARMNN_DEPRECATED_MSG("Use IsConcatSupported instead") bool IsMergerSupported(const BackendId& backend, const std::vector inputs, const TensorInfo& output, diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp index 12eb225674..62673ace07 100644 --- a/include/armnn/LayerVisitorBase.hpp +++ b/include/armnn/LayerVisitorBase.hpp @@ -33,6 +33,10 @@ public: LayerBindingId, const char*) override { DefaultPolicy::Apply(); } + void VisitConcatLayer(const IConnectableLayer*, + const OriginsDescriptor&, + const char*) override { DefaultPolicy::Apply(); } + void VisitConvolution2dLayer(const IConnectableLayer*, const Convolution2dDescriptor&, const ConstTensor&, -- cgit v1.2.1