diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/armnn/Descriptors.hpp | 15 | ||||
-rw-r--r-- | include/armnn/DescriptorsFwd.hpp | 1 | ||||
-rw-r--r-- | include/armnn/INetwork.hpp | 6 | ||||
-rw-r--r-- | include/armnn/LayerSupport.hpp | 7 |
4 files changed, 29 insertions, 0 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index decbf99880..5f9df6b2c3 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -332,4 +332,19 @@ struct LstmDescriptor bool m_ProjectionEnabled; }; +struct MeanDescriptor +{ + MeanDescriptor() + : m_KeepDims(false) + {} + + MeanDescriptor(const std::vector<unsigned int>& axis, bool keepDims) + : m_Axis(axis) + , m_KeepDims(keepDims) + {} + + std::vector<unsigned int> m_Axis; + bool m_KeepDims; +}; + } diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index ed958fc237..b161df8827 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -15,6 +15,7 @@ struct FullyConnectedDescriptor; struct LstmDescriptor; struct PermuteDescriptor; struct NormalizationDescriptor; +struct MeanDescriptor; struct Pooling2dDescriptor; struct ReshapeDescriptor; struct ResizeBilinearDescriptor; diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 0405074d3a..7fd7a25b60 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -279,6 +279,12 @@ public: /// @return - Interface for configuring the layer. virtual IConnectableLayer* AddSubtractionLayer(const char* name = nullptr) = 0; + /// Add a Mean layer to the network. + /// @param meanDescriptor - Parameters for the mean operation. + /// @param name - Optional name for the layer. + /// @ return - Interface for configuring the layer. + virtual IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr) = 0; + protected: ~INetwork() {} }; diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp index ac7d08ff62..d00691fad1 100644 --- a/include/armnn/LayerSupport.hpp +++ b/include/armnn/LayerSupport.hpp @@ -196,4 +196,11 @@ bool IsFloorSupported(Compute compute, char* reasonIfUnsupported = nullptr, size_t reasonIfUnsupportedMaxLength = 1024); +bool IsMeanSupported(Compute compute, + const TensorInfo& input, + const TensorInfo& output, + const MeanDescriptor& descriptor, + char* reasonIfUnsupported = nullptr, + size_t reasonIfUnsupportedMaxLength = 1024); + } |