aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2018-10-16 16:17:34 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-22 16:57:54 +0100
commit97a06fd57e7864a882ef5e37a1bf7286f5be5185 (patch)
treede883e081b66c2a3fc6031f95133f252cd6828a1
parente6488719f58b5dd0e8e23d40b8a4c1337d07e9fd (diff)
downloadarmnn-97a06fd57e7864a882ef5e37a1bf7286f5be5185.tar.gz
IVGCVSW-2018 Support NHWC in the current ref implementation
* Added NHWC support to TensorBufferArrayView class Change-Id: I41e1d0acd226a471ec834e380389631d9236cb00
-rw-r--r--src/backends/reference/workloads/TensorBufferArrayView.hpp25
1 files changed, 21 insertions, 4 deletions
diff --git a/src/backends/reference/workloads/TensorBufferArrayView.hpp b/src/backends/reference/workloads/TensorBufferArrayView.hpp
index aba44e4593..b149073ab7 100644
--- a/src/backends/reference/workloads/TensorBufferArrayView.hpp
+++ b/src/backends/reference/workloads/TensorBufferArrayView.hpp
@@ -20,6 +20,7 @@ public:
, m_Data(data)
, m_DataLayout(dataLayout)
{
+ BOOST_ASSERT(m_Shape.GetNumDimensions() == 4);
}
DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const
@@ -32,10 +33,26 @@ public:
BOOST_ASSERT( w < m_Shape[m_DataLayout.GetWidthIndex()] ||
( m_Shape[m_DataLayout.GetWidthIndex()] == 0 && w == 0) );
- return m_Data[b * m_Shape[1] * m_Shape[2] * m_Shape[3]
- + c * m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()]
- + h * m_Shape[m_DataLayout.GetWidthIndex()]
- + w];
+ // Offset the given indices appropriately depending on the data layout.
+ switch (m_DataLayout.GetDataLayout())
+ {
+ case DataLayout::NHWC:
+ b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index;
+ h *= m_Shape[m_DataLayout.GetWidthIndex()] * m_Shape[m_DataLayout.GetChannelsIndex()];
+ w *= m_Shape[m_DataLayout.GetChannelsIndex()];
+ // c stays unchanged
+ break;
+ case DataLayout::NCHW:
+ default:
+ b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index;
+ c *= m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()];
+ h *= m_Shape[m_DataLayout.GetWidthIndex()];
+ // w stays unchanged
+ break;
+ }
+
+ // Get the value using the correct offset.
+ return m_Data[b + c + h + w];
}
private: