aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/BatchMatMulImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/BatchMatMulImpl.cpp')
-rw-r--r--src/backends/reference/workloads/BatchMatMulImpl.cpp346
1 files changed, 291 insertions, 55 deletions
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp
index 6693f15760..c592b3b76c 100644
--- a/src/backends/reference/workloads/BatchMatMulImpl.cpp
+++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp
@@ -7,46 +7,53 @@
#include <armnn/backends/WorkloadData.hpp>
#include <armnn/Logging.hpp>
+#include <armnnUtils/Permute.hpp>
namespace armnn
{
-void BatchMatMul::BatchMatMulImpl()
+BatchMatMul::BatchMatMul(const BatchMatMulDescriptor& params,
+ const TensorInfo& inputXInfo,
+ const TensorInfo& inputYInfo,
+ const TensorInfo& outputInfo,
+ Decoder<float>& inputXDecoder,
+ Decoder<float>& inputYDecoder,
+ Encoder<float>& outputEncoder)
+ : params(params),
+ inputXInfo(inputXInfo),
+ inputYInfo(inputYInfo),
+ outputInfo(outputInfo),
+ inputXDecoder(inputXDecoder),
+ inputYDecoder(inputYDecoder),
+ outputEncoder(outputEncoder)
{
- inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape());
- inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape());
+ inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape());
+ inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape());
// At this point, we don't touch the input decoders - just the resultant vectors
- // Pre-transpose and pre-adjoint if their vectors aren't empty
- // and also DataLayouts which may change with permutations/adjoints
+ ApplyParams();
- // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes?
-
- auto idx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
- RecurseBMM(idx, 0);
+ ApplyBatchMatMul();
}
-void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim)
+void BatchMatMul::ApplyBatchMatMul()
{
- // We're working off of the indexes of the output tensor (the max possible shape)
-
- if(!(curDim < outputInfo.GetNumDimensions()))
- {
- // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
+ auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutX,
+ inputXInfo.GetShape());
+ auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY,
+ inputYInfo.GetShape());
+ AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
- auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params,
- inputXInfo.GetShape(),
- inputYInfo.GetShape());
- AdjustAxesToMulForUnequalRanks(axesToMul);
+ unsigned int inputXColDim = axesXToMul.second;
+ unsigned int inputYRowDim = axesYToMul.first;
- unsigned int inputXColDim = axesToMul.first.second;
- unsigned int inputYRowDim = axesToMul.second.first;
-
- unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
+ unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
+ auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx)
+ {
float sum = 0.0f;
- // You could also use inputXColSize
+ // InputYRowSize is synonymous with inputXColSize
for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
auto xIdx = curIdx;
xIdx[inputXColDim] = inputYRowIdx;
@@ -54,24 +61,271 @@ void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int cur
auto yIdx = curIdx;
yIdx[inputYRowDim] = inputYRowIdx;
- sum += (GetValueAt(DataSlot::InputX, xIdx)
- * GetValueAt(DataSlot::InputY, yIdx));
+ sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
}
SetValueAt(sum, DataSlot::Output, curIdx);
+ };
+
+ auto startIdx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
+ RecurseTensor(outputInfo,
+ batchMatMulOperation,
+ startIdx,
+ 0);
+}
+
+void BatchMatMul::ApplyParams()
+{
+ if(params.m_TransposeX)
+ {
+ Transpose(DataSlot::InputX);
+ }
+ else if(params.m_AdjointX)
+ {
+ Adjoint(DataSlot::InputX);
+ }
+ if(params.m_TransposeY)
+ {
+ Transpose(DataSlot::InputY);
+ }
+ else if(params.m_AdjointY)
+ {
+ Adjoint(DataSlot::InputY);
+ }
+}
+
+void BatchMatMul::Transpose(DataSlot type)
+{
+ // AKA the permute of the tensor
+ // This modifies the tensor's info.
+
+ switch(type)
+ {
+ case DataSlot::InputX:
+ {
+ auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX,
+ inputXInfo.GetShape());
+ inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec);
+ std::vector<float> temp(inputXData.size());
+ armnnUtils::Permute(inputXInfo.GetShape(),
+ permuteVec,
+ inputXData.data(),
+ temp.data(),
+ sizeof(float));
+ inputXData = temp;
+ break;
+ }
+ case DataSlot::InputY:
+ {
+ auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY,
+ inputYInfo.GetShape());
+ inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec);
+ std::vector<float> temp(inputYData.size());
+ armnnUtils::Permute(inputYInfo.GetShape(),
+ permuteVec,
+ inputYData.data(),
+ temp.data(),
+ sizeof(float));
+ inputYData = temp;
+ break;
+ }
+ case DataSlot::Output: // We needn't transpose the output tensor
+ default:
+ break;
+ }
+}
+
+void BatchMatMul::Adjoint(DataSlot type)
+{
+ // Finding the adjoint of a square matrix:
+ // Calculate the cofactor of each element (using Gauss elimination here)
+ // Apply a transpose to it (this also modifies the tensor's info)
+
+ TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
+ const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY;
+ const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape());
+
+ ARMNN_ASSERT(inputInfo.GetShape()[axesToAdjoint.first] == inputInfo.GetShape()[axesToAdjoint.second]);
+ // We grab a copy of the tensor data to prevent overwriting
+ std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
+
+ // The sub-matrix is the resultant matrix when the row and column of the current index is removed
+ unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
+ std::vector<std::vector<float>> subMat(subMatAxisSize,
+ std::vector<float>(subMatAxisSize));
+
+ // Lambdas for each sub-step of the cofactor operation
+ auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f)
+ {
+ float diff = std::fabs(a-b);
+ float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
+ return (diff <= bound) || (diff < std::numeric_limits<float>::min());
+ };
+
+ float swapMultiplier = std::numeric_limits<float>::max();
+ auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB)
+ {
+ // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run)
+ for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
+ {
+ float tmp = subMat[rowIdxA][colIdx];
+ subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
+ subMat[rowIdxB][colIdx] = tmp;
+ }
+ swapMultiplier *= -1.0f;
+ };
+
+ auto findNextValidPivotRowIdx = [&](unsigned int colIdx)
+ {
+ unsigned int result = std::numeric_limits<unsigned int>::max();
+
+ // The original diagonal has been checked and is invalid
+ for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
+ {
+ if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
+ {
+ result = rowIdx;
+ break;
+ }
+ }
+ return result;
+ };
+
+ auto eliminate = [&](const float& pivot, unsigned int pivotPos)
+ {
+ for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
+ {
+ float multiplierNumerator = subMat[rowIdx][pivotPos];
+ if(almostEquals(multiplierNumerator, 0.0f))
+ {
+ continue;
+ }
+ float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies
+ // Hence the almostEquals usage to counteract this
+ for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
+ {
+ // We start at col=pivotPos as we have assumed that all elements
+ // to our left have been eliminated to zero already
+
+ // We subtract based on the element directly above us in our pivot row
+ subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
+ }
+ }
+ };
+
+ auto cofactorOperation = [&](const std::vector<unsigned int>& curIdx)
+ {
+ auto row = curIdx[axesToAdjoint.first];
+ auto col = curIdx[axesToAdjoint.second];
+
+ float minorMultiplier = static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
+
+ for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
+ {
+ for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
+ {
+ unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
+ unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
+ auto cloneIdx = curIdx;
+ cloneIdx[axesToAdjoint.first] = outerRow;
+ cloneIdx[axesToAdjoint.second] = outerCol;
+ subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
+ }
+ }
+
+ float determinant = 1.0f;
+
+ // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices
+ switch(subMatAxisSize)
+ {
+ case 0:
+ {
+ determinant = GetValueAt(type, curIdx, inputDataClone);
+ break;
+ }
+ case 1:
+ {
+ // If the resultant sub-matrix is just one element - that's the determinant
+ determinant = subMat[0][0];
+ break;
+ }
+ case 2:
+ {
+ // For a 2x2 sub-matrix, the determinant is just a*d-b*c
+ determinant = subMat[0][0] * subMat[1][1] -
+ subMat[0][1] * subMat[1][0];
+ break;
+ }
+ default:
+ {
+ // Gaussian elimination to find the determinant of this sub-matrix
+ swapMultiplier = 1.0f;
+ // March diagonally down the pivots and if it's invalid (a zero), swap the row with the
+ // nearest non-zero down within the column
+ for(unsigned int pivotRow = 0, pivotCol = 0;
+ pivotRow < subMatAxisSize;
+ pivotRow++, pivotCol++)
+ {
+ float& pivot = subMat[pivotRow][pivotCol];
+
+ if(almostEquals(pivot, 0.0f))
+ {
+ unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
+ if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
+ {
+ // No valid pivot down this column, which means that this pivot remains a zero.
+ // This results in the determinant for this entire sub-matrix to just be zero.
+ determinant = 0.0f;
+ break;
+ }
+ swapRows(pivotRow, nextValidPivotRowIdx);
+ }
+ determinant *= pivot;
+ // The actual elimination bit (which will update/propagate to the pivots down the line)
+ eliminate(pivot, pivotRow); // Synonymous with pivotCol
+ }
+
+ determinant *= swapMultiplier;
+ break;
+ }
+ }
+ float cofactor = minorMultiplier * determinant;
+ SetValueAt(cofactor, type, curIdx);
+ };
+
+ auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
+ RecurseTensor(inputInfo,
+ cofactorOperation,
+ startIdx,
+ 0);
+
+ Transpose(type);
+}
+void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo,
+ const std::function<void(const std::vector<unsigned int>&)>& operation,
+ std::vector<unsigned int>& curIdx,
+ unsigned int curDim)
+{
+ if(!(curDim < tensorInfo.GetNumDimensions()))
+ {
+ // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
+ operation(curIdx);
return;
}
- for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++)
+ for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
{
curIdx[curDim] = i;
- RecurseBMM(curIdx, curDim+1);
+ RecurseTensor(tensorInfo,
+ operation,
+ curIdx,
+ curDim + 1);
}
}
-void BatchMatMul::AdjustAxesToMulForUnequalRanks(
- std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul)
+void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
+ std::pair<unsigned int, unsigned int>& axesYToMul)
{
int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) -
static_cast<int>(inputYInfo.GetNumDimensions());
@@ -82,18 +336,18 @@ void BatchMatMul::AdjustAxesToMulForUnequalRanks(
else if(rankDiff < 0)
{
// Y is the larger one
- axesToMul.first.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
- axesToMul.first.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesXToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesXToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
}
else if(rankDiff > 0)
{
// X is the larger one
- axesToMul.second.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
- axesToMul.second.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesYToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesYToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
}
}
-float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
+float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData)
{
// This gets the data from the input vector that we have, Not the decoder
// But for the output, it is operating on the encoder itself
@@ -101,14 +355,13 @@ float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
AdjustToSafeIdx(type, idx);
unsigned int flatIdx = CalcFlatIdx(type, idx);
float value = 0.0f;
-
switch(type)
{
case DataSlot::InputX:
- value = inputXData[flatIdx];
+ value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
break;
case DataSlot::InputY:
- value = inputYData[flatIdx];
+ value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
break;
case DataSlot::Output:
outputEncoder[flatIdx];
@@ -124,9 +377,7 @@ float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
{
AdjustToSafeIdx(type, idx);
-
unsigned int flatIdx = CalcFlatIdx(type, idx);
-
switch(type)
{
case DataSlot::InputX:
@@ -186,9 +437,7 @@ void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
{
unsigned int result = idx[idx.size()-1];
-
unsigned int dimMultiplier = 1;
-
unsigned int offset;
// -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
@@ -215,17 +464,4 @@ unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned
return result;
}
-template <typename T>
-std::string BatchMatMul::StringifyVec(const std::vector<T>& vec)
-{
- std::string res = "{ ";
- for(auto x : vec)
- {
- res += std::to_string(x);
- res += " ";
- }
- res += "}";
- return res;
-}
-
} // namespace armnn \ No newline at end of file