aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/TensorBufferArrayView.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/TensorBufferArrayView.hpp')
-rw-r--r--src/backends/reference/workloads/TensorBufferArrayView.hpp21
1 files changed, 13 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/TensorBufferArrayView.hpp b/src/backends/reference/workloads/TensorBufferArrayView.hpp
index e19810ca87..aba44e4593 100644
--- a/src/backends/reference/workloads/TensorBufferArrayView.hpp
+++ b/src/backends/reference/workloads/TensorBufferArrayView.hpp
@@ -15,28 +15,33 @@ template <typename DataType>
class TensorBufferArrayView
{
public:
- TensorBufferArrayView(const TensorShape& shape, DataType* data)
+ TensorBufferArrayView(const TensorShape& shape, DataType* data, DataLayoutIndexed dataLayout = DataLayout::NCHW)
: m_Shape(shape)
, m_Data(data)
+ , m_DataLayout(dataLayout)
{
}
DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const
{
- BOOST_ASSERT( b < m_Shape[0] || (m_Shape[0] == 0 && b == 0) );
- BOOST_ASSERT( c < m_Shape[1] || (m_Shape[1] == 0 && c == 0) );
- BOOST_ASSERT( h < m_Shape[2] || (m_Shape[2] == 0 && h == 0) );
- BOOST_ASSERT( w < m_Shape[3] || (m_Shape[3] == 0 && w == 0) );
+ BOOST_ASSERT( b < m_Shape[0] || ( m_Shape[0] == 0 && b == 0 ) );
+ BOOST_ASSERT( c < m_Shape[m_DataLayout.GetChannelsIndex()] ||
+ ( m_Shape[m_DataLayout.GetChannelsIndex()] == 0 && c == 0) );
+ BOOST_ASSERT( h < m_Shape[m_DataLayout.GetHeightIndex()] ||
+ ( m_Shape[m_DataLayout.GetHeightIndex()] == 0 && h == 0) );
+ BOOST_ASSERT( w < m_Shape[m_DataLayout.GetWidthIndex()] ||
+ ( m_Shape[m_DataLayout.GetWidthIndex()] == 0 && w == 0) );
return m_Data[b * m_Shape[1] * m_Shape[2] * m_Shape[3]
- + c * m_Shape[2] * m_Shape[3]
- + h * m_Shape[3]
+ + c * m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()]
+ + h * m_Shape[m_DataLayout.GetWidthIndex()]
+ w];
}
private:
const TensorShape m_Shape;
- DataType* m_Data;
+ DataType* m_Data;
+ DataLayoutIndexed m_DataLayout;
};
} //namespace armnn