aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/Descriptors.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/armnn/Descriptors.hpp')
-rw-r--r--include/armnn/Descriptors.hpp76
1 files changed, 76 insertions, 0 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index a8ad12ff8f..342d952277 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -377,6 +377,82 @@ struct Pooling2dDescriptor : BaseDescriptor
DataLayout m_DataLayout;
};
+/// A Pooling3dDescriptor for the Pooling3dLayer.
+struct Pooling3dDescriptor : BaseDescriptor
+{
+ Pooling3dDescriptor()
+ : m_PoolType(PoolingAlgorithm::Max)
+ , m_PadLeft(0)
+ , m_PadRight(0)
+ , m_PadTop(0)
+ , m_PadBottom(0)
+ , m_PadFront(0)
+ , m_PadBack(0)
+ , m_PoolWidth(0)
+ , m_PoolHeight(0)
+ , m_PoolDepth(0)
+ , m_StrideX(0)
+ , m_StrideY(0)
+ , m_StrideZ(0)
+ , m_OutputShapeRounding(OutputShapeRounding::Floor)
+ , m_PaddingMethod(PaddingMethod::Exclude)
+ , m_DataLayout(DataLayout::NCDHW)
+ {}
+
+ bool operator ==(const Pooling3dDescriptor& rhs) const
+ {
+ return m_PoolType == rhs.m_PoolType &&
+ m_PadLeft == rhs.m_PadLeft &&
+ m_PadRight == rhs.m_PadRight &&
+ m_PadTop == rhs.m_PadTop &&
+ m_PadBottom == rhs.m_PadBottom &&
+ m_PadFront == rhs.m_PadFront &&
+ m_PadBack == rhs.m_PadBack &&
+ m_PoolWidth == rhs.m_PoolWidth &&
+ m_PoolHeight == rhs.m_PoolHeight &&
+ m_PoolDepth == rhs.m_PoolDepth &&
+ m_StrideX == rhs.m_StrideX &&
+ m_StrideY == rhs.m_StrideY &&
+ m_StrideZ == rhs.m_StrideZ &&
+ m_OutputShapeRounding == rhs.m_OutputShapeRounding &&
+ m_PaddingMethod == rhs.m_PaddingMethod &&
+ m_DataLayout == rhs.m_DataLayout;
+ }
+
+ /// The pooling algorithm to use (Max. Average, L2).
+ PoolingAlgorithm m_PoolType;
+ /// Padding left value in the width dimension.
+ uint32_t m_PadLeft;
+ /// Padding right value in the width dimension.
+ uint32_t m_PadRight;
+ /// Padding top value in the height dimension.
+ uint32_t m_PadTop;
+ /// Padding bottom value in the height dimension.
+ uint32_t m_PadBottom;
+ /// Padding front value in the depth dimension.
+ uint32_t m_PadFront;
+ /// Padding back value in the depth dimension.
+ uint32_t m_PadBack;
+ /// Pooling width value.
+ uint32_t m_PoolWidth;
+ /// Pooling height value.
+ uint32_t m_PoolHeight;
+ /// Pooling depth value.
+ uint32_t m_PoolDepth;
+ /// Stride value when proceeding through input for the width dimension.
+ uint32_t m_StrideX;
+ /// Stride value when proceeding through input for the height dimension.
+ uint32_t m_StrideY;
+ /// Stride value when proceeding through input for the depth dimension.
+ uint32_t m_StrideZ;
+ /// The rounding method for the output shape. (Floor, Ceiling).
+ OutputShapeRounding m_OutputShapeRounding;
+ /// The padding method to be used. (Exclude, IgnoreValue).
+ PaddingMethod m_PaddingMethod;
+ /// The data layout to be used (NCDHW, NDHWC).
+ DataLayout m_DataLayout;
+};
+
/// A FullyConnectedDescriptor for the FullyConnectedLayer.
struct FullyConnectedDescriptor : BaseDescriptor
{