diff options
Diffstat (limited to 'include/armnn')
-rw-r--r-- | include/armnn/Descriptors.hpp | 9 | ||||
-rw-r--r-- | include/armnn/DescriptorsFwd.hpp | 1 | ||||
-rw-r--r-- | include/armnn/INetwork.hpp | 4 | ||||
-rw-r--r-- | include/armnn/LayerSupport.hpp | 1 |
4 files changed, 14 insertions, 1 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index bc1b59bdf5..30c8144220 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -274,6 +274,15 @@ struct NormalizationDescriptor DataLayout m_DataLayout; }; +struct L2NormalizationDescriptor +{ + L2NormalizationDescriptor() + : m_DataLayout(DataLayout::NCHW) + {} + + DataLayout m_DataLayout; +}; + struct BatchNormalizationDescriptor { BatchNormalizationDescriptor() diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index 9cb3463b28..739c12056c 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -15,6 +15,7 @@ struct FullyConnectedDescriptor; struct LstmDescriptor; struct PermuteDescriptor; struct NormalizationDescriptor; +struct L2NormalizationDescriptor; struct MeanDescriptor; struct PadDescriptor; struct Pooling2dDescriptor; diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 1d2679aabb..2c83909c83 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -228,9 +228,11 @@ public: /// Adds an L2 normalization layer to the network. /// Normalization is performed along dimension 1, but requires a 4d input. + /// @param desc - Parameters for the L2 normalization operation. /// @param name - Optional name for the layer. /// @return - Interface for configuring the layer. - virtual IConnectableLayer* AddL2NormalizationLayer(const char* name = nullptr) = 0; + virtual IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc, + const char* name = nullptr) = 0; /// Adds a layer with no inputs and a single output, which always corresponds to /// the passed in constant tensor. diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp index 3c7cce5c64..25e888e71e 100644 --- a/include/armnn/LayerSupport.hpp +++ b/include/armnn/LayerSupport.hpp @@ -104,6 +104,7 @@ bool IsFullyConnectedSupported(Compute compute, bool IsL2NormalizationSupported(Compute compute, const TensorInfo& input, const TensorInfo& output, + const L2NormalizationDescriptor& descriptor, char* reasonIfUnsupported = nullptr, size_t reasonIfUnsupportedMaxLength = 1024); |