diff options
Diffstat (limited to 'src/backends/reference/workloads/TensorBufferArrayView.hpp')
-rw-r--r-- | src/backends/reference/workloads/TensorBufferArrayView.hpp | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/TensorBufferArrayView.hpp b/src/backends/reference/workloads/TensorBufferArrayView.hpp index e19810ca87..aba44e4593 100644 --- a/src/backends/reference/workloads/TensorBufferArrayView.hpp +++ b/src/backends/reference/workloads/TensorBufferArrayView.hpp @@ -15,28 +15,33 @@ template <typename DataType> class TensorBufferArrayView { public: - TensorBufferArrayView(const TensorShape& shape, DataType* data) + TensorBufferArrayView(const TensorShape& shape, DataType* data, DataLayoutIndexed dataLayout = DataLayout::NCHW) : m_Shape(shape) , m_Data(data) + , m_DataLayout(dataLayout) { } DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const { - BOOST_ASSERT( b < m_Shape[0] || (m_Shape[0] == 0 && b == 0) ); - BOOST_ASSERT( c < m_Shape[1] || (m_Shape[1] == 0 && c == 0) ); - BOOST_ASSERT( h < m_Shape[2] || (m_Shape[2] == 0 && h == 0) ); - BOOST_ASSERT( w < m_Shape[3] || (m_Shape[3] == 0 && w == 0) ); + BOOST_ASSERT( b < m_Shape[0] || ( m_Shape[0] == 0 && b == 0 ) ); + BOOST_ASSERT( c < m_Shape[m_DataLayout.GetChannelsIndex()] || + ( m_Shape[m_DataLayout.GetChannelsIndex()] == 0 && c == 0) ); + BOOST_ASSERT( h < m_Shape[m_DataLayout.GetHeightIndex()] || + ( m_Shape[m_DataLayout.GetHeightIndex()] == 0 && h == 0) ); + BOOST_ASSERT( w < m_Shape[m_DataLayout.GetWidthIndex()] || + ( m_Shape[m_DataLayout.GetWidthIndex()] == 0 && w == 0) ); return m_Data[b * m_Shape[1] * m_Shape[2] * m_Shape[3] - + c * m_Shape[2] * m_Shape[3] - + h * m_Shape[3] + + c * m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()] + + h * m_Shape[m_DataLayout.GetWidthIndex()] + w]; } private: const TensorShape m_Shape; - DataType* m_Data; + DataType* m_Data; + DataLayoutIndexed m_DataLayout; }; } //namespace armnn |