From 97a06fd57e7864a882ef5e37a1bf7286f5be5185 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Tue, 16 Oct 2018 16:17:34 +0100 Subject: IVGCVSW-2018 Support NHWC in the current ref implementation * Added NHWC support to TensorBufferArrayView class Change-Id: I41e1d0acd226a471ec834e380389631d9236cb00 --- .../reference/workloads/TensorBufferArrayView.hpp | 25 ++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/backends/reference/workloads/TensorBufferArrayView.hpp b/src/backends/reference/workloads/TensorBufferArrayView.hpp index aba44e4593..b149073ab7 100644 --- a/src/backends/reference/workloads/TensorBufferArrayView.hpp +++ b/src/backends/reference/workloads/TensorBufferArrayView.hpp @@ -20,6 +20,7 @@ public: , m_Data(data) , m_DataLayout(dataLayout) { + BOOST_ASSERT(m_Shape.GetNumDimensions() == 4); } DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const @@ -32,10 +33,26 @@ public: BOOST_ASSERT( w < m_Shape[m_DataLayout.GetWidthIndex()] || ( m_Shape[m_DataLayout.GetWidthIndex()] == 0 && w == 0) ); - return m_Data[b * m_Shape[1] * m_Shape[2] * m_Shape[3] - + c * m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()] - + h * m_Shape[m_DataLayout.GetWidthIndex()] - + w]; + // Offset the given indices appropriately depending on the data layout. + switch (m_DataLayout.GetDataLayout()) + { + case DataLayout::NHWC: + b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index; + h *= m_Shape[m_DataLayout.GetWidthIndex()] * m_Shape[m_DataLayout.GetChannelsIndex()]; + w *= m_Shape[m_DataLayout.GetChannelsIndex()]; + // c stays unchanged + break; + case DataLayout::NCHW: + default: + b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index; + c *= m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()]; + h *= m_Shape[m_DataLayout.GetWidthIndex()]; + // w stays unchanged + break; + } + + // Get the value using the correct offset. + return m_Data[b + c + h + w]; } private: -- cgit v1.2.1