ArmNN
 22.08
BatchMatMulImpl.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BatchMatMulImpl.hpp"
7 
9 #include <armnn/Logging.hpp>
10 
11 namespace armnn
12 {
13 
15 {
16  inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape());
17  inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape());
18  // At this point, we don't touch the input decoders - just the resultant vectors
19 
20  // Pre-transpose and pre-adjoint if their vectors aren't empty
21  // and also DataLayouts which may change with permutations/adjoints
22 
23  // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes?
24 
25  auto idx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
26  RecurseBMM(idx, 0);
27 }
28 
29 void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim)
30 {
31  // We're working off of the indexes of the output tensor (the max possible shape)
32 
33  if(!(curDim < outputInfo.GetNumDimensions()))
34  {
35  // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
36 
37  auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params,
38  inputXInfo.GetShape(),
39  inputYInfo.GetShape());
41 
42  unsigned int inputXColDim = axesToMul.first.second;
43  unsigned int inputYRowDim = axesToMul.second.first;
44 
45  unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
46 
47  float sum = 0.0f;
48 
49  // You could also use inputXColSize
50  for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
51  auto xIdx = curIdx;
52  xIdx[inputXColDim] = inputYRowIdx;
53 
54  auto yIdx = curIdx;
55  yIdx[inputYRowDim] = inputYRowIdx;
56 
57  sum += (GetValueAt(DataSlot::InputX, xIdx)
58  * GetValueAt(DataSlot::InputY, yIdx));
59  }
60 
61  SetValueAt(sum, DataSlot::Output, curIdx);
62 
63  return;
64  }
65 
66  for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++)
67  {
68  curIdx[curDim] = i;
69  RecurseBMM(curIdx, curDim+1);
70  }
71 }
72 
74  std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul)
75 {
76  int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) -
77  static_cast<int>(inputYInfo.GetNumDimensions());
78  if(rankDiff == 0)
79  {
80  return;
81  }
82  else if(rankDiff < 0)
83  {
84  // Y is the larger one
85  axesToMul.first.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
86  axesToMul.first.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
87  }
88  else if(rankDiff > 0)
89  {
90  // X is the larger one
91  axesToMul.second.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
92  axesToMul.second.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
93  }
94 }
95 
96 float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
97 {
98  // This gets the data from the input vector that we have, Not the decoder
99  // But for the output, it is operating on the encoder itself
100 
101  AdjustToSafeIdx(type, idx);
102  unsigned int flatIdx = CalcFlatIdx(type, idx);
103  float value = 0.0f;
104 
105  switch(type)
106  {
107  case DataSlot::InputX:
108  value = inputXData[flatIdx];
109  break;
110  case DataSlot::InputY:
111  value = inputYData[flatIdx];
112  break;
113  case DataSlot::Output:
114  outputEncoder[flatIdx];
115  value = outputEncoder.Get();
116  break;
117  default:
118  break;
119  }
120 
121  return value;
122 }
123 
124 void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
125 {
126  AdjustToSafeIdx(type, idx);
127 
128  unsigned int flatIdx = CalcFlatIdx(type, idx);
129 
130  switch(type)
131  {
132  case DataSlot::InputX:
133  inputXData[flatIdx] = value;
134  break;
135  case DataSlot::InputY:
136  inputYData[flatIdx] = value;
137  break;
138  case DataSlot::Output:
139  outputEncoder[flatIdx];
140  outputEncoder.Set(value);
141  break;
142  default:
143  break;
144  }
145 }
146 
147 void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
148 {
149  for(unsigned int dim = 0; dim < idx.size(); dim++)
150  {
151  switch(type)
152  {
153  case DataSlot::InputX:
154  {
155  auto xRank = inputXInfo.GetNumDimensions();
156  auto xDiff = outputInfo.GetNumDimensions() - xRank;
157  if (dim < xDiff ||
158  idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
159  {
160  idx[dim] = 0; // Broadcasting
161  }
162  break;
163  }
164  case DataSlot::InputY:
165  {
166  auto yRank = inputYInfo.GetNumDimensions();
167  auto yDiff = outputInfo.GetNumDimensions() - yRank;
168  if (dim < yDiff ||
169  idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
170  {
171  idx[dim] = 0;
172  }
173  break;
174  }
175  case DataSlot::Output:
176  {
177  // Our indices are based off the output
178  break;
179  }
180  default:
181  break;
182  }
183  }
184 }
185 
186 unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
187 {
188  unsigned int result = idx[idx.size()-1];
189 
190  unsigned int dimMultiplier = 1;
191 
192  unsigned int offset;
193 
194  // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
195  for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--)
196  {
197  switch(type)
198  {
199  case DataSlot::InputX:
200  offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
201  dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
202  break;
203  case DataSlot::InputY:
204  offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
205  dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
206  break;
207  case DataSlot::Output:
208  dimMultiplier *= outputInfo.GetShape()[i+1];
209  break;
210  default:
211  break;
212  }
213  result += (idx[i] * dimMultiplier);
214  }
215  return result;
216 }
217 
218 template <typename T>
219 std::string BatchMatMul::StringifyVec(const std::vector<T>& vec)
220 {
221  std::string res = "{ ";
222  for(auto x : vec)
223  {
224  res += std::to_string(x);
225  res += " ";
226  }
227  res += "}";
228  return res;
229 }
230 
231 } // namespace armnn
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
unsigned int CalcFlatIdx(DataSlot type, const std::vector< unsigned int > &idx)
virtual std::vector< float > DecodeTensor(const TensorShape &tensorShape, bool isDepthwise=false)=0
virtual void Set(IType right)=0
void AdjustToSafeIdx(DataSlot type, std::vector< unsigned int > &idx)
Copyright (c) 2021 ARM Limited and Contributors.
std::string StringifyVec(const std::vector< T > &vec)
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)
Static helper to get the two axes (for each input) for multiplication.
virtual IType Get() const =0
void RecurseBMM(std::vector< unsigned int > &curIdx, unsigned int curDim)
void AdjustAxesToMulForUnequalRanks(std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int >> &axesToMul)
void SetValueAt(float value, DataSlot type, std::vector< unsigned int > idx)
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:195
float GetValueAt(DataSlot type, std::vector< unsigned int > idx)