diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2018-10-16 16:17:34 +0100 |
---|---|---|
committer | Matthew Bentham <matthew.bentham@arm.com> | 2018-10-22 16:57:54 +0100 |
commit | 97a06fd57e7864a882ef5e37a1bf7286f5be5185 (patch) | |
tree | de883e081b66c2a3fc6031f95133f252cd6828a1 /src | |
parent | e6488719f58b5dd0e8e23d40b8a4c1337d07e9fd (diff) | |
download | armnn-97a06fd57e7864a882ef5e37a1bf7286f5be5185.tar.gz |
IVGCVSW-2018 Support NHWC in the current ref implementation
* Added NHWC support to TensorBufferArrayView class
Change-Id: I41e1d0acd226a471ec834e380389631d9236cb00
Diffstat (limited to 'src')
-rw-r--r-- | src/backends/reference/workloads/TensorBufferArrayView.hpp | 25 |
1 files changed, 21 insertions, 4 deletions
diff --git a/src/backends/reference/workloads/TensorBufferArrayView.hpp b/src/backends/reference/workloads/TensorBufferArrayView.hpp index aba44e4593..b149073ab7 100644 --- a/src/backends/reference/workloads/TensorBufferArrayView.hpp +++ b/src/backends/reference/workloads/TensorBufferArrayView.hpp @@ -20,6 +20,7 @@ public: , m_Data(data) , m_DataLayout(dataLayout) { + BOOST_ASSERT(m_Shape.GetNumDimensions() == 4); } DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const @@ -32,10 +33,26 @@ public: BOOST_ASSERT( w < m_Shape[m_DataLayout.GetWidthIndex()] || ( m_Shape[m_DataLayout.GetWidthIndex()] == 0 && w == 0) ); - return m_Data[b * m_Shape[1] * m_Shape[2] * m_Shape[3] - + c * m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()] - + h * m_Shape[m_DataLayout.GetWidthIndex()] - + w]; + // Offset the given indices appropriately depending on the data layout. + switch (m_DataLayout.GetDataLayout()) + { + case DataLayout::NHWC: + b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index; + h *= m_Shape[m_DataLayout.GetWidthIndex()] * m_Shape[m_DataLayout.GetChannelsIndex()]; + w *= m_Shape[m_DataLayout.GetChannelsIndex()]; + // c stays unchanged + break; + case DataLayout::NCHW: + default: + b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index; + c *= m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()]; + h *= m_Shape[m_DataLayout.GetWidthIndex()]; + // w stays unchanged + break; + } + + // Get the value using the correct offset. + return m_Data[b + c + h + w]; } private: |