aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon
diff options
context:
space:
mode:
authorSamuel Yap <samuel.yap@arm.com>2022-08-08 14:07:42 +0100
committerNikhil Raj <nikhil.raj@arm.com>2022-08-30 17:03:33 +0100
commitdc8ed9d75e54e914a970e137900930fa64a0782b (patch)
tree8bcaedaae81a6afbdbe3c9a4e69e45840f18cdb4 /src/backends/backendsCommon
parent9c9d5b9d796d243d88bd7a7aebb2e7e6c467e3a4 (diff)
downloadarmnn-dc8ed9d75e54e914a970e137900930fa64a0782b.tar.gz
IVGCVSW-7105: BatchMatMul Optional Parameter Support
* Added transpose parameters to pre-transpose each input tensor's slices * Added adjoint parameters to pre-adjoint each input tensor's slices * Small refactoring (BatchMatMulDescriptor static helpers and BatchMatMulImpl constructor) * Updated input validation and output shape inference for parameters * Additional layer unit tests for parameters added * Versionings incremented Signed-off-by: Samuel Yap <samuel.yap@arm.com> Change-Id: Ibe5242a8a5bf604c13de0dc65844fd6c421cc667
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp236
-rw-r--r--src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp364
-rw-r--r--src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp18
3 files changed, 478 insertions, 140 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 9a4c60f551..f4afbd9a84 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -8,6 +8,7 @@
#include <armnn/backends/WorkloadInfo.hpp>
#include <armnnUtils/DataLayoutIndexed.hpp>
#include <armnnUtils/TensorUtils.hpp>
+#include <armnnUtils/Permute.hpp>
#include <armnn/utility/NumericCast.hpp>
#include <armnn/Logging.hpp>
@@ -4154,9 +4155,10 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
// For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
// axes N and I must be the same size
- const auto& inputTensorXInfo = workloadInfo.m_InputTensorInfos[0];
- const auto& inputTensorYInfo = workloadInfo.m_InputTensorInfos[1];
- const auto& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+ const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
+ const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
+ const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
+ // Output info has already been inferred
std::vector<DataType> supportedTypes =
{
@@ -4168,108 +4170,127 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
DataType::QSymmS16
};
- ValidateDataTypes(inputTensorXInfo, supportedTypes, descriptorName);
- ValidateDataTypes(inputTensorYInfo, supportedTypes, descriptorName);
- ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+ ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
+ ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
+ ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
- if ((inputTensorXInfo.GetNumDimensions() < 2) ||
- (inputTensorYInfo.GetNumDimensions() < 2))
+ if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
+ (inputYInfoBeforeParams.GetNumDimensions() < 2))
{
throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
}
- if(m_Parameters.m_DataLayoutX.has_value())
+ TensorInfo inputXInfoAfterParams;
+ TensorInfo inputYInfoAfterParams;
+
+ if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
+ (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Invalid descriptor parameters - Transpose and Adjoint "
+ "cannot both be true for a given input tensor.");
+ }
+ if(m_Parameters.m_TransposeX)
+ {
+ inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
+ BatchMatMulDescriptor::GetPermuteVec(
+ m_Parameters.m_DataLayoutX,
+ inputXInfoBeforeParams.GetShape()));
+ }
+ else if(m_Parameters.m_AdjointX)
{
- switch(m_Parameters.m_DataLayoutX.value())
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
+ inputXInfoBeforeParams.GetShape());
+ if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
+ inputXInfoBeforeParams.GetShape()[axesToMul.second])
{
- case DataLayout::NCHW:
- case DataLayout::NHWC:
- if(inputTensorXInfo.GetNumDimensions() != 4)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor X does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- case DataLayout::NCDHW:
- case DataLayout::NDHWC:
- if(inputTensorXInfo.GetNumDimensions() != 5)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor X does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- default:
- break;
+ throw InvalidArgumentException(descriptorName +
+ ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
}
+ // Shape remains the same as it's square
+ inputXInfoAfterParams = inputXInfoBeforeParams;
+ }
+ else
+ {
+ inputXInfoAfterParams = inputXInfoBeforeParams;
}
- if(m_Parameters.m_DataLayoutY.has_value())
+ if(m_Parameters.m_TransposeY)
{
- switch(m_Parameters.m_DataLayoutY.value())
+ inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
+ BatchMatMulDescriptor::GetPermuteVec(
+ m_Parameters.m_DataLayoutY,
+ inputYInfoBeforeParams.GetShape()));
+ }
+ else if(m_Parameters.m_AdjointY)
+ {
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
+ inputYInfoBeforeParams.GetShape());
+ if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
+ inputYInfoBeforeParams.GetShape()[axesToMul.second])
{
- case DataLayout::NCHW:
- case DataLayout::NHWC:
- if(inputTensorYInfo.GetNumDimensions() != 4)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor Y does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- case DataLayout::NCDHW:
- case DataLayout::NDHWC:
- if(inputTensorYInfo.GetNumDimensions() != 5)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor Y does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- default:
- break;
+ throw InvalidArgumentException(descriptorName +
+ ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
}
+ // Shape remains the same as it's square
+ inputYInfoAfterParams = inputYInfoBeforeParams;
+ }
+ else
+ {
+ inputYInfoAfterParams = inputYInfoBeforeParams;
+ }
+
+ switch(m_Parameters.m_DataLayoutX)
+ {
+ case DataLayout::NCDHW:
+ case DataLayout::NDHWC:
+ if(inputXInfoAfterParams.GetNumDimensions() < 3)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Input tensor X does not have the correct "
+ "number of dimensions for the Data Layout that it has been assigned.");
+ }
+ break;
+ case DataLayout::NCHW:
+ case DataLayout::NHWC:
+ default:
+ break;
+ }
+
+ switch(m_Parameters.m_DataLayoutY)
+ {
+ case DataLayout::NCDHW:
+ case DataLayout::NDHWC:
+ if(inputYInfoAfterParams.GetNumDimensions() < 3)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Input tensor Y does not have the correct "
+ "number of dimensions for the Data Layout that it has been assigned.");
+ }
+ break;
+ case DataLayout::NCHW:
+ case DataLayout::NHWC:
+ default:
+ break;
}
- auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters,
- inputTensorXInfo.GetShape(),
- inputTensorYInfo.GetShape());
+ auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
+ inputXInfoAfterParams.GetShape());
+ auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
+ inputXInfoBeforeParams.GetShape());
- if(inputTensorXInfo.GetShape()[axesToMul.first.second]
- != inputTensorYInfo.GetShape()[axesToMul.second.first])
+ if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
+ != inputYInfoAfterParams.GetShape()[axesYToMul.first])
{
throw InvalidArgumentException(descriptorName +
": The final axis of input tensor X must be the same size as "
"the second last axis of input tensor Y.");
}
- auto axesNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters,
- inputTensorXInfo.GetShape(),
- inputTensorYInfo.GetShape());
-
{ // Separate scope so we don't pollute the rest of the scope with our temp variables
// e.g. NHWC isnt compatible with NCHW as of now
- DataLayout xLayout;
- DataLayout yLayout;
-
- if(m_Parameters.m_DataLayoutX == EmptyOptional())
- {
- xLayout = DataLayout::NCHW; // Not equivalent - I'm just concerned with the last 2 axes
- }
- else
- {
- xLayout = m_Parameters.m_DataLayoutX.value();
- }
-
- if(m_Parameters.m_DataLayoutY == EmptyOptional())
- {
- yLayout = DataLayout::NCHW;
- }
- else
- {
- yLayout = m_Parameters.m_DataLayoutY.value();
- }
+ DataLayout xLayout = m_Parameters.m_DataLayoutX;
+ DataLayout yLayout = m_Parameters.m_DataLayoutY;
if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
{
@@ -4290,8 +4311,8 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
}
// Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
- unsigned int outputTensorDimSize = std::max(inputTensorXInfo.GetNumDimensions(),
- inputTensorYInfo.GetNumDimensions());
+ unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
+ inputYInfoAfterParams.GetNumDimensions());
if(outputTensorDimSize-2 > 0)
{
TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
@@ -4312,12 +4333,17 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
{
- ti.GetShape()[i] = inputTensorXInfo.GetShape()[i];
+ ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
}
};
- doAxisExtension(axesNotMul.first, tiXNotMul);
- doAxisExtension(axesNotMul.second, tiYNotMul);
+ auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
+ inputXInfoAfterParams.GetShape());
+ auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
+ inputYInfoAfterParams.GetShape());
+
+ doAxisExtension(axesXNotMul, tiXNotMul);
+ doAxisExtension(axesYNotMul, tiYNotMul);
for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
{
@@ -4332,42 +4358,6 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
"input_X",
"input_Y");
}
-
- // Also check descriptor parameter validity
- // This will eventually be moved to the start of the function as explained below
- if ((!m_Parameters.m_TransposeX.empty() && !m_Parameters.m_AdjointX.empty()) ||
- (!m_Parameters.m_TransposeY.empty() && !m_Parameters.m_AdjointY.empty()))
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameters - Transpose and Adjoint "
- "vectors cannot both be true for a given input tensor.");
- }
-
- if(m_Parameters.m_TransposeX.size() != 0 && m_Parameters.m_TransposeX.size() != inputTensorXInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Transpose X vector must be "
- "the same size as tensor input X's dimensionality.");
- }
- if(m_Parameters.m_AdjointX.size() != 0 && m_Parameters.m_AdjointX.size() != inputTensorXInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Adjoint X vector must be "
- "the same size as tensor input X's dimensionality.");
- }
- if(m_Parameters.m_TransposeY.size() != 0 && m_Parameters.m_TransposeY.size() != inputTensorYInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Transpose Y vector must be "
- "the same size as tensor input Y's dimensionality.");
- }
- if(m_Parameters.m_AdjointY.size() != 0 && m_Parameters.m_AdjointY.size() != inputTensorXInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Adjoint Y vector must be "
- "the same size as tensor input Y's dimensionality.");
- }
- // Note: for adjoint/transpose, you'll need to do the validation atop the resultant permutation.
}
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
index 41add6e6da..6fcc35ab52 100644
--- a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
@@ -191,7 +191,7 @@ LayerTestResult<T, 3> BatchMatMul3DSimpleTest(
std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
19, 22,
43, 50
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -247,9 +247,7 @@ LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory)
{
- auto descriptor = armnn::BatchMatMulDescriptor(
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NCHW),
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NCHW));
+ auto descriptor = armnn::BatchMatMulDescriptor(); // Default arbitrary layout is treated the same as NCHW
float qScale = 0.0f;
int32_t qOffset = 0;
@@ -282,7 +280,7 @@ LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
19, 22,
43, 50
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
memoryManager,
@@ -338,9 +336,12 @@ LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory)
{
- auto descriptor = armnn::BatchMatMulDescriptor(
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC),
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC));
+ auto descriptor = armnn::BatchMatMulDescriptor(false,
+ false,
+ false,
+ false,
+ armnn::DataLayout::NHWC,
+ armnn::DataLayout::NHWC);
float qScale = 0.0f;
int32_t qOffset = 0;
@@ -373,7 +374,7 @@ LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
19, 22,
43, 50
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
memoryManager,
@@ -471,7 +472,7 @@ LayerTestResult<T, 3> BatchMatMul3DBatchTest(
267, 286,
323, 346
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -566,7 +567,7 @@ LayerTestResult<T, 3> BatchMatMul3DBroadcastTest(
267, 286,
323, 346
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -661,7 +662,7 @@ LayerTestResult<T, 3> BatchMatMul3D2DBroadcastTest(
267, 286,
323, 346
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -717,9 +718,12 @@ LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory)
{
- auto descriptor = armnn::BatchMatMulDescriptor(
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NDHWC),
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC));
+ auto descriptor = armnn::BatchMatMulDescriptor(false,
+ false,
+ false,
+ false,
+ armnn::DataLayout::NDHWC,
+ armnn::DataLayout::NHWC);
float qScale = 0.0f;
int32_t qOffset = 0;
@@ -761,7 +765,7 @@ LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
34, 1079,
46, 1167
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 5>(workloadFactory,
memoryManager,
@@ -959,7 +963,7 @@ LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
88, 100, 142, 106,
39, 61, 78, 56,
72, 52, 98, 70
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -1007,4 +1011,330 @@ template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
BatchMatMul3DNonSquareTest<armnn::DataType::QSymmS16>(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 2> BatchMatMul2DTranspSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(true,
+ false,
+ false,
+ false);
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({2,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({2,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({3,3}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, 2, 3,
+ 4, 5, 6
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 7, 8, 9,
+ 10, 11, 12
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 47, 52, 57,
+ 64, 71, 78,
+ 81, 90, 99
+ }, qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 2>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 2> BatchMatMul2DAdjointSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(false,
+ false,
+ true,
+ false);
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({3,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({3,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({3,3}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 3, 1, 1,
+ 1, 3, -1,
+ 2, 4, 1
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 1, 0, 0,
+ 0, 1, 0,
+ 0, 0, 1
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 7, 3, -4,
+ -3, 1, 4,
+ -2, -10, 8
+ }, qScale, qOffset);
+
+ switch (ArmnnType)
+ {
+ case armnn::DataType::QAsymmU8:
+ outputExpected = armnnUtils::QuantizedVector<T>({
+ 3, 3, 0,
+ 0, 1, 1,
+ 0, 0, 8
+ }, qScale, qOffset);
+ break;
+ default:
+ break;
+ }
+
+ return BatchMatMulTestImpl<ArmnnType, T, 2>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 4> BatchMatMulNHWCParamsTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(false,
+ true,
+ true,
+ false,
+ armnn::DataLayout::NHWC,
+ armnn::DataLayout::NHWC);
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({1,4,4,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({2,2,4,1}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({2,4,2,2}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, -3, 1, 4, 4, 9, 1, 2,
+ 2, 4, 2, 2, 10, 7, 6, -5,
+ 3, 8, 9, 9, 21, 1, 17, 7,
+ 5, 11, 11, 8, 29, 3, 23, 6
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+
+ 9, 10, 11, 12,
+ 13, 14, 15, 16
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 28, 625, 140, 585,
+ 8, 110, -8, 1662,
+ -24, 401, -120, 921,
+ 12, 131, 108, -501,
+
+ 252, 545, 364, 505,
+ -24, 3214, -40, 4766,
+ -216, 1441, -312, 1961,
+ 204, -1133, 300, -1765
+ }, qScale, qOffset);
+
+ switch (ArmnnType)
+ {
+ case armnn::DataType::QAsymmU8:
+ outputExpected = armnnUtils::QuantizedVector<T>({
+ 28, 80, 140, 80,
+ 8, 45, 0, 255,
+ 0, 18, 0, 18,
+ 12, 0, 108, 0,
+
+ 252, 80, 255, 80,
+ 0, 255, 0, 255,
+ 0, 18, 0, 18,
+ 204, 0, 255, 0
+ }, qScale, qOffset);
+ break;
+ default:
+ break;
+ }
+
+ return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory); \ No newline at end of file
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
index 9e2139667b..0b261fba37 100644
--- a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
@@ -82,4 +82,22 @@ template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> BatchMatMul2DTranspSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> BatchMatMul2DAdjointSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> BatchMatMulNHWCParamsTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory); \ No newline at end of file