aboutsummaryrefslogtreecommitdiff
path: root/src/backends
diff options
context:
space:
mode:
authorSamuel Yap <samuel.yap@arm.com>2022-08-08 14:07:42 +0100
committerNikhil Raj <nikhil.raj@arm.com>2022-08-30 17:03:33 +0100
commitdc8ed9d75e54e914a970e137900930fa64a0782b (patch)
tree8bcaedaae81a6afbdbe3c9a4e69e45840f18cdb4 /src/backends
parent9c9d5b9d796d243d88bd7a7aebb2e7e6c467e3a4 (diff)
downloadarmnn-dc8ed9d75e54e914a970e137900930fa64a0782b.tar.gz
IVGCVSW-7105: BatchMatMul Optional Parameter Support
* Added transpose parameters to pre-transpose each input tensor's slices * Added adjoint parameters to pre-adjoint each input tensor's slices * Small refactoring (BatchMatMulDescriptor static helpers and BatchMatMulImpl constructor) * Updated input validation and output shape inference for parameters * Additional layer unit tests for parameters added * Versionings incremented Signed-off-by: Samuel Yap <samuel.yap@arm.com> Change-Id: Ibe5242a8a5bf604c13de0dc65844fd6c421cc667
Diffstat (limited to 'src/backends')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp236
-rw-r--r--src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp364
-rw-r--r--src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp18
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp21
-rw-r--r--src/backends/reference/workloads/BatchMatMulImpl.cpp346
-rw-r--r--src/backends/reference/workloads/BatchMatMulImpl.hpp69
-rw-r--r--src/backends/reference/workloads/RefBatchMatMulWorkload.cpp3
7 files changed, 823 insertions, 234 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 9a4c60f551..f4afbd9a84 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -8,6 +8,7 @@
#include <armnn/backends/WorkloadInfo.hpp>
#include <armnnUtils/DataLayoutIndexed.hpp>
#include <armnnUtils/TensorUtils.hpp>
+#include <armnnUtils/Permute.hpp>
#include <armnn/utility/NumericCast.hpp>
#include <armnn/Logging.hpp>
@@ -4154,9 +4155,10 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
// For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
// axes N and I must be the same size
- const auto& inputTensorXInfo = workloadInfo.m_InputTensorInfos[0];
- const auto& inputTensorYInfo = workloadInfo.m_InputTensorInfos[1];
- const auto& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+ const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
+ const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
+ const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
+ // Output info has already been inferred
std::vector<DataType> supportedTypes =
{
@@ -4168,108 +4170,127 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
DataType::QSymmS16
};
- ValidateDataTypes(inputTensorXInfo, supportedTypes, descriptorName);
- ValidateDataTypes(inputTensorYInfo, supportedTypes, descriptorName);
- ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+ ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
+ ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
+ ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
- if ((inputTensorXInfo.GetNumDimensions() < 2) ||
- (inputTensorYInfo.GetNumDimensions() < 2))
+ if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
+ (inputYInfoBeforeParams.GetNumDimensions() < 2))
{
throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
}
- if(m_Parameters.m_DataLayoutX.has_value())
+ TensorInfo inputXInfoAfterParams;
+ TensorInfo inputYInfoAfterParams;
+
+ if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
+ (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Invalid descriptor parameters - Transpose and Adjoint "
+ "cannot both be true for a given input tensor.");
+ }
+ if(m_Parameters.m_TransposeX)
+ {
+ inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
+ BatchMatMulDescriptor::GetPermuteVec(
+ m_Parameters.m_DataLayoutX,
+ inputXInfoBeforeParams.GetShape()));
+ }
+ else if(m_Parameters.m_AdjointX)
{
- switch(m_Parameters.m_DataLayoutX.value())
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
+ inputXInfoBeforeParams.GetShape());
+ if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
+ inputXInfoBeforeParams.GetShape()[axesToMul.second])
{
- case DataLayout::NCHW:
- case DataLayout::NHWC:
- if(inputTensorXInfo.GetNumDimensions() != 4)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor X does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- case DataLayout::NCDHW:
- case DataLayout::NDHWC:
- if(inputTensorXInfo.GetNumDimensions() != 5)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor X does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- default:
- break;
+ throw InvalidArgumentException(descriptorName +
+ ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
}
+ // Shape remains the same as it's square
+ inputXInfoAfterParams = inputXInfoBeforeParams;
+ }
+ else
+ {
+ inputXInfoAfterParams = inputXInfoBeforeParams;
}
- if(m_Parameters.m_DataLayoutY.has_value())
+ if(m_Parameters.m_TransposeY)
{
- switch(m_Parameters.m_DataLayoutY.value())
+ inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
+ BatchMatMulDescriptor::GetPermuteVec(
+ m_Parameters.m_DataLayoutY,
+ inputYInfoBeforeParams.GetShape()));
+ }
+ else if(m_Parameters.m_AdjointY)
+ {
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
+ inputYInfoBeforeParams.GetShape());
+ if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
+ inputYInfoBeforeParams.GetShape()[axesToMul.second])
{
- case DataLayout::NCHW:
- case DataLayout::NHWC:
- if(inputTensorYInfo.GetNumDimensions() != 4)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor Y does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- case DataLayout::NCDHW:
- case DataLayout::NDHWC:
- if(inputTensorYInfo.GetNumDimensions() != 5)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor Y does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- default:
- break;
+ throw InvalidArgumentException(descriptorName +
+ ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
}
+ // Shape remains the same as it's square
+ inputYInfoAfterParams = inputYInfoBeforeParams;
+ }
+ else
+ {
+ inputYInfoAfterParams = inputYInfoBeforeParams;
+ }
+
+ switch(m_Parameters.m_DataLayoutX)
+ {
+ case DataLayout::NCDHW:
+ case DataLayout::NDHWC:
+ if(inputXInfoAfterParams.GetNumDimensions() < 3)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Input tensor X does not have the correct "
+ "number of dimensions for the Data Layout that it has been assigned.");
+ }
+ break;
+ case DataLayout::NCHW:
+ case DataLayout::NHWC:
+ default:
+ break;
+ }
+
+ switch(m_Parameters.m_DataLayoutY)
+ {
+ case DataLayout::NCDHW:
+ case DataLayout::NDHWC:
+ if(inputYInfoAfterParams.GetNumDimensions() < 3)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Input tensor Y does not have the correct "
+ "number of dimensions for the Data Layout that it has been assigned.");
+ }
+ break;
+ case DataLayout::NCHW:
+ case DataLayout::NHWC:
+ default:
+ break;
}
- auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters,
- inputTensorXInfo.GetShape(),
- inputTensorYInfo.GetShape());
+ auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
+ inputXInfoAfterParams.GetShape());
+ auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
+ inputXInfoBeforeParams.GetShape());
- if(inputTensorXInfo.GetShape()[axesToMul.first.second]
- != inputTensorYInfo.GetShape()[axesToMul.second.first])
+ if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
+ != inputYInfoAfterParams.GetShape()[axesYToMul.first])
{
throw InvalidArgumentException(descriptorName +
": The final axis of input tensor X must be the same size as "
"the second last axis of input tensor Y.");
}
- auto axesNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters,
- inputTensorXInfo.GetShape(),
- inputTensorYInfo.GetShape());
-
{ // Separate scope so we don't pollute the rest of the scope with our temp variables
// e.g. NHWC isnt compatible with NCHW as of now
- DataLayout xLayout;
- DataLayout yLayout;
-
- if(m_Parameters.m_DataLayoutX == EmptyOptional())
- {
- xLayout = DataLayout::NCHW; // Not equivalent - I'm just concerned with the last 2 axes
- }
- else
- {
- xLayout = m_Parameters.m_DataLayoutX.value();
- }
-
- if(m_Parameters.m_DataLayoutY == EmptyOptional())
- {
- yLayout = DataLayout::NCHW;
- }
- else
- {
- yLayout = m_Parameters.m_DataLayoutY.value();
- }
+ DataLayout xLayout = m_Parameters.m_DataLayoutX;
+ DataLayout yLayout = m_Parameters.m_DataLayoutY;
if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
{
@@ -4290,8 +4311,8 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
}
// Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
- unsigned int outputTensorDimSize = std::max(inputTensorXInfo.GetNumDimensions(),
- inputTensorYInfo.GetNumDimensions());
+ unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
+ inputYInfoAfterParams.GetNumDimensions());
if(outputTensorDimSize-2 > 0)
{
TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
@@ -4312,12 +4333,17 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
{
- ti.GetShape()[i] = inputTensorXInfo.GetShape()[i];
+ ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
}
};
- doAxisExtension(axesNotMul.first, tiXNotMul);
- doAxisExtension(axesNotMul.second, tiYNotMul);
+ auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
+ inputXInfoAfterParams.GetShape());
+ auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
+ inputYInfoAfterParams.GetShape());
+
+ doAxisExtension(axesXNotMul, tiXNotMul);
+ doAxisExtension(axesYNotMul, tiYNotMul);
for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
{
@@ -4332,42 +4358,6 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
"input_X",
"input_Y");
}
-
- // Also check descriptor parameter validity
- // This will eventually be moved to the start of the function as explained below
- if ((!m_Parameters.m_TransposeX.empty() && !m_Parameters.m_AdjointX.empty()) ||
- (!m_Parameters.m_TransposeY.empty() && !m_Parameters.m_AdjointY.empty()))
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameters - Transpose and Adjoint "
- "vectors cannot both be true for a given input tensor.");
- }
-
- if(m_Parameters.m_TransposeX.size() != 0 && m_Parameters.m_TransposeX.size() != inputTensorXInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Transpose X vector must be "
- "the same size as tensor input X's dimensionality.");
- }
- if(m_Parameters.m_AdjointX.size() != 0 && m_Parameters.m_AdjointX.size() != inputTensorXInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Adjoint X vector must be "
- "the same size as tensor input X's dimensionality.");
- }
- if(m_Parameters.m_TransposeY.size() != 0 && m_Parameters.m_TransposeY.size() != inputTensorYInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Transpose Y vector must be "
- "the same size as tensor input Y's dimensionality.");
- }
- if(m_Parameters.m_AdjointY.size() != 0 && m_Parameters.m_AdjointY.size() != inputTensorXInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Adjoint Y vector must be "
- "the same size as tensor input Y's dimensionality.");
- }
- // Note: for adjoint/transpose, you'll need to do the validation atop the resultant permutation.
}
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
index 41add6e6da..6fcc35ab52 100644
--- a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
@@ -191,7 +191,7 @@ LayerTestResult<T, 3> BatchMatMul3DSimpleTest(
std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
19, 22,
43, 50
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -247,9 +247,7 @@ LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory)
{
- auto descriptor = armnn::BatchMatMulDescriptor(
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NCHW),
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NCHW));
+ auto descriptor = armnn::BatchMatMulDescriptor(); // Default arbitrary layout is treated the same as NCHW
float qScale = 0.0f;
int32_t qOffset = 0;
@@ -282,7 +280,7 @@ LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
19, 22,
43, 50
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
memoryManager,
@@ -338,9 +336,12 @@ LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory)
{
- auto descriptor = armnn::BatchMatMulDescriptor(
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC),
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC));
+ auto descriptor = armnn::BatchMatMulDescriptor(false,
+ false,
+ false,
+ false,
+ armnn::DataLayout::NHWC,
+ armnn::DataLayout::NHWC);
float qScale = 0.0f;
int32_t qOffset = 0;
@@ -373,7 +374,7 @@ LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
19, 22,
43, 50
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
memoryManager,
@@ -471,7 +472,7 @@ LayerTestResult<T, 3> BatchMatMul3DBatchTest(
267, 286,
323, 346
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -566,7 +567,7 @@ LayerTestResult<T, 3> BatchMatMul3DBroadcastTest(
267, 286,
323, 346
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -661,7 +662,7 @@ LayerTestResult<T, 3> BatchMatMul3D2DBroadcastTest(
267, 286,
323, 346
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -717,9 +718,12 @@ LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory)
{
- auto descriptor = armnn::BatchMatMulDescriptor(
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NDHWC),
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC));
+ auto descriptor = armnn::BatchMatMulDescriptor(false,
+ false,
+ false,
+ false,
+ armnn::DataLayout::NDHWC,
+ armnn::DataLayout::NHWC);
float qScale = 0.0f;
int32_t qOffset = 0;
@@ -761,7 +765,7 @@ LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
34, 1079,
46, 1167
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 5>(workloadFactory,
memoryManager,
@@ -959,7 +963,7 @@ LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
88, 100, 142, 106,
39, 61, 78, 56,
72, 52, 98, 70
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -1007,4 +1011,330 @@ template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
BatchMatMul3DNonSquareTest<armnn::DataType::QSymmS16>(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 2> BatchMatMul2DTranspSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(true,
+ false,
+ false,
+ false);
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({2,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({2,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({3,3}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, 2, 3,
+ 4, 5, 6
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 7, 8, 9,
+ 10, 11, 12
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 47, 52, 57,
+ 64, 71, 78,
+ 81, 90, 99
+ }, qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 2>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 2> BatchMatMul2DAdjointSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(false,
+ false,
+ true,
+ false);
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({3,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({3,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({3,3}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 3, 1, 1,
+ 1, 3, -1,
+ 2, 4, 1
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 1, 0, 0,
+ 0, 1, 0,
+ 0, 0, 1
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 7, 3, -4,
+ -3, 1, 4,
+ -2, -10, 8
+ }, qScale, qOffset);
+
+ switch (ArmnnType)
+ {
+ case armnn::DataType::QAsymmU8:
+ outputExpected = armnnUtils::QuantizedVector<T>({
+ 3, 3, 0,
+ 0, 1, 1,
+ 0, 0, 8
+ }, qScale, qOffset);
+ break;
+ default:
+ break;
+ }
+
+ return BatchMatMulTestImpl<ArmnnType, T, 2>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 4> BatchMatMulNHWCParamsTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(false,
+ true,
+ true,
+ false,
+ armnn::DataLayout::NHWC,
+ armnn::DataLayout::NHWC);
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({1,4,4,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({2,2,4,1}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({2,4,2,2}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, -3, 1, 4, 4, 9, 1, 2,
+ 2, 4, 2, 2, 10, 7, 6, -5,
+ 3, 8, 9, 9, 21, 1, 17, 7,
+ 5, 11, 11, 8, 29, 3, 23, 6
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+
+ 9, 10, 11, 12,
+ 13, 14, 15, 16
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 28, 625, 140, 585,
+ 8, 110, -8, 1662,
+ -24, 401, -120, 921,
+ 12, 131, 108, -501,
+
+ 252, 545, 364, 505,
+ -24, 3214, -40, 4766,
+ -216, 1441, -312, 1961,
+ 204, -1133, 300, -1765
+ }, qScale, qOffset);
+
+ switch (ArmnnType)
+ {
+ case armnn::DataType::QAsymmU8:
+ outputExpected = armnnUtils::QuantizedVector<T>({
+ 28, 80, 140, 80,
+ 8, 45, 0, 255,
+ 0, 18, 0, 18,
+ 12, 0, 108, 0,
+
+ 252, 80, 255, 80,
+ 0, 255, 0, 255,
+ 0, 18, 0, 18,
+ 204, 0, 255, 0
+ }, qScale, qOffset);
+ break;
+ default:
+ break;
+ }
+
+ return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory); \ No newline at end of file
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
index 9e2139667b..0b261fba37 100644
--- a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
@@ -82,4 +82,22 @@ template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> BatchMatMul2DTranspSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> BatchMatMul2DAdjointSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> BatchMatMulNHWCParamsTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory); \ No newline at end of file
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 593dc7851e..ae40333658 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -1133,6 +1133,27 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmS8, BatchMatMul3DNonSq
ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmU8, BatchMatMul3DNonSquareTest<DataType::QAsymmU8>);
ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQASymmS16, BatchMatMul3DNonSquareTest<DataType::QSymmS16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleBFloat16, BatchMatMul2DTranspSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleFloat32, BatchMatMul2DTranspSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleFloat16, BatchMatMul2DTranspSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleQAsymmS8, BatchMatMul2DTranspSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleQAsymmU8, BatchMatMul2DTranspSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleQASymmS16,BatchMatMul2DTranspSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleBFloat16, BatchMatMul2DAdjointSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleFloat32, BatchMatMul2DAdjointSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleFloat16, BatchMatMul2DAdjointSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleQAsymmS8, BatchMatMul2DAdjointSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleQAsymmU8, BatchMatMul2DAdjointSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleQASymmS16,BatchMatMul2DAdjointSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsBFloat16, BatchMatMulNHWCParamsTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsFloat32, BatchMatMulNHWCParamsTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsFloat16, BatchMatMulNHWCParamsTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsQAsymmS8, BatchMatMulNHWCParamsTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsQAsymmU8, BatchMatMulNHWCParamsTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsQASymmS16, BatchMatMulNHWCParamsTest<DataType::QSymmS16>);
+
// Batch Norm
ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32, BatchNormFloat32Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32Nhwc, BatchNormFloat32NhwcTest)
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp
index 6693f15760..c592b3b76c 100644
--- a/src/backends/reference/workloads/BatchMatMulImpl.cpp
+++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp
@@ -7,46 +7,53 @@
#include <armnn/backends/WorkloadData.hpp>
#include <armnn/Logging.hpp>
+#include <armnnUtils/Permute.hpp>
namespace armnn
{
-void BatchMatMul::BatchMatMulImpl()
+BatchMatMul::BatchMatMul(const BatchMatMulDescriptor& params,
+ const TensorInfo& inputXInfo,
+ const TensorInfo& inputYInfo,
+ const TensorInfo& outputInfo,
+ Decoder<float>& inputXDecoder,
+ Decoder<float>& inputYDecoder,
+ Encoder<float>& outputEncoder)
+ : params(params),
+ inputXInfo(inputXInfo),
+ inputYInfo(inputYInfo),
+ outputInfo(outputInfo),
+ inputXDecoder(inputXDecoder),
+ inputYDecoder(inputYDecoder),
+ outputEncoder(outputEncoder)
{
- inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape());
- inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape());
+ inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape());
+ inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape());
// At this point, we don't touch the input decoders - just the resultant vectors
- // Pre-transpose and pre-adjoint if their vectors aren't empty
- // and also DataLayouts which may change with permutations/adjoints
+ ApplyParams();
- // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes?
-
- auto idx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
- RecurseBMM(idx, 0);
+ ApplyBatchMatMul();
}
-void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim)
+void BatchMatMul::ApplyBatchMatMul()
{
- // We're working off of the indexes of the output tensor (the max possible shape)
-
- if(!(curDim < outputInfo.GetNumDimensions()))
- {
- // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
+ auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutX,
+ inputXInfo.GetShape());
+ auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY,
+ inputYInfo.GetShape());
+ AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
- auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params,
- inputXInfo.GetShape(),
- inputYInfo.GetShape());
- AdjustAxesToMulForUnequalRanks(axesToMul);
+ unsigned int inputXColDim = axesXToMul.second;
+ unsigned int inputYRowDim = axesYToMul.first;
- unsigned int inputXColDim = axesToMul.first.second;
- unsigned int inputYRowDim = axesToMul.second.first;
-
- unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
+ unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
+ auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx)
+ {
float sum = 0.0f;
- // You could also use inputXColSize
+ // InputYRowSize is synonymous with inputXColSize
for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
auto xIdx = curIdx;
xIdx[inputXColDim] = inputYRowIdx;
@@ -54,24 +61,271 @@ void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int cur
auto yIdx = curIdx;
yIdx[inputYRowDim] = inputYRowIdx;
- sum += (GetValueAt(DataSlot::InputX, xIdx)
- * GetValueAt(DataSlot::InputY, yIdx));
+ sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
}
SetValueAt(sum, DataSlot::Output, curIdx);
+ };
+
+ auto startIdx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
+ RecurseTensor(outputInfo,
+ batchMatMulOperation,
+ startIdx,
+ 0);
+}
+
+void BatchMatMul::ApplyParams()
+{
+ if(params.m_TransposeX)
+ {
+ Transpose(DataSlot::InputX);
+ }
+ else if(params.m_AdjointX)
+ {
+ Adjoint(DataSlot::InputX);
+ }
+ if(params.m_TransposeY)
+ {
+ Transpose(DataSlot::InputY);
+ }
+ else if(params.m_AdjointY)
+ {
+ Adjoint(DataSlot::InputY);
+ }
+}
+
+void BatchMatMul::Transpose(DataSlot type)
+{
+ // AKA the permute of the tensor
+ // This modifies the tensor's info.
+
+ switch(type)
+ {
+ case DataSlot::InputX:
+ {
+ auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX,
+ inputXInfo.GetShape());
+ inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec);
+ std::vector<float> temp(inputXData.size());
+ armnnUtils::Permute(inputXInfo.GetShape(),
+ permuteVec,
+ inputXData.data(),
+ temp.data(),
+ sizeof(float));
+ inputXData = temp;
+ break;
+ }
+ case DataSlot::InputY:
+ {
+ auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY,
+ inputYInfo.GetShape());
+ inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec);
+ std::vector<float> temp(inputYData.size());
+ armnnUtils::Permute(inputYInfo.GetShape(),
+ permuteVec,
+ inputYData.data(),
+ temp.data(),
+ sizeof(float));
+ inputYData = temp;
+ break;
+ }
+ case DataSlot::Output: // We needn't transpose the output tensor
+ default:
+ break;
+ }
+}
+
+void BatchMatMul::Adjoint(DataSlot type)
+{
+ // Finding the adjoint of a square matrix:
+ // Calculate the cofactor of each element (using Gauss elimination here)
+ // Apply a transpose to it (this also modifies the tensor's info)
+
+ TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
+ const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY;
+ const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape());
+
+ ARMNN_ASSERT(inputInfo.GetShape()[axesToAdjoint.first] == inputInfo.GetShape()[axesToAdjoint.second]);
+ // We grab a copy of the tensor data to prevent overwriting
+ std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
+
+ // The sub-matrix is the resultant matrix when the row and column of the current index is removed
+ unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
+ std::vector<std::vector<float>> subMat(subMatAxisSize,
+ std::vector<float>(subMatAxisSize));
+
+ // Lambdas for each sub-step of the cofactor operation
+ auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f)
+ {
+ float diff = std::fabs(a-b);
+ float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
+ return (diff <= bound) || (diff < std::numeric_limits<float>::min());
+ };
+
+ float swapMultiplier = std::numeric_limits<float>::max();
+ auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB)
+ {
+ // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run)
+ for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
+ {
+ float tmp = subMat[rowIdxA][colIdx];
+ subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
+ subMat[rowIdxB][colIdx] = tmp;
+ }
+ swapMultiplier *= -1.0f;
+ };
+
+ auto findNextValidPivotRowIdx = [&](unsigned int colIdx)
+ {
+ unsigned int result = std::numeric_limits<unsigned int>::max();
+
+ // The original diagonal has been checked and is invalid
+ for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
+ {
+ if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
+ {
+ result = rowIdx;
+ break;
+ }
+ }
+ return result;
+ };
+
+ auto eliminate = [&](const float& pivot, unsigned int pivotPos)
+ {
+ for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
+ {
+ float multiplierNumerator = subMat[rowIdx][pivotPos];
+ if(almostEquals(multiplierNumerator, 0.0f))
+ {
+ continue;
+ }
+ float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies
+ // Hence the almostEquals usage to counteract this
+ for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
+ {
+ // We start at col=pivotPos as we have assumed that all elements
+ // to our left have been eliminated to zero already
+
+ // We subtract based on the element directly above us in our pivot row
+ subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
+ }
+ }
+ };
+
+ auto cofactorOperation = [&](const std::vector<unsigned int>& curIdx)
+ {
+ auto row = curIdx[axesToAdjoint.first];
+ auto col = curIdx[axesToAdjoint.second];
+
+ float minorMultiplier = static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
+
+ for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
+ {
+ for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
+ {
+ unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
+ unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
+ auto cloneIdx = curIdx;
+ cloneIdx[axesToAdjoint.first] = outerRow;
+ cloneIdx[axesToAdjoint.second] = outerCol;
+ subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
+ }
+ }
+
+ float determinant = 1.0f;
+
+ // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices
+ switch(subMatAxisSize)
+ {
+ case 0:
+ {
+ determinant = GetValueAt(type, curIdx, inputDataClone);
+ break;
+ }
+ case 1:
+ {
+ // If the resultant sub-matrix is just one element - that's the determinant
+ determinant = subMat[0][0];
+ break;
+ }
+ case 2:
+ {
+ // For a 2x2 sub-matrix, the determinant is just a*d-b*c
+ determinant = subMat[0][0] * subMat[1][1] -
+ subMat[0][1] * subMat[1][0];
+ break;
+ }
+ default:
+ {
+ // Gaussian elimination to find the determinant of this sub-matrix
+ swapMultiplier = 1.0f;
+ // March diagonally down the pivots and if it's invalid (a zero), swap the row with the
+ // nearest non-zero down within the column
+ for(unsigned int pivotRow = 0, pivotCol = 0;
+ pivotRow < subMatAxisSize;
+ pivotRow++, pivotCol++)
+ {
+ float& pivot = subMat[pivotRow][pivotCol];
+
+ if(almostEquals(pivot, 0.0f))
+ {
+ unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
+ if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
+ {
+ // No valid pivot down this column, which means that this pivot remains a zero.
+ // This results in the determinant for this entire sub-matrix to just be zero.
+ determinant = 0.0f;
+ break;
+ }
+ swapRows(pivotRow, nextValidPivotRowIdx);
+ }
+ determinant *= pivot;
+ // The actual elimination bit (which will update/propagate to the pivots down the line)
+ eliminate(pivot, pivotRow); // Synonymous with pivotCol
+ }
+
+ determinant *= swapMultiplier;
+ break;
+ }
+ }
+ float cofactor = minorMultiplier * determinant;
+ SetValueAt(cofactor, type, curIdx);
+ };
+
+ auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
+ RecurseTensor(inputInfo,
+ cofactorOperation,
+ startIdx,
+ 0);
+
+ Transpose(type);
+}
+void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo,
+ const std::function<void(const std::vector<unsigned int>&)>& operation,
+ std::vector<unsigned int>& curIdx,
+ unsigned int curDim)
+{
+ if(!(curDim < tensorInfo.GetNumDimensions()))
+ {
+ // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
+ operation(curIdx);
return;
}
- for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++)
+ for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
{
curIdx[curDim] = i;
- RecurseBMM(curIdx, curDim+1);
+ RecurseTensor(tensorInfo,
+ operation,
+ curIdx,
+ curDim + 1);
}
}
-void BatchMatMul::AdjustAxesToMulForUnequalRanks(
- std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul)
+void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
+ std::pair<unsigned int, unsigned int>& axesYToMul)
{
int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) -
static_cast<int>(inputYInfo.GetNumDimensions());
@@ -82,18 +336,18 @@ void BatchMatMul::AdjustAxesToMulForUnequalRanks(
else if(rankDiff < 0)
{
// Y is the larger one
- axesToMul.first.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
- axesToMul.first.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesXToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesXToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
}
else if(rankDiff > 0)
{
// X is the larger one
- axesToMul.second.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
- axesToMul.second.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesYToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesYToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
}
}
-float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
+float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData)
{
// This gets the data from the input vector that we have, Not the decoder
// But for the output, it is operating on the encoder itself
@@ -101,14 +355,13 @@ float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
AdjustToSafeIdx(type, idx);
unsigned int flatIdx = CalcFlatIdx(type, idx);
float value = 0.0f;
-
switch(type)
{
case DataSlot::InputX:
- value = inputXData[flatIdx];
+ value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
break;
case DataSlot::InputY:
- value = inputYData[flatIdx];
+ value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
break;
case DataSlot::Output:
outputEncoder[flatIdx];
@@ -124,9 +377,7 @@ float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
{
AdjustToSafeIdx(type, idx);
-
unsigned int flatIdx = CalcFlatIdx(type, idx);
-
switch(type)
{
case DataSlot::InputX:
@@ -186,9 +437,7 @@ void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
{
unsigned int result = idx[idx.size()-1];
-
unsigned int dimMultiplier = 1;
-
unsigned int offset;
// -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
@@ -215,17 +464,4 @@ unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned
return result;
}
-template <typename T>
-std::string BatchMatMul::StringifyVec(const std::vector<T>& vec)
-{
- std::string res = "{ ";
- for(auto x : vec)
- {
- res += std::to_string(x);
- res += " ";
- }
- res += "}";
- return res;
-}
-
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.hpp b/src/backends/reference/workloads/BatchMatMulImpl.hpp
index 25b6c85d77..19971a4af3 100644
--- a/src/backends/reference/workloads/BatchMatMulImpl.hpp
+++ b/src/backends/reference/workloads/BatchMatMulImpl.hpp
@@ -15,6 +15,15 @@ namespace armnn
class BatchMatMul {
public:
+ BatchMatMul(const BatchMatMulDescriptor& params,
+ const TensorInfo& inputXInfo,
+ const TensorInfo& inputYInfo,
+ const TensorInfo& outputInfo,
+ Decoder<float>& inputXDecoder,
+ Decoder<float>& inputYDecoder,
+ Encoder<float>& outputEncoder);
+
+private:
enum DataSlot
{
InputX = 0,
@@ -22,31 +31,35 @@ public:
Output = 2
};
- BatchMatMul(const BatchMatMulDescriptor& params,
- const TensorInfo& inputXInfo,
- const TensorInfo& inputYInfo,
- const TensorInfo& outputInfo,
- Decoder<float>& inputXDecoder,
- Decoder<float>& inputYDecoder,
- Encoder<float>& outputEncoder)
- : params(params),
- inputXInfo(inputXInfo),
- inputYInfo(inputYInfo),
- outputInfo(outputInfo),
- inputXDecoder(inputXDecoder),
- inputYDecoder(inputYDecoder),
- outputEncoder(outputEncoder)
- {}
+ const BatchMatMulDescriptor& params;
+ TensorInfo inputXInfo;
+ TensorInfo inputYInfo;
+ TensorInfo outputInfo;
+ Decoder<float>& inputXDecoder;
+ Decoder<float>& inputYDecoder;
+ Encoder<float>& outputEncoder;
- void BatchMatMulImpl();
+ std::vector<float> inputXData;
+ std::vector<float> inputYData;
+
+ void ApplyBatchMatMul();
+
+ void ApplyParams();
+
+ void Transpose(DataSlot type);
- void RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim);
+ void Adjoint(DataSlot type);
+
+ void RecurseTensor(const TensorInfo& tensorInfo,
+ std::function<void(const std::vector<unsigned int>&)> const& operation,
+ std::vector<unsigned int>& curIdx,
+ unsigned int curDim);
// Adjusts it for when input tensors are of unequal rank
- void AdjustAxesToMulForUnequalRanks(
- std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul);
+ void AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
+ std::pair<unsigned int, unsigned int>& axesYToMul);
- float GetValueAt(DataSlot type, std::vector<unsigned int> idx);
+ float GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData = {});
void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
@@ -54,22 +67,6 @@ public:
void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
-
- template <typename T>
- std::string StringifyVec(const std::vector<T>& vec);
-
-private:
- const BatchMatMulDescriptor& params;
- const TensorInfo& inputXInfo;
- const TensorInfo& inputYInfo;
- const TensorInfo& outputInfo;
- Decoder<float>& inputXDecoder;
- Decoder<float>& inputYDecoder;
- Encoder<float>& outputEncoder;
-
- std::vector<float> inputXData;
- std::vector<float> inputYData;
-
};
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
index 388190c4ef..027b93b5d9 100644
--- a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
+++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
@@ -51,9 +51,6 @@ void RefBatchMatMulWorkload::Execute(std::vector<ITensorHandle*> inputs, std::ve
*inputXDecoder,
*inputYDecoder,
*outputEncoder);
-
- bmm.BatchMatMulImpl();
-
}
} // namespace armnn \ No newline at end of file