aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp236
1 files changed, 113 insertions, 123 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 9a4c60f551..f4afbd9a84 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -8,6 +8,7 @@
#include <armnn/backends/WorkloadInfo.hpp>
#include <armnnUtils/DataLayoutIndexed.hpp>
#include <armnnUtils/TensorUtils.hpp>
+#include <armnnUtils/Permute.hpp>
#include <armnn/utility/NumericCast.hpp>
#include <armnn/Logging.hpp>
@@ -4154,9 +4155,10 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
// For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
// axes N and I must be the same size
- const auto& inputTensorXInfo = workloadInfo.m_InputTensorInfos[0];
- const auto& inputTensorYInfo = workloadInfo.m_InputTensorInfos[1];
- const auto& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+ const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
+ const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
+ const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
+ // Output info has already been inferred
std::vector<DataType> supportedTypes =
{
@@ -4168,108 +4170,127 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
DataType::QSymmS16
};
- ValidateDataTypes(inputTensorXInfo, supportedTypes, descriptorName);
- ValidateDataTypes(inputTensorYInfo, supportedTypes, descriptorName);
- ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+ ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
+ ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
+ ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
- if ((inputTensorXInfo.GetNumDimensions() < 2) ||
- (inputTensorYInfo.GetNumDimensions() < 2))
+ if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
+ (inputYInfoBeforeParams.GetNumDimensions() < 2))
{
throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
}
- if(m_Parameters.m_DataLayoutX.has_value())
+ TensorInfo inputXInfoAfterParams;
+ TensorInfo inputYInfoAfterParams;
+
+ if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
+ (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Invalid descriptor parameters - Transpose and Adjoint "
+ "cannot both be true for a given input tensor.");
+ }
+ if(m_Parameters.m_TransposeX)
+ {
+ inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
+ BatchMatMulDescriptor::GetPermuteVec(
+ m_Parameters.m_DataLayoutX,
+ inputXInfoBeforeParams.GetShape()));
+ }
+ else if(m_Parameters.m_AdjointX)
{
- switch(m_Parameters.m_DataLayoutX.value())
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
+ inputXInfoBeforeParams.GetShape());
+ if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
+ inputXInfoBeforeParams.GetShape()[axesToMul.second])
{
- case DataLayout::NCHW:
- case DataLayout::NHWC:
- if(inputTensorXInfo.GetNumDimensions() != 4)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor X does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- case DataLayout::NCDHW:
- case DataLayout::NDHWC:
- if(inputTensorXInfo.GetNumDimensions() != 5)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor X does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- default:
- break;
+ throw InvalidArgumentException(descriptorName +
+ ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
}
+ // Shape remains the same as it's square
+ inputXInfoAfterParams = inputXInfoBeforeParams;
+ }
+ else
+ {
+ inputXInfoAfterParams = inputXInfoBeforeParams;
}
- if(m_Parameters.m_DataLayoutY.has_value())
+ if(m_Parameters.m_TransposeY)
{
- switch(m_Parameters.m_DataLayoutY.value())
+ inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
+ BatchMatMulDescriptor::GetPermuteVec(
+ m_Parameters.m_DataLayoutY,
+ inputYInfoBeforeParams.GetShape()));
+ }
+ else if(m_Parameters.m_AdjointY)
+ {
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
+ inputYInfoBeforeParams.GetShape());
+ if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
+ inputYInfoBeforeParams.GetShape()[axesToMul.second])
{
- case DataLayout::NCHW:
- case DataLayout::NHWC:
- if(inputTensorYInfo.GetNumDimensions() != 4)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor Y does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- case DataLayout::NCDHW:
- case DataLayout::NDHWC:
- if(inputTensorYInfo.GetNumDimensions() != 5)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor Y does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- default:
- break;
+ throw InvalidArgumentException(descriptorName +
+ ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
}
+ // Shape remains the same as it's square
+ inputYInfoAfterParams = inputYInfoBeforeParams;
+ }
+ else
+ {
+ inputYInfoAfterParams = inputYInfoBeforeParams;
+ }
+
+ switch(m_Parameters.m_DataLayoutX)
+ {
+ case DataLayout::NCDHW:
+ case DataLayout::NDHWC:
+ if(inputXInfoAfterParams.GetNumDimensions() < 3)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Input tensor X does not have the correct "
+ "number of dimensions for the Data Layout that it has been assigned.");
+ }
+ break;
+ case DataLayout::NCHW:
+ case DataLayout::NHWC:
+ default:
+ break;
+ }
+
+ switch(m_Parameters.m_DataLayoutY)
+ {
+ case DataLayout::NCDHW:
+ case DataLayout::NDHWC:
+ if(inputYInfoAfterParams.GetNumDimensions() < 3)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Input tensor Y does not have the correct "
+ "number of dimensions for the Data Layout that it has been assigned.");
+ }
+ break;
+ case DataLayout::NCHW:
+ case DataLayout::NHWC:
+ default:
+ break;
}
- auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters,
- inputTensorXInfo.GetShape(),
- inputTensorYInfo.GetShape());
+ auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
+ inputXInfoAfterParams.GetShape());
+ auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
+ inputXInfoBeforeParams.GetShape());
- if(inputTensorXInfo.GetShape()[axesToMul.first.second]
- != inputTensorYInfo.GetShape()[axesToMul.second.first])
+ if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
+ != inputYInfoAfterParams.GetShape()[axesYToMul.first])
{
throw InvalidArgumentException(descriptorName +
": The final axis of input tensor X must be the same size as "
"the second last axis of input tensor Y.");
}
- auto axesNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters,
- inputTensorXInfo.GetShape(),
- inputTensorYInfo.GetShape());
-
{ // Separate scope so we don't pollute the rest of the scope with our temp variables
// e.g. NHWC isnt compatible with NCHW as of now
- DataLayout xLayout;
- DataLayout yLayout;
-
- if(m_Parameters.m_DataLayoutX == EmptyOptional())
- {
- xLayout = DataLayout::NCHW; // Not equivalent - I'm just concerned with the last 2 axes
- }
- else
- {
- xLayout = m_Parameters.m_DataLayoutX.value();
- }
-
- if(m_Parameters.m_DataLayoutY == EmptyOptional())
- {
- yLayout = DataLayout::NCHW;
- }
- else
- {
- yLayout = m_Parameters.m_DataLayoutY.value();
- }
+ DataLayout xLayout = m_Parameters.m_DataLayoutX;
+ DataLayout yLayout = m_Parameters.m_DataLayoutY;
if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
{
@@ -4290,8 +4311,8 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
}
// Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
- unsigned int outputTensorDimSize = std::max(inputTensorXInfo.GetNumDimensions(),
- inputTensorYInfo.GetNumDimensions());
+ unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
+ inputYInfoAfterParams.GetNumDimensions());
if(outputTensorDimSize-2 > 0)
{
TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
@@ -4312,12 +4333,17 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
{
- ti.GetShape()[i] = inputTensorXInfo.GetShape()[i];
+ ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
}
};
- doAxisExtension(axesNotMul.first, tiXNotMul);
- doAxisExtension(axesNotMul.second, tiYNotMul);
+ auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
+ inputXInfoAfterParams.GetShape());
+ auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
+ inputYInfoAfterParams.GetShape());
+
+ doAxisExtension(axesXNotMul, tiXNotMul);
+ doAxisExtension(axesYNotMul, tiYNotMul);
for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
{
@@ -4332,42 +4358,6 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
"input_X",
"input_Y");
}
-
- // Also check descriptor parameter validity
- // This will eventually be moved to the start of the function as explained below
- if ((!m_Parameters.m_TransposeX.empty() && !m_Parameters.m_AdjointX.empty()) ||
- (!m_Parameters.m_TransposeY.empty() && !m_Parameters.m_AdjointY.empty()))
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameters - Transpose and Adjoint "
- "vectors cannot both be true for a given input tensor.");
- }
-
- if(m_Parameters.m_TransposeX.size() != 0 && m_Parameters.m_TransposeX.size() != inputTensorXInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Transpose X vector must be "
- "the same size as tensor input X's dimensionality.");
- }
- if(m_Parameters.m_AdjointX.size() != 0 && m_Parameters.m_AdjointX.size() != inputTensorXInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Adjoint X vector must be "
- "the same size as tensor input X's dimensionality.");
- }
- if(m_Parameters.m_TransposeY.size() != 0 && m_Parameters.m_TransposeY.size() != inputTensorYInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Transpose Y vector must be "
- "the same size as tensor input Y's dimensionality.");
- }
- if(m_Parameters.m_AdjointY.size() != 0 && m_Parameters.m_AdjointY.size() != inputTensorXInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Adjoint Y vector must be "
- "the same size as tensor input Y's dimensionality.");
- }
- // Note: for adjoint/transpose, you'll need to do the validation atop the resultant permutation.
}