diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-06-05 09:02:41 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-06-05 09:10:50 +0100 |
commit | ee423cee7d7753790d0d82c5c2fd12a262b412a2 (patch) | |
tree | d3bfacd18fb1fc2d70af98845ce1e9089642581c /src/backends/reference/workloads | |
parent | 286080f0d4c4f8a1ca174888f48475e3ec9ac797 (diff) | |
download | armnn-ee423cee7d7753790d0d82c5c2fd12a262b412a2.tar.gz |
IVGCVSW-3142 Refactor DataLayoutIndexed and TensorBufferArrayView
for convenience
* Added GetIndex method to DataLayoutIndexed
* Refactored TensorBufferArrayView::Get to use the new method
Change-Id: Iae08b2761bddeda9e935b25e6bc4985f2d386cd3
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r-- | src/backends/reference/workloads/TensorBufferArrayView.hpp | 29 |
1 files changed, 1 insertions, 28 deletions
diff --git a/src/backends/reference/workloads/TensorBufferArrayView.hpp b/src/backends/reference/workloads/TensorBufferArrayView.hpp index aecec6757a..c06407241d 100644 --- a/src/backends/reference/workloads/TensorBufferArrayView.hpp +++ b/src/backends/reference/workloads/TensorBufferArrayView.hpp @@ -30,34 +30,7 @@ public: DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const { - BOOST_ASSERT( b < m_Shape[0] || ( m_Shape[0] == 0 && b == 0 ) ); - BOOST_ASSERT( c < m_Shape[m_DataLayout.GetChannelsIndex()] || - ( m_Shape[m_DataLayout.GetChannelsIndex()] == 0 && c == 0) ); - BOOST_ASSERT( h < m_Shape[m_DataLayout.GetHeightIndex()] || - ( m_Shape[m_DataLayout.GetHeightIndex()] == 0 && h == 0) ); - BOOST_ASSERT( w < m_Shape[m_DataLayout.GetWidthIndex()] || - ( m_Shape[m_DataLayout.GetWidthIndex()] == 0 && w == 0) ); - - // Offset the given indices appropriately depending on the data layout. - switch (m_DataLayout.GetDataLayout()) - { - case DataLayout::NHWC: - b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index; - h *= m_Shape[m_DataLayout.GetWidthIndex()] * m_Shape[m_DataLayout.GetChannelsIndex()]; - w *= m_Shape[m_DataLayout.GetChannelsIndex()]; - // c stays unchanged - break; - case DataLayout::NCHW: - default: - b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index; - c *= m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()]; - h *= m_Shape[m_DataLayout.GetWidthIndex()]; - // w stays unchanged - break; - } - - // Get the value using the correct offset. - return m_Data[b + c + h + w]; + return m_Data[m_DataLayout.GetIndex(m_Shape, b, c, h, w)]; } private: |