From 7b885b3cce70154596b1994b013ea91527117c26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Ny=C3=ADri?= Date: Tue, 26 Oct 2021 14:47:57 +0100 Subject: IVGCVSW-6509 Front End + Reference Workload implementation Subtask of story: IVGCVSW-6164 Add a Pooling3d FrontEnd and Ref Implementation * Add front end * Add reference workload * Add corresponding unit tests Change-Id: Icce4146dd0a06a1da46a2def00a82d343e171750 Signed-off-by: Tamas Nyiri --- include/armnn/Descriptors.hpp | 76 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) (limited to 'include/armnn/Descriptors.hpp') diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index a8ad12ff8f..342d952277 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -377,6 +377,82 @@ struct Pooling2dDescriptor : BaseDescriptor DataLayout m_DataLayout; }; +/// A Pooling3dDescriptor for the Pooling3dLayer. +struct Pooling3dDescriptor : BaseDescriptor +{ + Pooling3dDescriptor() + : m_PoolType(PoolingAlgorithm::Max) + , m_PadLeft(0) + , m_PadRight(0) + , m_PadTop(0) + , m_PadBottom(0) + , m_PadFront(0) + , m_PadBack(0) + , m_PoolWidth(0) + , m_PoolHeight(0) + , m_PoolDepth(0) + , m_StrideX(0) + , m_StrideY(0) + , m_StrideZ(0) + , m_OutputShapeRounding(OutputShapeRounding::Floor) + , m_PaddingMethod(PaddingMethod::Exclude) + , m_DataLayout(DataLayout::NCDHW) + {} + + bool operator ==(const Pooling3dDescriptor& rhs) const + { + return m_PoolType == rhs.m_PoolType && + m_PadLeft == rhs.m_PadLeft && + m_PadRight == rhs.m_PadRight && + m_PadTop == rhs.m_PadTop && + m_PadBottom == rhs.m_PadBottom && + m_PadFront == rhs.m_PadFront && + m_PadBack == rhs.m_PadBack && + m_PoolWidth == rhs.m_PoolWidth && + m_PoolHeight == rhs.m_PoolHeight && + m_PoolDepth == rhs.m_PoolDepth && + m_StrideX == rhs.m_StrideX && + m_StrideY == rhs.m_StrideY && + m_StrideZ == rhs.m_StrideZ && + m_OutputShapeRounding == rhs.m_OutputShapeRounding && + m_PaddingMethod == rhs.m_PaddingMethod && + m_DataLayout == rhs.m_DataLayout; + } + + /// The pooling algorithm to use (Max. Average, L2). + PoolingAlgorithm m_PoolType; + /// Padding left value in the width dimension. + uint32_t m_PadLeft; + /// Padding right value in the width dimension. + uint32_t m_PadRight; + /// Padding top value in the height dimension. + uint32_t m_PadTop; + /// Padding bottom value in the height dimension. + uint32_t m_PadBottom; + /// Padding front value in the depth dimension. + uint32_t m_PadFront; + /// Padding back value in the depth dimension. + uint32_t m_PadBack; + /// Pooling width value. + uint32_t m_PoolWidth; + /// Pooling height value. + uint32_t m_PoolHeight; + /// Pooling depth value. + uint32_t m_PoolDepth; + /// Stride value when proceeding through input for the width dimension. + uint32_t m_StrideX; + /// Stride value when proceeding through input for the height dimension. + uint32_t m_StrideY; + /// Stride value when proceeding through input for the depth dimension. + uint32_t m_StrideZ; + /// The rounding method for the output shape. (Floor, Ceiling). + OutputShapeRounding m_OutputShapeRounding; + /// The padding method to be used. (Exclude, IgnoreValue). + PaddingMethod m_PaddingMethod; + /// The data layout to be used (NCDHW, NDHWC). + DataLayout m_DataLayout; +}; + /// A FullyConnectedDescriptor for the FullyConnectedLayer. struct FullyConnectedDescriptor : BaseDescriptor { -- cgit v1.2.1