ArmNN
 21.11
TensorBufferArrayView.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Tensor.hpp>
9 
11 
12 #include <armnn/utility/Assert.hpp>
13 
14 namespace armnn
15 {
16 
17 // Utility class providing access to raw tensor memory based on indices along each dimension.
18 template <typename DataType>
20 {
21 public:
24  : m_Shape(shape)
25  , m_Data(data)
26  , m_DataLayout(dataLayout)
27  {
28  ARMNN_ASSERT(m_Shape.GetNumDimensions() == 4);
29  }
30 
31  DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const
32  {
33  return m_Data[m_DataLayout.GetIndex(m_Shape, b, c, h, w)];
34  }
35 
36 private:
37  const TensorShape m_Shape;
38  DataType* m_Data;
39  armnnUtils::DataLayoutIndexed m_DataLayout;
40 };
41 
42 } //namespace armnn
Copyright (c) 2021 ARM Limited and Contributors.
DataType & Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const
DataType
Definition: Types.hpp:35
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout...
unsigned int GetIndex(const armnn::TensorShape &shape, unsigned int batchIndex, unsigned int channelIndex, unsigned int heightIndex, unsigned int widthIndex) const
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition: Tensor.cpp:174
TensorBufferArrayView(const TensorShape &shape, DataType *data, armnnUtils::DataLayoutIndexed dataLayout=DataLayout::NCHW)