ArmNN
 24.02
BatchToSpaceNd.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017-2020,2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BatchToSpaceNd.hpp"
7 
9 
10 using namespace armnnUtils;
11 
12 namespace armnn
13 {
14 
15 unsigned int Offset(const TensorShape& shape,
16  unsigned int batch,
17  unsigned int height,
18  unsigned int width,
19  unsigned int channels,
20  const DataLayoutIndexed& dataLayout)
21 {
22  // 3D Tensors
23  unsigned int channelDimension3D = dataLayout.GetDataLayout() == DataLayout::NCHW ? 1 : 2;
24  if (shape.GetNumDimensions() == 3)
25  {
26  return (batch * shape[dataLayout.GetHeightIndex()] + height) * shape[channelDimension3D] + channels;
27  }
28  // 4D Tensors
29  else if (shape.GetNumDimensions() == 4)
30  {
31  if (dataLayout.GetDataLayout() == DataLayout::NHWC)
32  {
33  return ((batch * shape[dataLayout.GetHeightIndex()] + height) *
34  shape[dataLayout.GetWidthIndex()] + width) *
35  shape[dataLayout.GetChannelsIndex()] + channels;
36  }
37  else
38  {
39  return ((batch * shape[dataLayout.GetChannelsIndex()] + channels) *
40  shape[dataLayout.GetHeightIndex()] + height) *
41  shape[dataLayout.GetWidthIndex()] + width;
42  }
43  }
44  else
45  {
46  throw InvalidArgumentException("Tensor rank must be either 3 or 4", CHECK_LOCATION());
47  }
48 }
49 
50 void BatchToSpaceNd(const TensorInfo& inputInfo,
51  const TensorInfo& outputInfo,
52  const BatchToSpaceNdDescriptor& params,
53  Decoder<float>& inputData,
54  Encoder<float>& outputData)
55 {
56  unsigned int rank = inputInfo.GetNumDimensions();
57  if (rank != 3 && rank != 4 )
58  {
59  throw InvalidArgumentException("Tensor rank must be either 3 or 4, but it is " + std::to_string(rank),
60  CHECK_LOCATION());
61  }
62 
63  DataLayoutIndexed dataLayout = params.m_DataLayout;
64  unsigned int channelDimension3D = params.m_DataLayout == DataLayout::NCHW ? 1 : 2;
65 
66  TensorShape inputShape = inputInfo.GetShape();
67  TensorShape outputShape = outputInfo.GetShape();
68 
69  const unsigned int inputBatchSize = inputShape[0];
70  const unsigned int outputBatchSize = outputShape[0];
71 
72  const unsigned int channels = (rank == 3) ? inputShape[channelDimension3D]
73  : inputShape[dataLayout.GetChannelsIndex()];
74 
75  const unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()];
76  const unsigned int inputWidth = (rank == 3) ? 1 : inputShape[dataLayout.GetWidthIndex()];
77  const unsigned int outputHeight = outputShape[dataLayout.GetHeightIndex()];
78  const unsigned int outputWidth = (rank == 3) ? 1 : outputShape[dataLayout.GetWidthIndex()];
79 
80  const unsigned int blockHeight = params.m_BlockShape[0];
81  const unsigned int blockWidth = (rank == 3) ? 1 : params.m_BlockShape[1];
82 
83  const unsigned int cropsTop = params.m_Crops[0].first;
84  const unsigned int cropsLeft = (rank == 3) ? 0 : params.m_Crops[1].first;
85 
86  for (unsigned int inBatch = 0; inBatch < inputBatchSize; ++inBatch)
87  {
88  const unsigned int outBatch = inBatch % outputBatchSize;
89  const unsigned int spatialOffset = inBatch / outputBatchSize;
90 
91  for (unsigned int inH = 0; inH < inputHeight; ++inH)
92  {
93  const unsigned int outH = inH * blockHeight + spatialOffset / blockWidth - cropsTop;
94 
95  if (outH >= outputHeight)
96  {
97  continue;
98  }
99 
100  for (unsigned int inW = 0; inW < inputWidth; ++inW)
101  {
102  const unsigned int outW = inW * blockWidth + spatialOffset % blockWidth - cropsLeft;
103 
104  if (outW >= outputWidth)
105  {
106  continue;
107  }
108 
109  for (unsigned int c = 0; c < channels; c++)
110  {
111  unsigned int outOffset = Offset(outputShape, outBatch, outH, outW, c, dataLayout);
112  unsigned int inOffset = Offset(inputShape, inBatch, inH, inW, c, dataLayout);
113 
114  outputData[outOffset];
115  inputData[inOffset];
116  outputData.Set(inputData.Get());
117  }
118  }
119  }
120  }
121 }
122 
123 } //namespace armnn
armnn::Decoder< float >
armnn::Encoder::Set
virtual void Set(IType right)=0
BatchToSpaceNd.hpp
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::TensorInfo::GetNumDimensions
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:197
armnn::BatchToSpaceNdDescriptor::m_BlockShape
std::vector< unsigned int > m_BlockShape
Block shape values.
Definition: Descriptors.hpp:898
CHECK_LOCATION
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203
armnnUtils::DataLayoutIndexed
Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout.
Definition: DataLayoutIndexed.hpp:17
armnnUtils::DataLayoutIndexed::GetDataLayout
armnn::DataLayout GetDataLayout() const
Definition: DataLayoutIndexed.hpp:22
armnn::BatchToSpaceNd
void BatchToSpaceNd(const TensorInfo &inputInfo, const TensorInfo &outputInfo, const BatchToSpaceNdDescriptor &params, Decoder< float > &inputData, Encoder< float > &outputData)
Definition: BatchToSpaceNd.cpp:50
armnn::BatchToSpaceNdDescriptor::m_Crops
std::vector< std::pair< unsigned int, unsigned int > > m_Crops
The values to crop from the input dimension.
Definition: Descriptors.hpp:900
armnnUtils::DataLayoutIndexed::GetHeightIndex
unsigned int GetHeightIndex() const
Definition: DataLayoutIndexed.hpp:24
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::Encoder< float >
armnn::TensorShape::GetNumDimensions
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition: Tensor.cpp:174
armnnUtils
Definition: CompatibleTypes.hpp:10
armnn::InvalidArgumentException
Definition: Exceptions.hpp:80
armnnUtils::DataLayoutIndexed::GetWidthIndex
unsigned int GetWidthIndex() const
Definition: DataLayoutIndexed.hpp:25
armnn::Decoder::Get
virtual IType Get() const =0
armnn::BatchToSpaceNdDescriptor
A BatchToSpaceNdDescriptor for the BatchToSpaceNdLayer.
Definition: Descriptors.hpp:875
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
armnn::Offset
unsigned int Offset(const TensorShape &shape, unsigned int batch, unsigned int height, unsigned int width, unsigned int channels, const DataLayoutIndexed &dataLayout)
Definition: BatchToSpaceNd.cpp:15
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnnUtils::DataLayoutIndexed::GetChannelsIndex
unsigned int GetChannelsIndex() const
Definition: DataLayoutIndexed.hpp:23
armnn::BatchToSpaceNdDescriptor::m_DataLayout
DataLayout m_DataLayout
The data layout to be used (NCHW, NHWC).
Definition: Descriptors.hpp:902
DataLayoutIndexed.hpp