ArmNN
 22.11
BatchMatMulImpl.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BatchMatMulImpl.hpp"
7 
9 #include <armnn/Logging.hpp>
10 #include <armnnUtils/Permute.hpp>
11 
12 namespace armnn
13 {
14 
16  const TensorInfo& inputXInfo,
17  const TensorInfo& inputYInfo,
18  const TensorInfo& outputInfo,
19  Decoder<float>& inputXDecoder,
20  Decoder<float>& inputYDecoder,
21  Encoder<float>& outputEncoder)
22  : params(params),
23  inputXInfo(inputXInfo),
24  inputYInfo(inputYInfo),
25  outputInfo(outputInfo),
26  inputXDecoder(inputXDecoder),
27  inputYDecoder(inputYDecoder),
28  outputEncoder(outputEncoder)
29 {
30  inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape());
31  inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape());
32  // At this point, we don't touch the input decoders - just the resultant vectors
33 
34  ApplyParams();
35 
36  ApplyBatchMatMul();
37 }
38 
39 void BatchMatMul::ApplyBatchMatMul()
40 {
41  auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutX,
42  inputXInfo.GetShape());
43  auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY,
44  inputYInfo.GetShape());
45  AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
46 
47  unsigned int inputXColDim = axesXToMul.second;
48  unsigned int inputYRowDim = axesYToMul.first;
49 
50  unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
51 
52  auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx)
53  {
54  float sum = 0.0f;
55 
56  // InputYRowSize is synonymous with inputXColSize
57  for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
58  auto xIdx = curIdx;
59  xIdx[inputXColDim] = inputYRowIdx;
60 
61  auto yIdx = curIdx;
62  yIdx[inputYRowDim] = inputYRowIdx;
63 
64  sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
65  }
66 
67  SetValueAt(sum, DataSlot::Output, curIdx);
68  };
69 
70  auto startIdx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
71  RecurseTensor(outputInfo,
72  batchMatMulOperation,
73  startIdx,
74  0);
75 }
76 
77 void BatchMatMul::ApplyParams()
78 {
79  if(params.m_TransposeX)
80  {
81  Transpose(DataSlot::InputX);
82  }
83  else if(params.m_AdjointX)
84  {
85  Adjoint(DataSlot::InputX);
86  }
87  if(params.m_TransposeY)
88  {
89  Transpose(DataSlot::InputY);
90  }
91  else if(params.m_AdjointY)
92  {
93  Adjoint(DataSlot::InputY);
94  }
95 }
96 
97 void BatchMatMul::Transpose(DataSlot type)
98 {
99  // AKA the permute of the tensor
100  // This modifies the tensor's info.
101 
102  switch(type)
103  {
104  case DataSlot::InputX:
105  {
106  auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX,
107  inputXInfo.GetShape());
108  inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec);
109  std::vector<float> temp(inputXData.size());
110  armnnUtils::Permute(inputXInfo.GetShape(),
111  permuteVec,
112  inputXData.data(),
113  temp.data(),
114  sizeof(float));
115  inputXData = temp;
116  break;
117  }
118  case DataSlot::InputY:
119  {
120  auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY,
121  inputYInfo.GetShape());
122  inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec);
123  std::vector<float> temp(inputYData.size());
124  armnnUtils::Permute(inputYInfo.GetShape(),
125  permuteVec,
126  inputYData.data(),
127  temp.data(),
128  sizeof(float));
129  inputYData = temp;
130  break;
131  }
132  case DataSlot::Output: // We needn't transpose the output tensor
133  default:
134  break;
135  }
136 }
137 
138 void BatchMatMul::Adjoint(DataSlot type)
139 {
140  // Finding the adjoint of a square matrix:
141  // Calculate the cofactor of each element (using Gauss elimination here)
142  // Apply a transpose to it (this also modifies the tensor's info)
143 
144  TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
145  const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY;
146  const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape());
147 
148  ARMNN_ASSERT(inputInfo.GetShape()[axesToAdjoint.first] == inputInfo.GetShape()[axesToAdjoint.second]);
149  // We grab a copy of the tensor data to prevent overwriting
150  std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
151 
152  // The sub-matrix is the resultant matrix when the row and column of the current index is removed
153  unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
154  std::vector<std::vector<float>> subMat(subMatAxisSize,
155  std::vector<float>(subMatAxisSize));
156 
157  // Lambdas for each sub-step of the cofactor operation
158  auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f)
159  {
160  float diff = std::fabs(a-b);
161  float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
162  return (diff <= bound) || (diff < std::numeric_limits<float>::min());
163  };
164 
165  float swapMultiplier = std::numeric_limits<float>::max();
166  auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB)
167  {
168  // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run)
169  for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
170  {
171  float tmp = subMat[rowIdxA][colIdx];
172  subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
173  subMat[rowIdxB][colIdx] = tmp;
174  }
175  swapMultiplier *= -1.0f;
176  };
177 
178  auto findNextValidPivotRowIdx = [&](unsigned int colIdx)
179  {
180  unsigned int result = std::numeric_limits<unsigned int>::max();
181 
182  // The original diagonal has been checked and is invalid
183  for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
184  {
185  if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
186  {
187  result = rowIdx;
188  break;
189  }
190  }
191  return result;
192  };
193 
194  auto eliminate = [&](const float& pivot, unsigned int pivotPos)
195  {
196  for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
197  {
198  float multiplierNumerator = subMat[rowIdx][pivotPos];
199  if(almostEquals(multiplierNumerator, 0.0f))
200  {
201  continue;
202  }
203  float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies
204  // Hence the almostEquals usage to counteract this
205  for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
206  {
207  // We start at col=pivotPos as we have assumed that all elements
208  // to our left have been eliminated to zero already
209 
210  // We subtract based on the element directly above us in our pivot row
211  subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
212  }
213  }
214  };
215 
216  auto cofactorOperation = [&](const std::vector<unsigned int>& curIdx)
217  {
218  auto row = curIdx[axesToAdjoint.first];
219  auto col = curIdx[axesToAdjoint.second];
220 
221  float minorMultiplier = static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
222 
223  for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
224  {
225  for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
226  {
227  unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
228  unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
229  auto cloneIdx = curIdx;
230  cloneIdx[axesToAdjoint.first] = outerRow;
231  cloneIdx[axesToAdjoint.second] = outerCol;
232  subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
233  }
234  }
235 
236  float determinant = 1.0f;
237 
238  // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices
239  switch(subMatAxisSize)
240  {
241  case 0:
242  {
243  determinant = GetValueAt(type, curIdx, inputDataClone);
244  break;
245  }
246  case 1:
247  {
248  // If the resultant sub-matrix is just one element - that's the determinant
249  determinant = subMat[0][0];
250  break;
251  }
252  case 2:
253  {
254  // For a 2x2 sub-matrix, the determinant is just a*d-b*c
255  determinant = subMat[0][0] * subMat[1][1] -
256  subMat[0][1] * subMat[1][0];
257  break;
258  }
259  default:
260  {
261  // Gaussian elimination to find the determinant of this sub-matrix
262  swapMultiplier = 1.0f;
263  // March diagonally down the pivots and if it's invalid (a zero), swap the row with the
264  // nearest non-zero down within the column
265  for(unsigned int pivotRow = 0, pivotCol = 0;
266  pivotRow < subMatAxisSize;
267  pivotRow++, pivotCol++)
268  {
269  float& pivot = subMat[pivotRow][pivotCol];
270 
271  if(almostEquals(pivot, 0.0f))
272  {
273  unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
274  if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
275  {
276  // No valid pivot down this column, which means that this pivot remains a zero.
277  // This results in the determinant for this entire sub-matrix to just be zero.
278  determinant = 0.0f;
279  break;
280  }
281  swapRows(pivotRow, nextValidPivotRowIdx);
282  }
283  determinant *= pivot;
284  // The actual elimination bit (which will update/propagate to the pivots down the line)
285  eliminate(pivot, pivotRow); // Synonymous with pivotCol
286  }
287 
288  determinant *= swapMultiplier;
289  break;
290  }
291  }
292  float cofactor = minorMultiplier * determinant;
293  SetValueAt(cofactor, type, curIdx);
294  };
295 
296  auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
297  RecurseTensor(inputInfo,
298  cofactorOperation,
299  startIdx,
300  0);
301 
302  Transpose(type);
303 }
304 
305 void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo,
306  const std::function<void(const std::vector<unsigned int>&)>& operation,
307  std::vector<unsigned int>& curIdx,
308  unsigned int curDim)
309 {
310  if(!(curDim < tensorInfo.GetNumDimensions()))
311  {
312  // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
313  operation(curIdx);
314  return;
315  }
316 
317  for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
318  {
319  curIdx[curDim] = i;
320  RecurseTensor(tensorInfo,
321  operation,
322  curIdx,
323  curDim + 1);
324  }
325 }
326 
327 void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
328  std::pair<unsigned int, unsigned int>& axesYToMul)
329 {
330  int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) -
331  static_cast<int>(inputYInfo.GetNumDimensions());
332  if(rankDiff == 0)
333  {
334  return;
335  }
336  else if(rankDiff < 0)
337  {
338  // Y is the larger one
339  axesXToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
340  axesXToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
341  }
342  else if(rankDiff > 0)
343  {
344  // X is the larger one
345  axesYToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
346  axesYToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
347  }
348 }
349 
350 float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData)
351 {
352  // This gets the data from the input vector that we have, Not the decoder
353  // But for the output, it is operating on the encoder itself
354 
355  AdjustToSafeIdx(type, idx);
356  unsigned int flatIdx = CalcFlatIdx(type, idx);
357  float value = 0.0f;
358  switch(type)
359  {
360  case DataSlot::InputX:
361  value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
362  break;
363  case DataSlot::InputY:
364  value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
365  break;
366  case DataSlot::Output:
367  outputEncoder[flatIdx];
368  value = outputEncoder.Get();
369  break;
370  default:
371  break;
372  }
373 
374  return value;
375 }
376 
377 void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
378 {
379  AdjustToSafeIdx(type, idx);
380  unsigned int flatIdx = CalcFlatIdx(type, idx);
381  switch(type)
382  {
383  case DataSlot::InputX:
384  inputXData[flatIdx] = value;
385  break;
386  case DataSlot::InputY:
387  inputYData[flatIdx] = value;
388  break;
389  case DataSlot::Output:
390  outputEncoder[flatIdx];
391  outputEncoder.Set(value);
392  break;
393  default:
394  break;
395  }
396 }
397 
398 void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
399 {
400  for(unsigned int dim = 0; dim < idx.size(); dim++)
401  {
402  switch(type)
403  {
404  case DataSlot::InputX:
405  {
406  auto xRank = inputXInfo.GetNumDimensions();
407  auto xDiff = outputInfo.GetNumDimensions() - xRank;
408  if (dim < xDiff ||
409  idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
410  {
411  idx[dim] = 0; // Broadcasting
412  }
413  break;
414  }
415  case DataSlot::InputY:
416  {
417  auto yRank = inputYInfo.GetNumDimensions();
418  auto yDiff = outputInfo.GetNumDimensions() - yRank;
419  if (dim < yDiff ||
420  idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
421  {
422  idx[dim] = 0;
423  }
424  break;
425  }
426  case DataSlot::Output:
427  {
428  // Our indices are based off the output
429  break;
430  }
431  default:
432  break;
433  }
434  }
435 }
436 
437 unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
438 {
439  unsigned int result = idx[idx.size()-1];
440  unsigned int dimMultiplier = 1;
441  unsigned int offset;
442 
443  // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
444  for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--)
445  {
446  switch(type)
447  {
448  case DataSlot::InputX:
449  offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
450  dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
451  break;
452  case DataSlot::InputY:
453  offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
454  dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
455  break;
456  case DataSlot::Output:
457  dimMultiplier *= outputInfo.GetShape()[i+1];
458  break;
459  default:
460  break;
461  }
462  result += (idx[i] * dimMultiplier);
463  }
464  return result;
465 }
466 
467 } // namespace armnn
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
virtual std::vector< float > DecodeTensor(const TensorShape &tensorShape, bool isDepthwise=false)=0
virtual void Set(IType right)=0
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
Copyright (c) 2021 ARM Limited and Contributors.
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) ...
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
BatchMatMul(const BatchMatMulDescriptor &params, const TensorInfo &inputXInfo, const TensorInfo &inputYInfo, const TensorInfo &outputInfo, Decoder< float > &inputXDecoder, Decoder< float > &inputYDecoder, Encoder< float > &outputEncoder)
A BatchMatMulDescriptor for the BatchMatMul operator.
virtual IType Get() const =0
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:195
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:98