aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2023-01-23 19:32:06 +0000
committerTeresaARM <teresa.charlinreyes@arm.com>2023-01-24 17:01:30 +0000
commit0e3fe10bfe1b4f006f6e0c5c2fae8fb5515c7544 (patch)
tree222ff6eb1c034efa05bc5dcf4b255f80993987bf
parentd134c13ec9a0585bb7656654e0e65c57958d8833 (diff)
downloadarmnn-0e3fe10bfe1b4f006f6e0c5c2fae8fb5515c7544.tar.gz
IVGCVSW-7455 Workaround to allow CLBatchMatMul to parse some 4D models
* Added ability to reduce dimension sizes when calling BuildArmComputeTensorInfo or BuildArmComputeTensorShapes, this will attempt to remove leading 1s in order to squeeze the number of dimensions but retain the size. * Changed ClBatchMatMulWorkload to attempt to squeeze the number of dimensions to 3 as the CL Gemm Kernel can only support up to 3 dimensions. Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I6b3d0886c5b97fdb686838fc3dc292833ddc4643
-rw-r--r--delegate/src/test/BatchMatMulTest.cpp3
-rw-r--r--include/armnnUtils/TensorUtils.hpp7
-rw-r--r--src/armnnUtils/TensorUtils.cpp36
-rw-r--r--src/armnnUtils/test/TensorUtilsTest.cpp58
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.cpp69
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.hpp24
-rw-r--r--src/backends/cl/test/ClLayerTests.cpp5
-rw-r--r--src/backends/cl/workloads/ClBatchMatMulWorkload.cpp47
8 files changed, 227 insertions, 22 deletions
diff --git a/delegate/src/test/BatchMatMulTest.cpp b/delegate/src/test/BatchMatMulTest.cpp
index d13d8dcf43..06ad2c3be2 100644
--- a/delegate/src/test/BatchMatMulTest.cpp
+++ b/delegate/src/test/BatchMatMulTest.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -677,6 +677,7 @@ namespace armnnDelegate
std::vector <armnn::BackendId> backends = {armnn::Compute::GpuAcc};
BatchMatMul2DFp32SimpleTest (backends);
BatchMatMul3DFp32SimpleTest (backends);
+ BatchMatMul4DFp32SimpleTest (backends);
BatchMatMul3DFp32BatchTest (backends);
BatchMatMul3DFp32BroadcastTest (backends);
BatchMatMul3D2DFp32BroadcastTest (backends);
diff --git a/include/armnnUtils/TensorUtils.hpp b/include/armnnUtils/TensorUtils.hpp
index 2d6ec2fea4..a2aa9b0a98 100644
--- a/include/armnnUtils/TensorUtils.hpp
+++ b/include/armnnUtils/TensorUtils.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2019,2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2018-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -8,6 +8,7 @@
#include <armnn/TypesUtils.hpp>
#include <armnn/Tensor.hpp>
#include <armnn/Types.hpp>
+#include <armnnUtils/TensorUtils.hpp>
#include <utility>
#include <vector>
@@ -41,6 +42,10 @@ armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches,
std::pair<float, float> FindMinMax(armnn::ITensorHandle* tensorHandle);
+armnn::TensorShape ReduceDims(const armnn::TensorShape& tensorInfo, unsigned int dimensions);
+
+armnn::TensorInfo ReduceDims(const armnn::TensorInfo& tensorInfo, unsigned int dimensions);
+
armnn::TensorShape ExpandDims(const armnn::TensorShape& tensorShape, int axis);
std::vector<unsigned int> SqueezeDims(const armnn::TensorShape& tensorShape);
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp
index 9e3d719211..03109e0cee 100644
--- a/src/armnnUtils/TensorUtils.cpp
+++ b/src/armnnUtils/TensorUtils.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -103,6 +103,40 @@ std::pair<float, float> FindMinMax(ITensorHandle* tensorHandle)
return std::make_pair(min, max);
}
+TensorShape ReduceDims(const TensorShape& tensorShape, unsigned int dimensions)
+{
+ if (tensorShape.GetNumDimensions() <= dimensions)
+ {
+ return tensorShape;
+ }
+ std::vector<unsigned int> newShape;
+
+ unsigned int dimsToSkip = tensorShape.GetNumDimensions() - dimensions;
+ unsigned int dimsSkipped = 0;
+ bool insertRemainder = false;
+
+ for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
+ {
+ if (tensorShape[i] == 1 && dimsSkipped < dimsToSkip && !insertRemainder)
+ {
+ ++dimsSkipped;
+ continue;
+ }
+ newShape.push_back(tensorShape[i]);
+ // Once we insert the first dimension we can't skip any more
+ insertRemainder = true;
+ }
+ return TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data());
+}
+
+TensorInfo ReduceDims(const TensorInfo& tensorInfo, unsigned int dimensions)
+{
+ TensorInfo strippedTensor(tensorInfo);
+ TensorShape strippedShape = ReduceDims(tensorInfo.GetShape(), dimensions);
+ strippedTensor.SetShape(strippedShape);
+ return strippedTensor;
+}
+
TensorShape ExpandDims(const TensorShape& tensorShape, int axis)
{
unsigned int outputDim = tensorShape.GetNumDimensions() + 1;
diff --git a/src/armnnUtils/test/TensorUtilsTest.cpp b/src/armnnUtils/test/TensorUtilsTest.cpp
index 16349c554e..a69a0098ce 100644
--- a/src/armnnUtils/test/TensorUtilsTest.cpp
+++ b/src/armnnUtils/test/TensorUtilsTest.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2019,2021-2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2019,2021-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -126,6 +126,62 @@ TEST_CASE("ExpandDimsInvalidAxisTest")
CHECK_THROWS_AS(ExpandDims(inputShape, 4), armnn::InvalidArgumentException);
}
+TEST_CASE("ReduceDimsShapeAll1s")
+{
+ armnn::TensorShape inputShape({ 1, 1, 1 });
+
+ // Invalid expand dimension 4
+ armnn::TensorShape outputShape = ReduceDims(inputShape, 2);
+ CHECK(outputShape.GetNumDimensions() == 2);
+ CHECK(outputShape[0] == 1);
+ CHECK(outputShape[1] == 1);
+}
+
+TEST_CASE("ReduceDimsShapeNotEnough1s")
+{
+ armnn::TensorShape inputShape({ 1, 2, 1 });
+
+ // Invalid expand dimension 4
+ armnn::TensorShape outputShape = ReduceDims(inputShape, 1);
+ CHECK(outputShape.GetNumDimensions() == 2);
+ CHECK(outputShape[0] == 2);
+ CHECK(outputShape[1] == 1);
+}
+
+TEST_CASE("ReduceDimsInfoAll1s")
+{
+ armnn::TensorInfo inputInfo({ 1, 1, 1 }, DataType::Float32);
+
+ // Invalid expand dimension 4
+ armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 2);
+ CHECK(outputInfo.GetShape().GetNumDimensions() == 2);
+ CHECK(outputInfo.GetShape()[0] == 1);
+ CHECK(outputInfo.GetShape()[1] == 1);
+}
+
+TEST_CASE("ReduceDimsInfoNotEnough1s")
+{
+ armnn::TensorInfo inputInfo({ 1, 2, 1 }, DataType::Float32);
+
+ // Invalid expand dimension 4
+ armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 1);
+ CHECK(outputInfo.GetNumDimensions() == 2);
+ CHECK(outputInfo.GetShape()[0] == 2);
+ CHECK(outputInfo.GetShape()[1] == 1);
+}
+
+TEST_CASE("ReduceDimsShapeDimensionGreaterThanSize")
+{
+ armnn::TensorShape inputShape({ 1, 1, 1 });
+
+ // Invalid expand dimension 4
+ armnn::TensorShape outputShape = ReduceDims(inputShape, 4);
+ CHECK(outputShape.GetNumDimensions() == 3);
+ CHECK(outputShape[0] == 1);
+ CHECK(outputShape[1] == 1);
+ CHECK(outputShape[2] == 1);
+}
+
TEST_CASE("ExpandDimsInvalidNegativeAxisTest")
{
armnn::TensorShape inputShape({ 2, 3, 4 });
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index 38c7f70da5..e6c5a9b41c 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include <aclCommon/ArmComputeTensorUtils.hpp>
@@ -146,6 +146,51 @@ arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& te
return shape;
}
+std::vector<unsigned int> ReduceDimsForACL(const armnn::TensorShape tensorShape, unsigned int dimensions)
+{
+ std::vector<unsigned int> newShape;
+
+ unsigned int dimsToSkip = 0;
+
+ if (tensorShape.GetNumDimensions() > dimensions)
+ {
+ dimsToSkip = tensorShape.GetNumDimensions() - dimensions;
+ }
+ unsigned int dimsSkipped = 0;
+ bool insertRemainder = false;
+
+ for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
+ {
+ if (tensorShape[i] == 1 && dimsSkipped < dimsToSkip && !insertRemainder)
+ {
+ ++dimsSkipped;
+ continue;
+ }
+ newShape.insert(newShape.begin(), tensorShape[i]);
+ // Once we insert the first dimension we can't skip any more
+ insertRemainder = true;
+ }
+ return newShape;
+}
+
+arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape, unsigned int dimensions)
+{
+ arm_compute::TensorShape shape;
+ std::vector<unsigned int> strippedShape = ReduceDimsForACL(tensorShape, dimensions);
+
+ for (unsigned int i = 0; i < strippedShape.size(); i++)
+ {
+ shape.set(i, strippedShape[i], false);
+ }
+
+ // prevent arm_compute issue where tensor is flattened to nothing
+ if (shape.num_dimensions() == 0)
+ {
+ shape.set_num_dimensions(1);
+ }
+ return shape;
+}
+
// Utility function used to build a TensorInfo object, that can be used to initialise
// ARM Compute Tensor and CLTensor allocators.
// Note: this utility ignores the value of armnn::TensorInfo.IsConstant(). ACL tensors
@@ -174,6 +219,28 @@ arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tenso
return aclTensorInfo;
}
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo, unsigned int dimensions)
+{
+ bool multiScales = tensorInfo.HasMultipleQuantizationScales();
+ const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape(), dimensions);
+ const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales);
+
+ const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ?
+ arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
+ arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
+
+ return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
+}
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
+ armnn::DataLayout dataLayout, unsigned int dimensions)
+{
+ arm_compute::TensorInfo aclTensorInfo = BuildArmComputeTensorInfo(tensorInfo, dimensions);
+ aclTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
+
+ return aclTensorInfo;
+}
+
+
arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)
{
switch(dataLayout)
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.hpp b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
index 6ddecf2aaa..1f07fa949c 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.hpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -36,16 +36,38 @@ arm_compute::Coordinates BuildArmComputeReductionCoordinates(size_t inputDimensi
/// Utility function used to setup an arm_compute::TensorShape object from an armnn::TensorShape.
arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape);
+/// Utility function used to setup an arm_compute::TensorShape object from an armnn::TensorShape. This will
+/// attempt to reduce the number of leading 1s until the dimension length is equal to the dimensions passed in.
+arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape, unsigned int dimensions);
+
/// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
/// armnn::ITensorInfo.
arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo);
/// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
+/// armnn::ITensorInfo. This will attempt to reduce the number of leading 1s until the dimension length is equal
+/// to the dimensions passed in.
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo, unsigned int dimensions);
+
+/// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
+/// armnn::ITensorInfo. This will attempt to reduce the number of leading 1s until the dimension length is equal
+/// to the dimensions passed in.
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
+ armnn::DataLayout dataLayout,
+ unsigned int dimensions);
+
+/// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
/// armnn::ITensorInfo.
/// armnn::DataLayout.
arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
armnn::DataLayout dataLayout);
+/// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
+/// armnn::ITensorInfo. This will attempt to reduce the number of leading 1s until the dimension length is equal
+/// to the dimensions passed in.
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
+ armnn::DataLayout dataLayout, unsigned int dimensions);
+
/// Utility function used to convert armnn::DataLayout to arm_compute::DataLayout
/// armnn::DataLayout.
arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout);
diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp
index 4ba2a9ec3b..10e2a54c9f 100644
--- a/src/backends/cl/test/ClLayerTests.cpp
+++ b/src/backends/cl/test/ClLayerTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -80,6 +80,9 @@ ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul2DSimpleFloat32,
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DSimpleFloat32,
ClContextControlFixture,
BatchMatMul3DSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMulNCHWSimpleFloat32,
+ ClContextControlFixture,
+ BatchMatMulNCHWSimpleTest<DataType::Float32>);
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DBatchFloat32,
ClContextControlFixture,
BatchMatMul3DBatchTest<DataType::Float32>);
diff --git a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
index ece87c2672..f21666b90a 100644
--- a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
+++ b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
@@ -13,6 +13,7 @@
#include <armnn/utility/PolymorphicDowncast.hpp>
#include <armnnUtils/Permute.hpp>
+#include <armnnUtils/TensorUtils.hpp>
#include <backendsCommon/WorkloadUtils.hpp>
@@ -24,6 +25,7 @@
namespace armnn
{
+
arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
const TensorInfo& inputY,
const TensorInfo& output,
@@ -42,36 +44,41 @@ arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
- const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX);
- const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY);
- const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+ // ClGemmMatrixMultiplyNativeKernel used by CLGEMM can only support 3 dimensional
+ // tensors so try to reduce the dimensions to 3
+ const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX, 3);
+ const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY, 3);
+ const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayoutY, 3);
arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
if (descriptor.m_TransposeX == true)
{
- auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputX.GetNumDimensions());
+ armnn::TensorInfo inputXStripped = armnnUtils::ReduceDims(inputX, 3);
+
+ auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputXStripped.GetNumDimensions());
const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
- const TensorInfo permutedXInfo = armnnUtils::Permuted(inputX, permutationXVector);
- aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo);
+ const TensorInfo permutedXInfo = armnnUtils::Permuted(inputXStripped, permutationXVector);
+ aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo, 3);
statusPermuteX = arm_compute::CLPermute::validate(&aclInputXInfo,
&aclPermutedXInfo,
aclPermutationXVector);
}
- if ( descriptor.m_TransposeY == true)
+ if (descriptor.m_TransposeY == true)
{
- auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputY.GetNumDimensions());
+ armnn::TensorInfo inputYStripped = armnnUtils::ReduceDims(inputY, 3);
+
+ auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputYStripped.GetNumDimensions());
const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
- const TensorInfo permutedYInfo = armnnUtils::Permuted(inputY, permutationYVector);
- aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo);
+ const TensorInfo permutedYInfo = armnnUtils::Permuted(inputYStripped, permutationYVector);
+ aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo, 3);
statusPermuteY = arm_compute::CLPermute::validate(&aclInputYInfo,
&aclPermutedYInfo,
aclPermutationYVector);
-
}
const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
@@ -133,16 +140,24 @@ ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& d
arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
inputX.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX));
+ arm_compute::TensorShape inputXTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
+ info.m_InputTensorInfos[0].GetShape(), 3);
+ inputX.info()->set_tensor_shape(inputXTensorInfo);
inputY.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY));
+ arm_compute::TensorShape inputYTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
+ info.m_InputTensorInfos[1].GetShape(), 3);
+ inputY.info()->set_tensor_shape(inputYTensorInfo);
arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
if (descriptor.m_Parameters.m_TransposeX == true)
{
+ armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[0], 3);
+
armnn::PermutationVector permutationXVector
- = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[0].GetNumDimensions());
- const TensorInfo permutedXInfo = armnnUtils::Permuted(info.m_InputTensorInfos[0], permutationXVector);
+ = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
+ const TensorInfo permutedXInfo = armnnUtils::Permuted(strippedInfo, permutationXVector);
const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
@@ -157,9 +172,11 @@ ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& d
if (descriptor.m_Parameters.m_TransposeY == true)
{
+ armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[1], 3);
+
armnn::PermutationVector permutationYVector
- = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[1].GetNumDimensions());
- const TensorInfo permutedYInfo = armnnUtils::Permuted(info.m_InputTensorInfos[1], permutationYVector);
+ = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
+ const TensorInfo permutedYInfo = armnnUtils::Permuted(strippedInfo, permutationYVector);
const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY);