aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/Types.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/armnn/Types.hpp')
-rw-r--r--include/armnn/Types.hpp44
1 files changed, 44 insertions, 0 deletions
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index 4afc50b1b1..bb0b1e6ca7 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -31,12 +31,56 @@ enum class DataType
Signed32 = 3
};
+// Begin: DataLayout
+
enum class DataLayout
{
NCHW = 1,
NHWC = 2
};
+// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
+class DataLayoutIndexed
+{
+public:
+ DataLayoutIndexed(DataLayout dataLayout) : m_DataLayout(dataLayout)
+ {
+ switch (dataLayout)
+ {
+ case DataLayout::NHWC:
+ m_ChannelsIndex = 3;
+ m_HeightIndex = 1;
+ m_WidthIndex = 2;
+ break;
+ case DataLayout::NCHW:
+ m_ChannelsIndex = 1;
+ m_HeightIndex = 2;
+ m_WidthIndex = 3;
+ break;
+ default:
+ throw InvalidArgumentException("Unknown DataLayout value: " +
+ std::to_string(static_cast<int>(dataLayout)));
+ }
+ }
+
+ DataLayout GetDataLayout() const { return m_DataLayout; }
+ unsigned int GetChannelsIndex() const { return m_ChannelsIndex; }
+ unsigned int GetHeightIndex() const { return m_HeightIndex; }
+ unsigned int GetWidthIndex() const { return m_WidthIndex; }
+
+private:
+ DataLayout m_DataLayout;
+ unsigned int m_ChannelsIndex;
+ unsigned int m_HeightIndex;
+ unsigned int m_WidthIndex;
+};
+
+// Conversion methods - implementations in src/armnn/InternalTypes.cpp
+bool operator==(const DataLayout& dataLayout, const DataLayoutIndexed& indexed);
+bool operator==(const DataLayoutIndexed& indexed, const DataLayout& dataLayout);
+
+// End: DataLayout
+
enum class ActivationFunction
{
Sigmoid = 0,