ArmNN
 24.02
BatchMatMulImpl.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BatchMatMulImpl.hpp"
7 
9 #include <armnn/Logging.hpp>
10 #include <armnnUtils/Permute.hpp>
11 
12 namespace armnn
13 {
14 
16  const TensorInfo& inputXInfo,
17  const TensorInfo& inputYInfo,
18  const TensorInfo& outputInfo,
19  Decoder<float>& inputXDecoder,
20  Decoder<float>& inputYDecoder,
21  Encoder<float>& outputEncoder)
22  : params(params),
23  inputXInfo(inputXInfo),
24  inputYInfo(inputYInfo),
25  outputInfo(outputInfo),
26  inputXDecoder(inputXDecoder),
27  inputYDecoder(inputYDecoder),
28  outputEncoder(outputEncoder)
29 {
30  inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape());
31  inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape());
32  // At this point, we don't touch the input decoders - just the resultant vectors
33 
34  ApplyParams();
35 
36  ApplyBatchMatMul();
37 }
38 
39 void BatchMatMul::ApplyBatchMatMul()
40 {
41  auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutX,
42  inputXInfo.GetShape());
43  auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY,
44  inputYInfo.GetShape());
45  AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
46 
47  unsigned int inputXColDim = axesXToMul.second;
48  unsigned int inputYRowDim = axesYToMul.first;
49 
50  unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
51 
52  auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx)
53  {
54  float sum = 0.0f;
55 
56  // InputYRowSize is synonymous with inputXColSize
57  for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
58  auto xIdx = curIdx;
59  xIdx[inputXColDim] = inputYRowIdx;
60 
61  auto yIdx = curIdx;
62  yIdx[inputYRowDim] = inputYRowIdx;
63 
64  sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
65  }
66 
67  SetValueAt(sum, DataSlot::Output, curIdx);
68  };
69 
70  auto startIdx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
71  RecurseTensor(outputInfo,
72  batchMatMulOperation,
73  startIdx,
74  0);
75 }
76 
77 void BatchMatMul::ApplyParams()
78 {
79  if(params.m_TransposeX)
80  {
81  Transpose(DataSlot::InputX);
82  }
83  else if(params.m_AdjointX)
84  {
85  Adjoint(DataSlot::InputX);
86  }
87  if(params.m_TransposeY)
88  {
89  Transpose(DataSlot::InputY);
90  }
91  else if(params.m_AdjointY)
92  {
93  Adjoint(DataSlot::InputY);
94  }
95 }
96 
97 void BatchMatMul::Transpose(DataSlot type)
98 {
99  // AKA the permute of the tensor
100  // This modifies the tensor's info.
101 
102  switch(type)
103  {
104  case DataSlot::InputX:
105  {
106  auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX,
107  inputXInfo.GetShape());
108  inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec);
109  std::vector<float> temp(inputXData.size());
110  armnnUtils::Permute(inputXInfo.GetShape(),
111  permuteVec,
112  inputXData.data(),
113  temp.data(),
114  sizeof(float));
115  inputXData = temp;
116  break;
117  }
118  case DataSlot::InputY:
119  {
120  auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY,
121  inputYInfo.GetShape());
122  inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec);
123  std::vector<float> temp(inputYData.size());
124  armnnUtils::Permute(inputYInfo.GetShape(),
125  permuteVec,
126  inputYData.data(),
127  temp.data(),
128  sizeof(float));
129  inputYData = temp;
130  break;
131  }
132  case DataSlot::Output: // We needn't transpose the output tensor
133  default:
134  break;
135  }
136 }
137 
138 void BatchMatMul::Adjoint(DataSlot type)
139 {
140  // Finding the adjoint of a square matrix:
141  // Calculate the cofactor of each element (using Gauss elimination here)
142  // Apply a transpose to it (this also modifies the tensor's info)
143 
144  TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
145  const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY;
146  const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape());
147 
148  ARMNN_ASSERT(inputInfo.GetShape()[axesToAdjoint.first] == inputInfo.GetShape()[axesToAdjoint.second]);
149  // We grab a copy of the tensor data to prevent overwriting
150  std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
151 
152  // The sub-matrix is the resultant matrix when the row and column of the current index is removed
153  unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
154  std::vector<std::vector<float>> subMat(subMatAxisSize,
155  std::vector<float>(subMatAxisSize));
156 
157  // Lambdas for each sub-step of the cofactor operation
158  auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f)
159  {
160  float diff = std::fabs(a-b);
161  float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
162  return (diff <= bound) || (diff < std::numeric_limits<float>::min());
163  };
164 
165  float swapMultiplier = std::numeric_limits<float>::max();
166  auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB)
167  {
168  // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run)
169  for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
170  {
171  float tmp = subMat[rowIdxA][colIdx];
172  subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
173  subMat[rowIdxB][colIdx] = tmp;
174  }
175  swapMultiplier *= -1.0f;
176  };
177 
178  auto findNextValidPivotRowIdx = [&](unsigned int colIdx)
179  {
180  unsigned int result = std::numeric_limits<unsigned int>::max();
181 
182  // The original diagonal has been checked and is invalid
183  for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
184  {
185  if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
186  {
187  result = rowIdx;
188  break;
189  }
190  }
191  return result;
192  };
193 
194  auto eliminate = [&](const float& pivot, unsigned int pivotPos)
195  {
196  for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
197  {
198  float multiplierNumerator = subMat[rowIdx][pivotPos];
199  if(almostEquals(multiplierNumerator, 0.0f))
200  {
201  continue;
202  }
203  float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies
204  // Hence the almostEquals usage to counteract this
205  for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
206  {
207  // We start at col=pivotPos as we have assumed that all elements
208  // to our left have been eliminated to zero already
209 
210  // We subtract based on the element directly above us in our pivot row
211  subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
212  }
213  }
214  };
215 
216  auto cofactorOperation = [&](const std::vector<unsigned int>& curIdx)
217  {
218  auto row = curIdx[axesToAdjoint.first];
219  auto col = curIdx[axesToAdjoint.second];
220 
221  float minorMultiplier = static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
222 
223  for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
224  {
225  for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
226  {
227  unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
228  unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
229  auto cloneIdx = curIdx;
230  cloneIdx[axesToAdjoint.first] = outerRow;
231  cloneIdx[axesToAdjoint.second] = outerCol;
232  subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
233  }
234  }
235 
236  float determinant = 1.0f;
237 
238  // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices
239  switch(subMatAxisSize)
240  {
241  case 0:
242  {
243  determinant = GetValueAt(type, curIdx, inputDataClone);
244  break;
245  }
246  case 1:
247  {
248  // If the resultant sub-matrix is just one element - that's the determinant
249  determinant = subMat[0][0];
250  break;
251  }
252  case 2:
253  {
254  // For a 2x2 sub-matrix, the determinant is just a*d-b*c
255  determinant = subMat[0][0] * subMat[1][1] -
256  subMat[0][1] * subMat[1][0];
257  break;
258  }
259  default:
260  {
261  // Gaussian elimination to find the determinant of this sub-matrix
262  swapMultiplier = 1.0f;
263  // March diagonally down the pivots and if it's invalid (a zero), swap the row with the
264  // nearest non-zero down within the column
265  for(unsigned int pivotRow = 0, pivotCol = 0;
266  pivotRow < subMatAxisSize;
267  pivotRow++, pivotCol++)
268  {
269  float& pivot = subMat[pivotRow][pivotCol];
270 
271  if(almostEquals(pivot, 0.0f))
272  {
273  unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
274  if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
275  {
276  // No valid pivot down this column, which means that this pivot remains a zero.
277  // This results in the determinant for this entire sub-matrix to just be zero.
278  determinant = 0.0f;
279  break;
280  }
281  swapRows(pivotRow, nextValidPivotRowIdx);
282  }
283  determinant *= pivot;
284  // The actual elimination bit (which will update/propagate to the pivots down the line)
285  eliminate(pivot, pivotRow); // Synonymous with pivotCol
286  }
287 
288  determinant *= swapMultiplier;
289  break;
290  }
291  }
292  float cofactor = minorMultiplier * determinant;
293  SetValueAt(cofactor, type, curIdx);
294  };
295 
296  auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
297  RecurseTensor(inputInfo,
298  cofactorOperation,
299  startIdx,
300  0);
301 
302  Transpose(type);
303 }
304 
305 void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo,
306  const std::function<void(const std::vector<unsigned int>&)>& operation,
307  std::vector<unsigned int>& curIdx,
308  unsigned int curDim)
309 {
310  if(!(curDim < tensorInfo.GetNumDimensions()))
311  {
312  // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
313  operation(curIdx);
314  return;
315  }
316 
317  for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
318  {
319  curIdx[curDim] = i;
320  RecurseTensor(tensorInfo,
321  operation,
322  curIdx,
323  curDim + 1);
324  }
325 }
326 
327 void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
328  std::pair<unsigned int, unsigned int>& axesYToMul)
329 {
330  int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) -
331  static_cast<int>(inputYInfo.GetNumDimensions());
332  if(rankDiff == 0)
333  {
334  return;
335  }
336  else if(rankDiff < 0)
337  {
338  // Y is the larger one
339  axesXToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
340  axesXToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
341  }
342  else if(rankDiff > 0)
343  {
344  // X is the larger one
345  axesYToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
346  axesYToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
347  }
348 }
349 
350 float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData)
351 {
352  // This gets the data from the input vector that we have, Not the decoder
353  // But for the output, it is operating on the encoder itself
354 
355  AdjustToSafeIdx(type, idx);
356  unsigned int flatIdx = CalcFlatIdx(type, idx);
357  float value = 0.0f;
358  switch(type)
359  {
360  case DataSlot::InputX:
361  value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
362  break;
363  case DataSlot::InputY:
364  value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
365  break;
366  case DataSlot::Output:
367  outputEncoder[flatIdx];
368  value = outputEncoder.Get();
369  break;
370  default:
371  break;
372  }
373 
374  return value;
375 }
376 
377 void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
378 {
379  AdjustToSafeIdx(type, idx);
380  unsigned int flatIdx = CalcFlatIdx(type, idx);
381  switch(type)
382  {
383  case DataSlot::InputX:
384  inputXData[flatIdx] = value;
385  break;
386  case DataSlot::InputY:
387  inputYData[flatIdx] = value;
388  break;
389  case DataSlot::Output:
390  outputEncoder[flatIdx];
391  outputEncoder.Set(value);
392  break;
393  default:
394  break;
395  }
396 }
397 
398 void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
399 {
400  for(unsigned int dim = 0; dim < idx.size(); dim++)
401  {
402  switch(type)
403  {
404  case DataSlot::InputX:
405  {
406  auto xRank = inputXInfo.GetNumDimensions();
407  auto xDiff = outputInfo.GetNumDimensions() - xRank;
408  if (dim < xDiff ||
409  idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
410  {
411  idx[dim] = 0; // Broadcasting
412  }
413  break;
414  }
415  case DataSlot::InputY:
416  {
417  auto yRank = inputYInfo.GetNumDimensions();
418  auto yDiff = outputInfo.GetNumDimensions() - yRank;
419  if (dim < yDiff ||
420  idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
421  {
422  idx[dim] = 0;
423  }
424  break;
425  }
426  case DataSlot::Output:
427  {
428  // Our indices are based off the output
429  break;
430  }
431  default:
432  break;
433  }
434  }
435 }
436 
437 unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
438 {
439  unsigned int result = idx[idx.size()-1];
440  unsigned int dimMultiplier = 1;
441  unsigned int offset;
442 
443  // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
444  for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--)
445  {
446  switch(type)
447  {
448  case DataSlot::InputX:
449  offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
450  dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
451  break;
452  case DataSlot::InputY:
453  offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
454  dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
455  break;
456  case DataSlot::Output:
457  dimMultiplier *= outputInfo.GetShape()[i+1];
458  break;
459  default:
460  break;
461  }
462  result += (idx[i] * dimMultiplier);
463  }
464  return result;
465 }
466 
467 } // namespace armnn
ARMNN_ASSERT
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
armnn::Decoder< float >
armnn::BatchMatMulDescriptor::m_TransposeX
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
Definition: Descriptors.hpp:1612
armnn::Encoder::Set
virtual void Set(IType right)=0
WorkloadData.hpp
armnn::BatchMatMulDescriptor::m_AdjointX
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
Definition: Descriptors.hpp:1617
armnn::BatchMatMulDescriptor::GetAxesToMul
static std::pair< unsigned int, unsigned int > GetAxesToMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the two axes (for each input) for multiplication.
Definition: Descriptors.cpp:484
BatchMatMulImpl.hpp
armnn::Encoder::Get
virtual IType Get() const =0
armnn::BatchMatMulDescriptor::m_DataLayoutX
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
Definition: Descriptors.hpp:1621
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::TensorInfo::GetNumDimensions
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:197
armnn::BatchMatMulDescriptor::GetPermuteVec
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
Definition: Descriptors.cpp:522
armnn::BatchMatMulDescriptor::m_AdjointY
bool m_AdjointY
Definition: Descriptors.hpp:1618
armnnUtils::Permute
void Permute(const armnn::TensorShape &dstShape, const armnn::PermutationVector &mappings, const void *src, void *dst, size_t dataTypeSize)
Definition: Permute.cpp:164
armnnUtils::Permuted
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:125
armnn::Encoder< float >
Logging.hpp
armnn::BatchMatMulDescriptor::m_TransposeY
bool m_TransposeY
Definition: Descriptors.hpp:1613
armnn::BatchMatMulDescriptor::m_DataLayoutY
DataLayout m_DataLayoutY
Definition: Descriptors.hpp:1622
armnn::BatchMatMulDescriptor
A BatchMatMulDescriptor for the BatchMatMul operator.
Definition: Descriptors.hpp:1584
armnn::Decoder::DecodeTensor
virtual std::vector< float > DecodeTensor(const TensorShape &tensorShape, bool isDepthwise=false)=0
Permute.hpp
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::BatchMatMul::BatchMatMul
BatchMatMul(const BatchMatMulDescriptor &params, const TensorInfo &inputXInfo, const TensorInfo &inputYInfo, const TensorInfo &outputInfo, Decoder< float > &inputXDecoder, Decoder< float > &inputYDecoder, Encoder< float > &outputEncoder)
Definition: BatchMatMulImpl.cpp:15