aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/DataLayoutIndexed.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/DataLayoutIndexed.cpp')
-rw-r--r--src/armnnUtils/DataLayoutIndexed.cpp38
1 files changed, 36 insertions, 2 deletions
diff --git a/src/armnnUtils/DataLayoutIndexed.cpp b/src/armnnUtils/DataLayoutIndexed.cpp
index db27de4bdd..b02f07ec85 100644
--- a/src/armnnUtils/DataLayoutIndexed.cpp
+++ b/src/armnnUtils/DataLayoutIndexed.cpp
@@ -5,6 +5,8 @@
#include "DataLayoutIndexed.hpp"
+#include <boost/assert.hpp>
+
using namespace armnn;
namespace armnnUtils
@@ -31,13 +33,45 @@ DataLayoutIndexed::DataLayoutIndexed(armnn::DataLayout dataLayout)
}
}
-// Definition in include/armnn/Types.hpp
+unsigned int DataLayoutIndexed::GetIndex(const TensorShape& shape,
+ unsigned int batchIndex, unsigned int channelIndex,
+ unsigned int heightIndex, unsigned int widthIndex) const
+{
+ BOOST_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) );
+ BOOST_ASSERT( channelIndex < shape[m_ChannelsIndex] ||
+ ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) );
+ BOOST_ASSERT( heightIndex < shape[m_HeightIndex] ||
+ ( shape[m_HeightIndex] == 0 && heightIndex == 0) );
+ BOOST_ASSERT( widthIndex < shape[m_WidthIndex] ||
+ ( shape[m_WidthIndex] == 0 && widthIndex == 0) );
+
+ // Offset the given indices appropriately depending on the data layout
+ switch (m_DataLayout)
+ {
+ case DataLayout::NHWC:
+ batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
+ heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
+ widthIndex *= shape[m_ChannelsIndex];
+ // channelIndex stays unchanged
+ break;
+ case DataLayout::NCHW:
+ default:
+ batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
+ channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
+ heightIndex *= shape[m_WidthIndex];
+ // widthIndex stays unchanged
+ break;
+ }
+
+ // Get the value using the correct offset
+ return batchIndex + channelIndex + heightIndex + widthIndex;
+}
+
bool operator==(const DataLayout& dataLayout, const DataLayoutIndexed& indexed)
{
return dataLayout == indexed.GetDataLayout();
}
-// Definition in include/armnn/Types.hpp
bool operator==(const DataLayoutIndexed& indexed, const DataLayout& dataLayout)
{
return indexed.GetDataLayout() == dataLayout;