23 inputXInfo(inputXInfo),
24 inputYInfo(inputYInfo),
25 outputInfo(outputInfo),
26 inputXDecoder(inputXDecoder),
27 inputYDecoder(inputYDecoder),
28 outputEncoder(outputEncoder)
39 void BatchMatMul::ApplyBatchMatMul()
45 AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
47 unsigned int inputXColDim = axesXToMul.second;
48 unsigned int inputYRowDim = axesYToMul.first;
50 unsigned int inputYRowSize = inputYInfo.
GetShape()[inputYRowDim];
52 auto batchMatMulOperation = [&](
const std::vector<unsigned int>& curIdx)
57 for (
unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
59 xIdx[inputXColDim] = inputYRowIdx;
62 yIdx[inputYRowDim] = inputYRowIdx;
64 sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
71 RecurseTensor(outputInfo,
77 void BatchMatMul::ApplyParams()
81 Transpose(DataSlot::InputX);
85 Adjoint(DataSlot::InputX);
89 Transpose(DataSlot::InputY);
93 Adjoint(DataSlot::InputY);
97 void BatchMatMul::Transpose(DataSlot type)
104 case DataSlot::InputX:
109 std::vector<float> temp(inputXData.size());
118 case DataSlot::InputY:
123 std::vector<float> temp(inputYData.size());
138 void BatchMatMul::Adjoint(DataSlot type)
144 TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
150 std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
153 unsigned int subMatAxisSize = inputInfo.
GetShape()[axesToAdjoint.first] - 1;
154 std::vector<std::vector<float>> subMat(subMatAxisSize,
155 std::vector<float>(subMatAxisSize));
158 auto almostEquals = [&](
const float& a,
const float& b,
float unitsInLastPlace = 2.0f)
160 float diff = std::fabs(a-b);
161 float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
162 return (diff <= bound) || (diff < std::numeric_limits<float>::min());
165 float swapMultiplier = std::numeric_limits<float>::max();
166 auto swapRows = [&](
unsigned int rowIdxA,
unsigned int rowIdxB)
169 for(
unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
171 float tmp = subMat[rowIdxA][colIdx];
172 subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
173 subMat[rowIdxB][colIdx] = tmp;
175 swapMultiplier *= -1.0f;
178 auto findNextValidPivotRowIdx = [&](
unsigned int colIdx)
180 unsigned int result = std::numeric_limits<unsigned int>::max();
183 for(
unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
185 if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
194 auto eliminate = [&](
const float& pivot,
unsigned int pivotPos)
196 for(
unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
198 float multiplierNumerator = subMat[rowIdx][pivotPos];
199 if(almostEquals(multiplierNumerator, 0.0f))
203 float multiplier = multiplierNumerator / pivot;
205 for(
unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
211 subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
216 auto cofactorOperation = [&](
const std::vector<unsigned int>& curIdx)
218 auto row = curIdx[axesToAdjoint.first];
219 auto col = curIdx[axesToAdjoint.second];
221 float minorMultiplier =
static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
223 for(
unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
225 for(
unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
227 unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
228 unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
229 auto cloneIdx = curIdx;
230 cloneIdx[axesToAdjoint.first] = outerRow;
231 cloneIdx[axesToAdjoint.second] = outerCol;
232 subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
236 float determinant = 1.0f;
239 switch(subMatAxisSize)
243 determinant = GetValueAt(type, curIdx, inputDataClone);
249 determinant = subMat[0][0];
255 determinant = subMat[0][0] * subMat[1][1] -
256 subMat[0][1] * subMat[1][0];
262 swapMultiplier = 1.0f;
265 for(
unsigned int pivotRow = 0, pivotCol = 0;
266 pivotRow < subMatAxisSize;
267 pivotRow++, pivotCol++)
269 float& pivot = subMat[pivotRow][pivotCol];
271 if(almostEquals(pivot, 0.0f))
273 unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
274 if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
281 swapRows(pivotRow, nextValidPivotRowIdx);
283 determinant *= pivot;
285 eliminate(pivot, pivotRow);
288 determinant *= swapMultiplier;
292 float cofactor = minorMultiplier * determinant;
293 SetValueAt(cofactor, type, curIdx);
297 RecurseTensor(inputInfo,
305 void BatchMatMul::RecurseTensor(
const TensorInfo& tensorInfo,
306 const std::function<
void(
const std::vector<unsigned int>&)>& operation,
307 std::vector<unsigned int>& curIdx,
317 for(
unsigned int i = 0; i < tensorInfo.
GetShape()[curDim]; i++)
320 RecurseTensor(tensorInfo,
327 void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
328 std::pair<unsigned int, unsigned int>& axesYToMul)
336 else if(rankDiff < 0)
339 axesXToMul.first +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
340 axesXToMul.second +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
342 else if(rankDiff > 0)
345 axesYToMul.first +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
346 axesYToMul.second +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
350 float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx,
const std::vector<float>& customData)
355 AdjustToSafeIdx(type, idx);
356 unsigned int flatIdx = CalcFlatIdx(type, idx);
360 case DataSlot::InputX:
361 value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
363 case DataSlot::InputY:
364 value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
367 outputEncoder[flatIdx];
368 value = outputEncoder.
Get();
377 void BatchMatMul::SetValueAt(
float value, DataSlot type, std::vector<unsigned int> idx)
379 AdjustToSafeIdx(type, idx);
380 unsigned int flatIdx = CalcFlatIdx(type, idx);
383 case DataSlot::InputX:
384 inputXData[flatIdx] = value;
386 case DataSlot::InputY:
387 inputYData[flatIdx] = value;
390 outputEncoder[flatIdx];
391 outputEncoder.
Set(value);
398 void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
400 for(
unsigned int dim = 0; dim < idx.size(); dim++)
404 case DataSlot::InputX:
409 idx[dim] > inputXInfo.
GetShape()[dim-xDiff]-1)
415 case DataSlot::InputY:
420 idx[dim] > inputYInfo.
GetShape()[dim-yDiff]-1)
437 unsigned int BatchMatMul::CalcFlatIdx(DataSlot type,
const std::vector<unsigned int>& idx)
439 unsigned int result = idx[idx.size()-1];
440 unsigned int dimMultiplier = 1;
444 for(
unsigned int i = static_cast<unsigned int>(idx.size()-2);
static_cast<int>(i) >= 0; i--)
448 case DataSlot::InputX:
450 dimMultiplier *= inputXInfo.
GetShape()[i + 1 - offset];
452 case DataSlot::InputY:
454 dimMultiplier *= inputYInfo.
GetShape()[i + 1 - offset];
457 dimMultiplier *= outputInfo.
GetShape()[i+1];
462 result += (idx[i] * dimMultiplier);
const TensorShape & GetShape() const
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
virtual std::vector< float > DecodeTensor(const TensorShape &tensorShape, bool isDepthwise=false)=0
virtual void Set(IType right)=0
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
Copyright (c) 2021 ARM Limited and Contributors.
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) ...
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)
#define ARMNN_ASSERT(COND)
BatchMatMul(const BatchMatMulDescriptor ¶ms, const TensorInfo &inputXInfo, const TensorInfo &inputYInfo, const TensorInfo &outputInfo, Decoder< float > &inputXDecoder, Decoder< float > &inputYDecoder, Encoder< float > &outputEncoder)
A BatchMatMulDescriptor for the BatchMatMul operator.
virtual IType Get() const =0
unsigned int GetNumDimensions() const
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)