aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/DataLayoutIndexed.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/DataLayoutIndexed.hpp')
-rw-r--r--src/armnnUtils/DataLayoutIndexed.hpp39
1 files changed, 36 insertions, 3 deletions
diff --git a/src/armnnUtils/DataLayoutIndexed.hpp b/src/armnnUtils/DataLayoutIndexed.hpp
index 5bb8e0d93f..8bd9701a5e 100644
--- a/src/armnnUtils/DataLayoutIndexed.hpp
+++ b/src/armnnUtils/DataLayoutIndexed.hpp
@@ -8,6 +8,8 @@
#include <armnn/Types.hpp>
#include <armnn/Tensor.hpp>
+#include <boost/assert.hpp>
+
namespace armnnUtils
{
@@ -21,9 +23,40 @@ public:
unsigned int GetChannelsIndex() const { return m_ChannelsIndex; }
unsigned int GetHeightIndex() const { return m_HeightIndex; }
unsigned int GetWidthIndex() const { return m_WidthIndex; }
- unsigned int GetIndex(const armnn::TensorShape& shape,
- unsigned int batchIndex, unsigned int channelIndex,
- unsigned int heightIndex, unsigned int widthIndex) const;
+
+ inline unsigned int GetIndex(const armnn::TensorShape& shape,
+ unsigned int batchIndex, unsigned int channelIndex,
+ unsigned int heightIndex, unsigned int widthIndex) const
+ {
+ BOOST_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) );
+ BOOST_ASSERT( channelIndex < shape[m_ChannelsIndex] ||
+ ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) );
+ BOOST_ASSERT( heightIndex < shape[m_HeightIndex] ||
+ ( shape[m_HeightIndex] == 0 && heightIndex == 0) );
+ BOOST_ASSERT( widthIndex < shape[m_WidthIndex] ||
+ ( shape[m_WidthIndex] == 0 && widthIndex == 0) );
+
+ // Offset the given indices appropriately depending on the data layout
+ switch (m_DataLayout)
+ {
+ case armnn::DataLayout::NHWC:
+ batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
+ heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
+ widthIndex *= shape[m_ChannelsIndex];
+ // channelIndex stays unchanged
+ break;
+ case armnn::DataLayout::NCHW:
+ default:
+ batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
+ channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
+ heightIndex *= shape[m_WidthIndex];
+ // widthIndex stays unchanged
+ break;
+ }
+
+ // Get the value using the correct offset
+ return batchIndex + channelIndex + heightIndex + widthIndex;
+ }
private:
armnn::DataLayout m_DataLayout;