From dc8ed9d75e54e914a970e137900930fa64a0782b Mon Sep 17 00:00:00 2001 From: Samuel Yap Date: Mon, 8 Aug 2022 14:07:42 +0100 Subject: IVGCVSW-7105: BatchMatMul Optional Parameter Support * Added transpose parameters to pre-transpose each input tensor's slices * Added adjoint parameters to pre-adjoint each input tensor's slices * Small refactoring (BatchMatMulDescriptor static helpers and BatchMatMulImpl constructor) * Updated input validation and output shape inference for parameters * Additional layer unit tests for parameters added * Versionings incremented Signed-off-by: Samuel Yap Change-Id: Ibe5242a8a5bf604c13de0dc65844fd6c421cc667 --- .../reference/workloads/BatchMatMulImpl.cpp | 346 +++++++++++++++++---- .../reference/workloads/BatchMatMulImpl.hpp | 69 ++-- .../reference/workloads/RefBatchMatMulWorkload.cpp | 3 - 3 files changed, 324 insertions(+), 94 deletions(-) (limited to 'src/backends/reference/workloads') diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp index 6693f15760..c592b3b76c 100644 --- a/src/backends/reference/workloads/BatchMatMulImpl.cpp +++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp @@ -7,46 +7,53 @@ #include #include +#include namespace armnn { -void BatchMatMul::BatchMatMulImpl() +BatchMatMul::BatchMatMul(const BatchMatMulDescriptor& params, + const TensorInfo& inputXInfo, + const TensorInfo& inputYInfo, + const TensorInfo& outputInfo, + Decoder& inputXDecoder, + Decoder& inputYDecoder, + Encoder& outputEncoder) + : params(params), + inputXInfo(inputXInfo), + inputYInfo(inputYInfo), + outputInfo(outputInfo), + inputXDecoder(inputXDecoder), + inputYDecoder(inputYDecoder), + outputEncoder(outputEncoder) { - inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape()); - inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape()); + inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape()); + inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape()); // At this point, we don't touch the input decoders - just the resultant vectors - // Pre-transpose and pre-adjoint if their vectors aren't empty - // and also DataLayouts which may change with permutations/adjoints + ApplyParams(); - // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes? - - auto idx = std::vector(outputInfo.GetNumDimensions(), 0); - RecurseBMM(idx, 0); + ApplyBatchMatMul(); } -void BatchMatMul::RecurseBMM(std::vector& curIdx, unsigned int curDim) +void BatchMatMul::ApplyBatchMatMul() { - // We're working off of the indexes of the output tensor (the max possible shape) - - if(!(curDim < outputInfo.GetNumDimensions())) - { - // We're at the leaf level of this call tree, so we operate here (each leaf is a data point) + auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutX, + inputXInfo.GetShape()); + auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY, + inputYInfo.GetShape()); + AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul); - auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params, - inputXInfo.GetShape(), - inputYInfo.GetShape()); - AdjustAxesToMulForUnequalRanks(axesToMul); + unsigned int inputXColDim = axesXToMul.second; + unsigned int inputYRowDim = axesYToMul.first; - unsigned int inputXColDim = axesToMul.first.second; - unsigned int inputYRowDim = axesToMul.second.first; - - unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim]; + unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim]; + auto batchMatMulOperation = [&](const std::vector& curIdx) + { float sum = 0.0f; - // You could also use inputXColSize + // InputYRowSize is synonymous with inputXColSize for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) { auto xIdx = curIdx; xIdx[inputXColDim] = inputYRowIdx; @@ -54,24 +61,271 @@ void BatchMatMul::RecurseBMM(std::vector& curIdx, unsigned int cur auto yIdx = curIdx; yIdx[inputYRowDim] = inputYRowIdx; - sum += (GetValueAt(DataSlot::InputX, xIdx) - * GetValueAt(DataSlot::InputY, yIdx)); + sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx)); } SetValueAt(sum, DataSlot::Output, curIdx); + }; + + auto startIdx = std::vector(outputInfo.GetNumDimensions(), 0); + RecurseTensor(outputInfo, + batchMatMulOperation, + startIdx, + 0); +} + +void BatchMatMul::ApplyParams() +{ + if(params.m_TransposeX) + { + Transpose(DataSlot::InputX); + } + else if(params.m_AdjointX) + { + Adjoint(DataSlot::InputX); + } + if(params.m_TransposeY) + { + Transpose(DataSlot::InputY); + } + else if(params.m_AdjointY) + { + Adjoint(DataSlot::InputY); + } +} + +void BatchMatMul::Transpose(DataSlot type) +{ + // AKA the permute of the tensor + // This modifies the tensor's info. + + switch(type) + { + case DataSlot::InputX: + { + auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX, + inputXInfo.GetShape()); + inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec); + std::vector temp(inputXData.size()); + armnnUtils::Permute(inputXInfo.GetShape(), + permuteVec, + inputXData.data(), + temp.data(), + sizeof(float)); + inputXData = temp; + break; + } + case DataSlot::InputY: + { + auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY, + inputYInfo.GetShape()); + inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec); + std::vector temp(inputYData.size()); + armnnUtils::Permute(inputYInfo.GetShape(), + permuteVec, + inputYData.data(), + temp.data(), + sizeof(float)); + inputYData = temp; + break; + } + case DataSlot::Output: // We needn't transpose the output tensor + default: + break; + } +} + +void BatchMatMul::Adjoint(DataSlot type) +{ + // Finding the adjoint of a square matrix: + // Calculate the cofactor of each element (using Gauss elimination here) + // Apply a transpose to it (this also modifies the tensor's info) + + TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo; + const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY; + const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape()); + + ARMNN_ASSERT(inputInfo.GetShape()[axesToAdjoint.first] == inputInfo.GetShape()[axesToAdjoint.second]); + // We grab a copy of the tensor data to prevent overwriting + std::vector inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData; + + // The sub-matrix is the resultant matrix when the row and column of the current index is removed + unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1; + std::vector> subMat(subMatAxisSize, + std::vector(subMatAxisSize)); + + // Lambdas for each sub-step of the cofactor operation + auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f) + { + float diff = std::fabs(a-b); + float bound = diff * std::numeric_limits::epsilon() * unitsInLastPlace; + return (diff <= bound) || (diff < std::numeric_limits::min()); + }; + + float swapMultiplier = std::numeric_limits::max(); + auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB) + { + // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run) + for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++) + { + float tmp = subMat[rowIdxA][colIdx]; + subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx]; + subMat[rowIdxB][colIdx] = tmp; + } + swapMultiplier *= -1.0f; + }; + + auto findNextValidPivotRowIdx = [&](unsigned int colIdx) + { + unsigned int result = std::numeric_limits::max(); + + // The original diagonal has been checked and is invalid + for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++) + { + if(!almostEquals(subMat[rowIdx][colIdx], 0.0f)) + { + result = rowIdx; + break; + } + } + return result; + }; + + auto eliminate = [&](const float& pivot, unsigned int pivotPos) + { + for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++) + { + float multiplierNumerator = subMat[rowIdx][pivotPos]; + if(almostEquals(multiplierNumerator, 0.0f)) + { + continue; + } + float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies + // Hence the almostEquals usage to counteract this + for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++) + { + // We start at col=pivotPos as we have assumed that all elements + // to our left have been eliminated to zero already + + // We subtract based on the element directly above us in our pivot row + subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx]; + } + } + }; + + auto cofactorOperation = [&](const std::vector& curIdx) + { + auto row = curIdx[axesToAdjoint.first]; + auto col = curIdx[axesToAdjoint.second]; + + float minorMultiplier = static_cast(std::pow(-1, (row + 1 + col + 1))); + + for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++) + { + for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++) + { + unsigned int outerRow = (subRow >= row)?subRow + 1:subRow; + unsigned int outerCol = (subCol >= col)?subCol + 1:subCol; + auto cloneIdx = curIdx; + cloneIdx[axesToAdjoint.first] = outerRow; + cloneIdx[axesToAdjoint.second] = outerCol; + subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone); + } + } + + float determinant = 1.0f; + + // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices + switch(subMatAxisSize) + { + case 0: + { + determinant = GetValueAt(type, curIdx, inputDataClone); + break; + } + case 1: + { + // If the resultant sub-matrix is just one element - that's the determinant + determinant = subMat[0][0]; + break; + } + case 2: + { + // For a 2x2 sub-matrix, the determinant is just a*d-b*c + determinant = subMat[0][0] * subMat[1][1] - + subMat[0][1] * subMat[1][0]; + break; + } + default: + { + // Gaussian elimination to find the determinant of this sub-matrix + swapMultiplier = 1.0f; + // March diagonally down the pivots and if it's invalid (a zero), swap the row with the + // nearest non-zero down within the column + for(unsigned int pivotRow = 0, pivotCol = 0; + pivotRow < subMatAxisSize; + pivotRow++, pivotCol++) + { + float& pivot = subMat[pivotRow][pivotCol]; + + if(almostEquals(pivot, 0.0f)) + { + unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol); + if(nextValidPivotRowIdx == std::numeric_limits::max()) + { + // No valid pivot down this column, which means that this pivot remains a zero. + // This results in the determinant for this entire sub-matrix to just be zero. + determinant = 0.0f; + break; + } + swapRows(pivotRow, nextValidPivotRowIdx); + } + determinant *= pivot; + // The actual elimination bit (which will update/propagate to the pivots down the line) + eliminate(pivot, pivotRow); // Synonymous with pivotCol + } + + determinant *= swapMultiplier; + break; + } + } + float cofactor = minorMultiplier * determinant; + SetValueAt(cofactor, type, curIdx); + }; + + auto startIdx = std::vector(inputInfo.GetNumDimensions(), 0); + RecurseTensor(inputInfo, + cofactorOperation, + startIdx, + 0); + + Transpose(type); +} +void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo, + const std::function&)>& operation, + std::vector& curIdx, + unsigned int curDim) +{ + if(!(curDim < tensorInfo.GetNumDimensions())) + { + // We're at the leaf level of this call tree, so we operate here (each leaf is a data point) + operation(curIdx); return; } - for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++) + for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++) { curIdx[curDim] = i; - RecurseBMM(curIdx, curDim+1); + RecurseTensor(tensorInfo, + operation, + curIdx, + curDim + 1); } } -void BatchMatMul::AdjustAxesToMulForUnequalRanks( - std::pair, std::pair>& axesToMul) +void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair& axesXToMul, + std::pair& axesYToMul) { int rankDiff = static_cast(inputXInfo.GetNumDimensions()) - static_cast(inputYInfo.GetNumDimensions()); @@ -82,18 +336,18 @@ void BatchMatMul::AdjustAxesToMulForUnequalRanks( else if(rankDiff < 0) { // Y is the larger one - axesToMul.first.first += static_cast::type>(std::abs(rankDiff)); - axesToMul.first.second += static_cast::type>(std::abs(rankDiff)); + axesXToMul.first += static_cast::type>(std::abs(rankDiff)); + axesXToMul.second += static_cast::type>(std::abs(rankDiff)); } else if(rankDiff > 0) { // X is the larger one - axesToMul.second.first += static_cast::type>(std::abs(rankDiff)); - axesToMul.second.second += static_cast::type>(std::abs(rankDiff)); + axesYToMul.first += static_cast::type>(std::abs(rankDiff)); + axesYToMul.second += static_cast::type>(std::abs(rankDiff)); } } -float BatchMatMul::GetValueAt(DataSlot type, std::vector idx) +float BatchMatMul::GetValueAt(DataSlot type, std::vector idx, const std::vector& customData) { // This gets the data from the input vector that we have, Not the decoder // But for the output, it is operating on the encoder itself @@ -101,14 +355,13 @@ float BatchMatMul::GetValueAt(DataSlot type, std::vector idx) AdjustToSafeIdx(type, idx); unsigned int flatIdx = CalcFlatIdx(type, idx); float value = 0.0f; - switch(type) { case DataSlot::InputX: - value = inputXData[flatIdx]; + value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx]; break; case DataSlot::InputY: - value = inputYData[flatIdx]; + value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx]; break; case DataSlot::Output: outputEncoder[flatIdx]; @@ -124,9 +377,7 @@ float BatchMatMul::GetValueAt(DataSlot type, std::vector idx) void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector idx) { AdjustToSafeIdx(type, idx); - unsigned int flatIdx = CalcFlatIdx(type, idx); - switch(type) { case DataSlot::InputX: @@ -186,9 +437,7 @@ void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector& idx) unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector& idx) { unsigned int result = idx[idx.size()-1]; - unsigned int dimMultiplier = 1; - unsigned int offset; // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x) @@ -215,17 +464,4 @@ unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector -std::string BatchMatMul::StringifyVec(const std::vector& vec) -{ - std::string res = "{ "; - for(auto x : vec) - { - res += std::to_string(x); - res += " "; - } - res += "}"; - return res; -} - } // namespace armnn \ No newline at end of file diff --git a/src/backends/reference/workloads/BatchMatMulImpl.hpp b/src/backends/reference/workloads/BatchMatMulImpl.hpp index 25b6c85d77..19971a4af3 100644 --- a/src/backends/reference/workloads/BatchMatMulImpl.hpp +++ b/src/backends/reference/workloads/BatchMatMulImpl.hpp @@ -15,6 +15,15 @@ namespace armnn class BatchMatMul { public: + BatchMatMul(const BatchMatMulDescriptor& params, + const TensorInfo& inputXInfo, + const TensorInfo& inputYInfo, + const TensorInfo& outputInfo, + Decoder& inputXDecoder, + Decoder& inputYDecoder, + Encoder& outputEncoder); + +private: enum DataSlot { InputX = 0, @@ -22,31 +31,35 @@ public: Output = 2 }; - BatchMatMul(const BatchMatMulDescriptor& params, - const TensorInfo& inputXInfo, - const TensorInfo& inputYInfo, - const TensorInfo& outputInfo, - Decoder& inputXDecoder, - Decoder& inputYDecoder, - Encoder& outputEncoder) - : params(params), - inputXInfo(inputXInfo), - inputYInfo(inputYInfo), - outputInfo(outputInfo), - inputXDecoder(inputXDecoder), - inputYDecoder(inputYDecoder), - outputEncoder(outputEncoder) - {} + const BatchMatMulDescriptor& params; + TensorInfo inputXInfo; + TensorInfo inputYInfo; + TensorInfo outputInfo; + Decoder& inputXDecoder; + Decoder& inputYDecoder; + Encoder& outputEncoder; - void BatchMatMulImpl(); + std::vector inputXData; + std::vector inputYData; + + void ApplyBatchMatMul(); + + void ApplyParams(); + + void Transpose(DataSlot type); - void RecurseBMM(std::vector& curIdx, unsigned int curDim); + void Adjoint(DataSlot type); + + void RecurseTensor(const TensorInfo& tensorInfo, + std::function&)> const& operation, + std::vector& curIdx, + unsigned int curDim); // Adjusts it for when input tensors are of unequal rank - void AdjustAxesToMulForUnequalRanks( - std::pair, std::pair>& axesToMul); + void AdjustAxesToMulForUnequalRanks(std::pair& axesXToMul, + std::pair& axesYToMul); - float GetValueAt(DataSlot type, std::vector idx); + float GetValueAt(DataSlot type, std::vector idx, const std::vector& customData = {}); void SetValueAt(float value, DataSlot type, std::vector idx); @@ -54,22 +67,6 @@ public: void AdjustToSafeIdx(DataSlot type, std::vector& idx); unsigned int CalcFlatIdx(DataSlot type, const std::vector& idx); - - template - std::string StringifyVec(const std::vector& vec); - -private: - const BatchMatMulDescriptor& params; - const TensorInfo& inputXInfo; - const TensorInfo& inputYInfo; - const TensorInfo& outputInfo; - Decoder& inputXDecoder; - Decoder& inputYDecoder; - Encoder& outputEncoder; - - std::vector inputXData; - std::vector inputYData; - }; } // namespace armnn \ No newline at end of file diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp index 388190c4ef..027b93b5d9 100644 --- a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp +++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp @@ -51,9 +51,6 @@ void RefBatchMatMulWorkload::Execute(std::vector inputs, std::ve *inputXDecoder, *inputYDecoder, *outputEncoder); - - bmm.BatchMatMulImpl(); - } } // namespace armnn \ No newline at end of file -- cgit v1.2.1