From ee423cee7d7753790d0d82c5c2fd12a262b412a2 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Wed, 5 Jun 2019 09:02:41 +0100 Subject: IVGCVSW-3142 Refactor DataLayoutIndexed and TensorBufferArrayView for convenience * Added GetIndex method to DataLayoutIndexed * Refactored TensorBufferArrayView::Get to use the new method Change-Id: Iae08b2761bddeda9e935b25e6bc4985f2d386cd3 Signed-off-by: Matteo Martincigh --- src/armnnUtils/DataLayoutIndexed.cpp | 38 ++++++++++++++++++++-- src/armnnUtils/DataLayoutIndexed.hpp | 6 ++++ .../reference/workloads/TensorBufferArrayView.hpp | 29 +---------------- 3 files changed, 43 insertions(+), 30 deletions(-) diff --git a/src/armnnUtils/DataLayoutIndexed.cpp b/src/armnnUtils/DataLayoutIndexed.cpp index db27de4bdd..b02f07ec85 100644 --- a/src/armnnUtils/DataLayoutIndexed.cpp +++ b/src/armnnUtils/DataLayoutIndexed.cpp @@ -5,6 +5,8 @@ #include "DataLayoutIndexed.hpp" +#include + using namespace armnn; namespace armnnUtils @@ -31,13 +33,45 @@ DataLayoutIndexed::DataLayoutIndexed(armnn::DataLayout dataLayout) } } -// Definition in include/armnn/Types.hpp +unsigned int DataLayoutIndexed::GetIndex(const TensorShape& shape, + unsigned int batchIndex, unsigned int channelIndex, + unsigned int heightIndex, unsigned int widthIndex) const +{ + BOOST_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) ); + BOOST_ASSERT( channelIndex < shape[m_ChannelsIndex] || + ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) ); + BOOST_ASSERT( heightIndex < shape[m_HeightIndex] || + ( shape[m_HeightIndex] == 0 && heightIndex == 0) ); + BOOST_ASSERT( widthIndex < shape[m_WidthIndex] || + ( shape[m_WidthIndex] == 0 && widthIndex == 0) ); + + // Offset the given indices appropriately depending on the data layout + switch (m_DataLayout) + { + case DataLayout::NHWC: + batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex + heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex]; + widthIndex *= shape[m_ChannelsIndex]; + // channelIndex stays unchanged + break; + case DataLayout::NCHW: + default: + batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex + channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex]; + heightIndex *= shape[m_WidthIndex]; + // widthIndex stays unchanged + break; + } + + // Get the value using the correct offset + return batchIndex + channelIndex + heightIndex + widthIndex; +} + bool operator==(const DataLayout& dataLayout, const DataLayoutIndexed& indexed) { return dataLayout == indexed.GetDataLayout(); } -// Definition in include/armnn/Types.hpp bool operator==(const DataLayoutIndexed& indexed, const DataLayout& dataLayout) { return indexed.GetDataLayout() == dataLayout; diff --git a/src/armnnUtils/DataLayoutIndexed.hpp b/src/armnnUtils/DataLayoutIndexed.hpp index 1cf2a09e32..5bb8e0d93f 100644 --- a/src/armnnUtils/DataLayoutIndexed.hpp +++ b/src/armnnUtils/DataLayoutIndexed.hpp @@ -2,8 +2,11 @@ // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // + #pragma once + #include +#include namespace armnnUtils { @@ -18,6 +21,9 @@ public: unsigned int GetChannelsIndex() const { return m_ChannelsIndex; } unsigned int GetHeightIndex() const { return m_HeightIndex; } unsigned int GetWidthIndex() const { return m_WidthIndex; } + unsigned int GetIndex(const armnn::TensorShape& shape, + unsigned int batchIndex, unsigned int channelIndex, + unsigned int heightIndex, unsigned int widthIndex) const; private: armnn::DataLayout m_DataLayout; diff --git a/src/backends/reference/workloads/TensorBufferArrayView.hpp b/src/backends/reference/workloads/TensorBufferArrayView.hpp index aecec6757a..c06407241d 100644 --- a/src/backends/reference/workloads/TensorBufferArrayView.hpp +++ b/src/backends/reference/workloads/TensorBufferArrayView.hpp @@ -30,34 +30,7 @@ public: DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const { - BOOST_ASSERT( b < m_Shape[0] || ( m_Shape[0] == 0 && b == 0 ) ); - BOOST_ASSERT( c < m_Shape[m_DataLayout.GetChannelsIndex()] || - ( m_Shape[m_DataLayout.GetChannelsIndex()] == 0 && c == 0) ); - BOOST_ASSERT( h < m_Shape[m_DataLayout.GetHeightIndex()] || - ( m_Shape[m_DataLayout.GetHeightIndex()] == 0 && h == 0) ); - BOOST_ASSERT( w < m_Shape[m_DataLayout.GetWidthIndex()] || - ( m_Shape[m_DataLayout.GetWidthIndex()] == 0 && w == 0) ); - - // Offset the given indices appropriately depending on the data layout. - switch (m_DataLayout.GetDataLayout()) - { - case DataLayout::NHWC: - b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index; - h *= m_Shape[m_DataLayout.GetWidthIndex()] * m_Shape[m_DataLayout.GetChannelsIndex()]; - w *= m_Shape[m_DataLayout.GetChannelsIndex()]; - // c stays unchanged - break; - case DataLayout::NCHW: - default: - b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index; - c *= m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()]; - h *= m_Shape[m_DataLayout.GetWidthIndex()]; - // w stays unchanged - break; - } - - // Get the value using the correct offset. - return m_Data[b + c + h + w]; + return m_Data[m_DataLayout.GetIndex(m_Shape, b, c, h, w)]; } private: -- cgit v1.2.1