From dc8ed9d75e54e914a970e137900930fa64a0782b Mon Sep 17 00:00:00 2001 From: Samuel Yap Date: Mon, 8 Aug 2022 14:07:42 +0100 Subject: IVGCVSW-7105: BatchMatMul Optional Parameter Support * Added transpose parameters to pre-transpose each input tensor's slices * Added adjoint parameters to pre-adjoint each input tensor's slices * Small refactoring (BatchMatMulDescriptor static helpers and BatchMatMulImpl constructor) * Updated input validation and output shape inference for parameters * Additional layer unit tests for parameters added * Versionings incremented Signed-off-by: Samuel Yap Change-Id: Ibe5242a8a5bf604c13de0dc65844fd6c421cc667 --- src/backends/backendsCommon/WorkloadData.cpp | 236 +++++++++++++-------------- 1 file changed, 113 insertions(+), 123 deletions(-) (limited to 'src/backends/backendsCommon/WorkloadData.cpp') diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 9a4c60f551..f4afbd9a84 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -4154,9 +4155,10 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively, // axes N and I must be the same size - const auto& inputTensorXInfo = workloadInfo.m_InputTensorInfos[0]; - const auto& inputTensorYInfo = workloadInfo.m_InputTensorInfos[1]; - const auto& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; + const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0]; + const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1]; + const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0]; + // Output info has already been inferred std::vector supportedTypes = { @@ -4168,108 +4170,127 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons DataType::QSymmS16 }; - ValidateDataTypes(inputTensorXInfo, supportedTypes, descriptorName); - ValidateDataTypes(inputTensorYInfo, supportedTypes, descriptorName); - ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); + ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName); + ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName); + ValidateDataTypes(outputInfo, supportedTypes, descriptorName); - if ((inputTensorXInfo.GetNumDimensions() < 2) || - (inputTensorYInfo.GetNumDimensions() < 2)) + if ((inputXInfoBeforeParams.GetNumDimensions() < 2) || + (inputYInfoBeforeParams.GetNumDimensions() < 2)) { throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater."); } - if(m_Parameters.m_DataLayoutX.has_value()) + TensorInfo inputXInfoAfterParams; + TensorInfo inputYInfoAfterParams; + + if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) || + (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY)) + { + throw InvalidArgumentException(descriptorName + + ": Invalid descriptor parameters - Transpose and Adjoint " + "cannot both be true for a given input tensor."); + } + if(m_Parameters.m_TransposeX) + { + inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams, + BatchMatMulDescriptor::GetPermuteVec( + m_Parameters.m_DataLayoutX, + inputXInfoBeforeParams.GetShape())); + } + else if(m_Parameters.m_AdjointX) { - switch(m_Parameters.m_DataLayoutX.value()) + auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX, + inputXInfoBeforeParams.GetShape()); + if(inputXInfoBeforeParams.GetShape()[axesToMul.first] != + inputXInfoBeforeParams.GetShape()[axesToMul.second]) { - case DataLayout::NCHW: - case DataLayout::NHWC: - if(inputTensorXInfo.GetNumDimensions() != 4) - { - throw InvalidArgumentException(descriptorName + - ": Input tensor X does not have the correct " - "number of dimensions for the Data Layout that it has been assigned."); - } - break; - case DataLayout::NCDHW: - case DataLayout::NDHWC: - if(inputTensorXInfo.GetNumDimensions() != 5) - { - throw InvalidArgumentException(descriptorName + - ": Input tensor X does not have the correct " - "number of dimensions for the Data Layout that it has been assigned."); - } - break; - default: - break; + throw InvalidArgumentException(descriptorName + + ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." ); } + // Shape remains the same as it's square + inputXInfoAfterParams = inputXInfoBeforeParams; + } + else + { + inputXInfoAfterParams = inputXInfoBeforeParams; } - if(m_Parameters.m_DataLayoutY.has_value()) + if(m_Parameters.m_TransposeY) { - switch(m_Parameters.m_DataLayoutY.value()) + inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams, + BatchMatMulDescriptor::GetPermuteVec( + m_Parameters.m_DataLayoutY, + inputYInfoBeforeParams.GetShape())); + } + else if(m_Parameters.m_AdjointY) + { + auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY, + inputYInfoBeforeParams.GetShape()); + if(inputYInfoBeforeParams.GetShape()[axesToMul.first] != + inputYInfoBeforeParams.GetShape()[axesToMul.second]) { - case DataLayout::NCHW: - case DataLayout::NHWC: - if(inputTensorYInfo.GetNumDimensions() != 4) - { - throw InvalidArgumentException(descriptorName + - ": Input tensor Y does not have the correct " - "number of dimensions for the Data Layout that it has been assigned."); - } - break; - case DataLayout::NCDHW: - case DataLayout::NDHWC: - if(inputTensorYInfo.GetNumDimensions() != 5) - { - throw InvalidArgumentException(descriptorName + - ": Input tensor Y does not have the correct " - "number of dimensions for the Data Layout that it has been assigned."); - } - break; - default: - break; + throw InvalidArgumentException(descriptorName + + ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." ); } + // Shape remains the same as it's square + inputYInfoAfterParams = inputYInfoBeforeParams; + } + else + { + inputYInfoAfterParams = inputYInfoBeforeParams; + } + + switch(m_Parameters.m_DataLayoutX) + { + case DataLayout::NCDHW: + case DataLayout::NDHWC: + if(inputXInfoAfterParams.GetNumDimensions() < 3) + { + throw InvalidArgumentException(descriptorName + + ": Input tensor X does not have the correct " + "number of dimensions for the Data Layout that it has been assigned."); + } + break; + case DataLayout::NCHW: + case DataLayout::NHWC: + default: + break; + } + + switch(m_Parameters.m_DataLayoutY) + { + case DataLayout::NCDHW: + case DataLayout::NDHWC: + if(inputYInfoAfterParams.GetNumDimensions() < 3) + { + throw InvalidArgumentException(descriptorName + + ": Input tensor Y does not have the correct " + "number of dimensions for the Data Layout that it has been assigned."); + } + break; + case DataLayout::NCHW: + case DataLayout::NHWC: + default: + break; } - auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters, - inputTensorXInfo.GetShape(), - inputTensorYInfo.GetShape()); + auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX, + inputXInfoAfterParams.GetShape()); + auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY, + inputXInfoBeforeParams.GetShape()); - if(inputTensorXInfo.GetShape()[axesToMul.first.second] - != inputTensorYInfo.GetShape()[axesToMul.second.first]) + if(inputXInfoAfterParams.GetShape()[axesXToMul.second] + != inputYInfoAfterParams.GetShape()[axesYToMul.first]) { throw InvalidArgumentException(descriptorName + ": The final axis of input tensor X must be the same size as " "the second last axis of input tensor Y."); } - auto axesNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters, - inputTensorXInfo.GetShape(), - inputTensorYInfo.GetShape()); - { // Separate scope so we don't pollute the rest of the scope with our temp variables // e.g. NHWC isnt compatible with NCHW as of now - DataLayout xLayout; - DataLayout yLayout; - - if(m_Parameters.m_DataLayoutX == EmptyOptional()) - { - xLayout = DataLayout::NCHW; // Not equivalent - I'm just concerned with the last 2 axes - } - else - { - xLayout = m_Parameters.m_DataLayoutX.value(); - } - - if(m_Parameters.m_DataLayoutY == EmptyOptional()) - { - yLayout = DataLayout::NCHW; - } - else - { - yLayout = m_Parameters.m_DataLayoutY.value(); - } + DataLayout xLayout = m_Parameters.m_DataLayoutX; + DataLayout yLayout = m_Parameters.m_DataLayoutY; if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW) { @@ -4290,8 +4311,8 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons } // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one - unsigned int outputTensorDimSize = std::max(inputTensorXInfo.GetNumDimensions(), - inputTensorYInfo.GetNumDimensions()); + unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(), + inputYInfoAfterParams.GetNumDimensions()); if(outputTensorDimSize-2 > 0) { TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2), @@ -4312,12 +4333,17 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons for(unsigned int i = 0; i < ti.GetNumDimensions(); i++) { - ti.GetShape()[i] = inputTensorXInfo.GetShape()[i]; + ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i]; } }; - doAxisExtension(axesNotMul.first, tiXNotMul); - doAxisExtension(axesNotMul.second, tiYNotMul); + auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX, + inputXInfoAfterParams.GetShape()); + auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY, + inputYInfoAfterParams.GetShape()); + + doAxisExtension(axesXNotMul, tiXNotMul); + doAxisExtension(axesYNotMul, tiYNotMul); for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++) { @@ -4332,42 +4358,6 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons "input_X", "input_Y"); } - - // Also check descriptor parameter validity - // This will eventually be moved to the start of the function as explained below - if ((!m_Parameters.m_TransposeX.empty() && !m_Parameters.m_AdjointX.empty()) || - (!m_Parameters.m_TransposeY.empty() && !m_Parameters.m_AdjointY.empty())) - { - throw InvalidArgumentException(descriptorName + - ": Invalid descriptor parameters - Transpose and Adjoint " - "vectors cannot both be true for a given input tensor."); - } - - if(m_Parameters.m_TransposeX.size() != 0 && m_Parameters.m_TransposeX.size() != inputTensorXInfo.GetNumDimensions()) - { - throw InvalidArgumentException(descriptorName + - ": Invalid descriptor parameter - Transpose X vector must be " - "the same size as tensor input X's dimensionality."); - } - if(m_Parameters.m_AdjointX.size() != 0 && m_Parameters.m_AdjointX.size() != inputTensorXInfo.GetNumDimensions()) - { - throw InvalidArgumentException(descriptorName + - ": Invalid descriptor parameter - Adjoint X vector must be " - "the same size as tensor input X's dimensionality."); - } - if(m_Parameters.m_TransposeY.size() != 0 && m_Parameters.m_TransposeY.size() != inputTensorYInfo.GetNumDimensions()) - { - throw InvalidArgumentException(descriptorName + - ": Invalid descriptor parameter - Transpose Y vector must be " - "the same size as tensor input Y's dimensionality."); - } - if(m_Parameters.m_AdjointY.size() != 0 && m_Parameters.m_AdjointY.size() != inputTensorXInfo.GetNumDimensions()) - { - throw InvalidArgumentException(descriptorName + - ": Invalid descriptor parameter - Adjoint Y vector must be " - "the same size as tensor input Y's dimensionality."); - } - // Note: for adjoint/transpose, you'll need to do the validation atop the resultant permutation. } -- cgit v1.2.1