ArmNN
 24.02
DataLayoutIndexed.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2018-2021,2023 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Types.hpp>
9 #include <armnn/Tensor.hpp>
10 
11 #include <armnn/utility/Assert.hpp>
12 
13 namespace armnnUtils
14 {
15 
16 /// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
18 {
19 public:
21 
22  armnn::DataLayout GetDataLayout() const { return m_DataLayout; }
23  unsigned int GetChannelsIndex() const { return m_ChannelsIndex; }
24  unsigned int GetHeightIndex() const { return m_HeightIndex; }
25  unsigned int GetWidthIndex() const { return m_WidthIndex; }
26  unsigned int GetDepthIndex() const { return m_DepthIndex; }
27 
28  inline unsigned int GetIndex(const armnn::TensorShape& shape,
29  unsigned int batchIndex, unsigned int channelIndex,
30  unsigned int heightIndex, unsigned int widthIndex) const
31  {
32  if (batchIndex >= shape[0] && !( shape[0] == 0 && batchIndex == 0))
33  {
34  throw armnn::Exception("Unable to get batch index", CHECK_LOCATION());
35  }
36  if (channelIndex >= shape[m_ChannelsIndex] &&
37  !(shape[m_ChannelsIndex] == 0 && channelIndex == 0))
38  {
39  throw armnn::Exception("Unable to get channel index", CHECK_LOCATION());
40 
41  }
42  if (heightIndex >= shape[m_HeightIndex] &&
43  !( shape[m_HeightIndex] == 0 && heightIndex == 0))
44  {
45  throw armnn::Exception("Unable to get height index", CHECK_LOCATION());
46  }
47  if (widthIndex >= shape[m_WidthIndex] &&
48  ( shape[m_WidthIndex] == 0 && widthIndex == 0))
49  {
50  throw armnn::Exception("Unable to get width index", CHECK_LOCATION());
51  }
52 
53  /// Offset the given indices appropriately depending on the data layout
54  switch (m_DataLayout)
55  {
57  batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
58  heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
59  widthIndex *= shape[m_ChannelsIndex];
60  /// channelIndex stays unchanged
61  break;
63  default:
64  batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
65  channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
66  heightIndex *= shape[m_WidthIndex];
67  /// widthIndex stays unchanged
68  break;
69  }
70 
71  /// Get the value using the correct offset
72  return batchIndex + channelIndex + heightIndex + widthIndex;
73  }
74 
75 private:
76  armnn::DataLayout m_DataLayout;
77  unsigned int m_ChannelsIndex;
78  unsigned int m_HeightIndex;
79  unsigned int m_WidthIndex;
80  unsigned int m_DepthIndex;
81 };
82 
83 /// Equality methods
84 bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed);
85 bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout);
86 
87 } // namespace armnnUtils
armnn::DataLayout
DataLayout
Definition: Types.hpp:62
armnn::DataLayout::NHWC
@ NHWC
CHECK_LOCATION
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203
armnnUtils::DataLayoutIndexed
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout.
Definition: DataLayoutIndexed.hpp:17
armnnUtils::DataLayoutIndexed::GetDataLayout
armnn::DataLayout GetDataLayout() const
Definition: DataLayoutIndexed.hpp:22
armnnUtils::DataLayoutIndexed::GetHeightIndex
unsigned int GetHeightIndex() const
Definition: DataLayoutIndexed.hpp:24
Assert.hpp
armnn::TensorShape
Definition: Tensor.hpp:20
armnnUtils
Definition: CompatibleTypes.hpp:10
armnnUtils::operator==
bool operator==(const armnn::DataLayout &dataLayout, const DataLayoutIndexed &indexed)
Equality methods.
Definition: DataLayoutIndexed.cpp:46
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
armnnUtils::DataLayoutIndexed::GetWidthIndex
unsigned int GetWidthIndex() const
Definition: DataLayoutIndexed.hpp:25
Tensor.hpp
Types.hpp
armnnUtils::DataLayoutIndexed::GetChannelsIndex
unsigned int GetChannelsIndex() const
Definition: DataLayoutIndexed.hpp:23
armnnUtils::DataLayoutIndexed::DataLayoutIndexed
DataLayoutIndexed(armnn::DataLayout dataLayout)
Definition: DataLayoutIndexed.cpp:13
armnnUtils::DataLayoutIndexed::GetIndex
unsigned int GetIndex(const armnn::TensorShape &shape, unsigned int batchIndex, unsigned int channelIndex, unsigned int heightIndex, unsigned int widthIndex) const
Definition: DataLayoutIndexed.hpp:28
armnnUtils::DataLayoutIndexed::GetDepthIndex
unsigned int GetDepthIndex() const
Definition: DataLayoutIndexed.hpp:26
armnn::DataLayout::NCHW
@ NCHW