diff options
Diffstat (limited to 'src/armnnUtils/DataLayoutIndexed.cpp')
-rw-r--r-- | src/armnnUtils/DataLayoutIndexed.cpp | 38 |
1 files changed, 36 insertions, 2 deletions
diff --git a/src/armnnUtils/DataLayoutIndexed.cpp b/src/armnnUtils/DataLayoutIndexed.cpp index db27de4bdd..b02f07ec85 100644 --- a/src/armnnUtils/DataLayoutIndexed.cpp +++ b/src/armnnUtils/DataLayoutIndexed.cpp @@ -5,6 +5,8 @@ #include "DataLayoutIndexed.hpp" +#include <boost/assert.hpp> + using namespace armnn; namespace armnnUtils @@ -31,13 +33,45 @@ DataLayoutIndexed::DataLayoutIndexed(armnn::DataLayout dataLayout) } } -// Definition in include/armnn/Types.hpp +unsigned int DataLayoutIndexed::GetIndex(const TensorShape& shape, + unsigned int batchIndex, unsigned int channelIndex, + unsigned int heightIndex, unsigned int widthIndex) const +{ + BOOST_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) ); + BOOST_ASSERT( channelIndex < shape[m_ChannelsIndex] || + ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) ); + BOOST_ASSERT( heightIndex < shape[m_HeightIndex] || + ( shape[m_HeightIndex] == 0 && heightIndex == 0) ); + BOOST_ASSERT( widthIndex < shape[m_WidthIndex] || + ( shape[m_WidthIndex] == 0 && widthIndex == 0) ); + + // Offset the given indices appropriately depending on the data layout + switch (m_DataLayout) + { + case DataLayout::NHWC: + batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex + heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex]; + widthIndex *= shape[m_ChannelsIndex]; + // channelIndex stays unchanged + break; + case DataLayout::NCHW: + default: + batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex + channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex]; + heightIndex *= shape[m_WidthIndex]; + // widthIndex stays unchanged + break; + } + + // Get the value using the correct offset + return batchIndex + channelIndex + heightIndex + widthIndex; +} + bool operator==(const DataLayout& dataLayout, const DataLayoutIndexed& indexed) { return dataLayout == indexed.GetDataLayout(); } -// Definition in include/armnn/Types.hpp bool operator==(const DataLayoutIndexed& indexed, const DataLayout& dataLayout) { return indexed.GetDataLayout() == dataLayout; |