ArmNN
 21.02
DataLayoutIndexed.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Types.hpp>
9 #include <armnn/Tensor.hpp>
10 
11 #include <armnn/utility/Assert.hpp>
12 
13 namespace armnnUtils
14 {
15 
16 /// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
18 {
19 public:
21 
22  armnn::DataLayout GetDataLayout() const { return m_DataLayout; }
23  unsigned int GetChannelsIndex() const { return m_ChannelsIndex; }
24  unsigned int GetHeightIndex() const { return m_HeightIndex; }
25  unsigned int GetWidthIndex() const { return m_WidthIndex; }
26 
27  inline unsigned int GetIndex(const armnn::TensorShape& shape,
28  unsigned int batchIndex, unsigned int channelIndex,
29  unsigned int heightIndex, unsigned int widthIndex) const
30  {
31  ARMNN_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) );
32  ARMNN_ASSERT( channelIndex < shape[m_ChannelsIndex] ||
33  ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) );
34  ARMNN_ASSERT( heightIndex < shape[m_HeightIndex] ||
35  ( shape[m_HeightIndex] == 0 && heightIndex == 0) );
36  ARMNN_ASSERT( widthIndex < shape[m_WidthIndex] ||
37  ( shape[m_WidthIndex] == 0 && widthIndex == 0) );
38 
39  /// Offset the given indices appropriately depending on the data layout
40  switch (m_DataLayout)
41  {
43  batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
44  heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
45  widthIndex *= shape[m_ChannelsIndex];
46  /// channelIndex stays unchanged
47  break;
49  default:
50  batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
51  channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
52  heightIndex *= shape[m_WidthIndex];
53  /// widthIndex stays unchanged
54  break;
55  }
56 
57  /// Get the value using the correct offset
58  return batchIndex + channelIndex + heightIndex + widthIndex;
59  }
60 
61 private:
62  armnn::DataLayout m_DataLayout;
63  unsigned int m_ChannelsIndex;
64  unsigned int m_HeightIndex;
65  unsigned int m_WidthIndex;
66 };
67 
68 /// Equality methods
69 bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed);
70 bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout);
71 
72 } // namespace armnnUtils
DataLayout
Definition: Types.hpp:50
unsigned int GetWidthIndex() const
unsigned int GetHeightIndex() const
bool operator==(const armnn::DataLayout &dataLayout, const DataLayoutIndexed &indexed)
Equality methods.
DataLayoutIndexed(armnn::DataLayout dataLayout)
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout...
unsigned int GetIndex(const armnn::TensorShape &shape, unsigned int batchIndex, unsigned int channelIndex, unsigned int heightIndex, unsigned int widthIndex) const
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
armnn::DataLayout GetDataLayout() const
unsigned int GetChannelsIndex() const