ArmNN
 21.02
Broadcast.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BaseIterator.hpp"
7 #include <armnn/Tensor.hpp>
8 
9 #include <functional>
10 
11 namespace armnn
12 {
13 
15 {
16  BroadcastLoop(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape);
17 
18  BroadcastLoop(const TensorShape& inShape, const TensorShape& outShape);
19 
20  unsigned int GetNumDimensions()
21  {
22  return static_cast<unsigned int>(m_DimData.size());
23  }
24 
25  template <typename Func, typename DecoderOp, typename EncoderOp>
26  void Unroll(Func operationFunc,
27  unsigned int dimension,
28  DecoderOp& inData0,
29  DecoderOp& inData1,
30  EncoderOp& outData)
31  {
32  if (dimension >= GetNumDimensions())
33  {
34  outData.Set(operationFunc(inData0.Get(), inData1.Get()));
35  return;
36  }
37 
38  unsigned int inData0Movement = 0;
39  unsigned int inData1Movement = 0;
40  unsigned int outDataMovement = 0;
41 
42  for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
43  {
44  Unroll(operationFunc, dimension + 1, inData0, inData1, outData);
45 
46  inData0 += m_DimData[dimension].m_Stride1;
47  inData1 += m_DimData[dimension].m_Stride2;
48  outData += m_DimData[dimension].m_StrideOut;
49 
50  inData0Movement += m_DimData[dimension].m_Stride1;
51  inData1Movement += m_DimData[dimension].m_Stride2;
52  outDataMovement += m_DimData[dimension].m_StrideOut;
53  }
54 
55  // move iterator back to the start
56  inData0 -= inData0Movement;
57  inData1 -= inData1Movement;
58  outData -= outDataMovement;
59  }
60 
61  template <typename Func, typename DecoderOp, typename EncoderOp>
62  void Unroll(Func operationFunc,
63  unsigned int dimension,
64  DecoderOp& inData,
65  EncoderOp& outData)
66  {
67  if (dimension >= GetNumDimensions())
68  {
69  outData.Set(operationFunc(inData.Get()));
70  return;
71  }
72 
73  unsigned int inDataMovement = 0;
74  unsigned int outDataMovement = 0;
75 
76  for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
77  {
78  Unroll(operationFunc, dimension + 1, inData, outData);
79 
80  inData += m_DimData[dimension].m_Stride1;
81  outData += m_DimData[dimension].m_StrideOut;
82 
83  inDataMovement += m_DimData[dimension].m_Stride1;
84  outDataMovement += m_DimData[dimension].m_StrideOut;
85  }
86 
87  // move iterator back to the start
88  inData -= inDataMovement;
89  outData -= outDataMovement;
90  }
91 
92 private:
93  // Struct to hold the dimension data.
94  struct BroadcastDimensionData
95  {
96  unsigned int m_DimSize;
97  unsigned int m_StrideOut;
98  unsigned int m_Stride1;
99  unsigned int m_Stride2;
100  };
101 
102  std::vector<BroadcastDimensionData> m_DimData;
103 };
104 
105 } //namespace armnn
Copyright (c) 2021 ARM Limited and Contributors.
BroadcastLoop(const TensorShape &inShape0, const TensorShape &inShape1, const TensorShape &outShape)
Definition: Broadcast.cpp:11
void Unroll(Func operationFunc, unsigned int dimension, DecoderOp &inData, EncoderOp &outData)
Definition: Broadcast.hpp:62
unsigned int GetNumDimensions()
Definition: Broadcast.hpp:20
void Unroll(Func operationFunc, unsigned int dimension, DecoderOp &inData0, DecoderOp &inData1, EncoderOp &outData)
Definition: Broadcast.hpp:26