aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
Diffstat (limited to 'include')
-rw-r--r--include/armnn/Descriptors.hpp12
-rw-r--r--include/armnn/Types.hpp11
-rw-r--r--include/armnn/TypesUtils.hpp11
3 files changed, 31 insertions, 3 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 39ea824045..a8ad12ff8f 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1060,17 +1060,20 @@ struct MeanDescriptor : BaseDescriptor
/// A PadDescriptor for the PadLayer.
struct PadDescriptor : BaseDescriptor
{
- PadDescriptor() : m_PadValue(0)
+ PadDescriptor() : m_PadValue(0), m_PaddingMode(PaddingMode::Constant)
{}
- PadDescriptor(const std::vector<std::pair<unsigned int, unsigned int>>& padList, const float& padValue = 0)
+ PadDescriptor(const std::vector<std::pair<unsigned int, unsigned int>>& padList,
+ const float& padValue = 0,
+ const PaddingMode& paddingMode = PaddingMode::Constant)
: m_PadList(padList)
, m_PadValue(padValue)
+ , m_PaddingMode(paddingMode)
{}
bool operator ==(const PadDescriptor& rhs) const
{
- return m_PadList == rhs.m_PadList && m_PadValue == rhs.m_PadValue;
+ return m_PadList == rhs.m_PadList && m_PadValue == rhs.m_PadValue && m_PaddingMode == rhs.m_PaddingMode;
}
/// @brief Specifies the padding for input dimension.
@@ -1081,6 +1084,9 @@ struct PadDescriptor : BaseDescriptor
/// Optional value to use for padding, defaults to 0
float m_PadValue;
+
+ /// Specifies the Padding mode (Constant, Reflect or Symmetric)
+ PaddingMode m_PaddingMode;
};
/// A SliceDescriptor for the SliceLayer.
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index 4f39ebe16a..deaa0b3a50 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -166,6 +166,17 @@ enum class PaddingMethod
Exclude = 1
};
+///
+/// The padding mode controls whether the padding should be filled with constant values (Constant), or
+/// reflect the input, either including the border values (Symmetric) or not (Reflect).
+///
+enum class PaddingMode
+{
+ Constant = 0,
+ Reflect = 1,
+ Symmetric = 2
+};
+
enum class NormalizationAlgorithmChannel
{
Across = 0,
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index a1c11b74df..ccb0280457 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -125,6 +125,17 @@ constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
}
}
+constexpr char const* GetPaddingModeAsCString(PaddingMode mode)
+{
+ switch (mode)
+ {
+ case PaddingMode::Constant: return "Exclude";
+ case PaddingMode::Symmetric: return "Symmetric";
+ case PaddingMode::Reflect: return "Reflect";
+ default: return "Unknown";
+ }
+}
+
constexpr char const* GetReduceOperationAsCString(ReduceOperation reduce_operation)
{
switch (reduce_operation)