From 0e3fe10bfe1b4f006f6e0c5c2fae8fb5515c7544 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Mon, 23 Jan 2023 19:32:06 +0000 Subject: IVGCVSW-7455 Workaround to allow CLBatchMatMul to parse some 4D models * Added ability to reduce dimension sizes when calling BuildArmComputeTensorInfo or BuildArmComputeTensorShapes, this will attempt to remove leading 1s in order to squeeze the number of dimensions but retain the size. * Changed ClBatchMatMulWorkload to attempt to squeeze the number of dimensions to 3 as the CL Gemm Kernel can only support up to 3 dimensions. Signed-off-by: Mike Kelly Change-Id: I6b3d0886c5b97fdb686838fc3dc292833ddc4643 --- src/backends/cl/test/ClLayerTests.cpp | 5 ++- .../cl/workloads/ClBatchMatMulWorkload.cpp | 47 +++++++++++++++------- 2 files changed, 36 insertions(+), 16 deletions(-) (limited to 'src/backends/cl') diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp index 4ba2a9ec3b..10e2a54c9f 100644 --- a/src/backends/cl/test/ClLayerTests.cpp +++ b/src/backends/cl/test/ClLayerTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -80,6 +80,9 @@ ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul2DSimpleFloat32, ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DSimpleFloat32, ClContextControlFixture, BatchMatMul3DSimpleTest); +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMulNCHWSimpleFloat32, + ClContextControlFixture, + BatchMatMulNCHWSimpleTest); ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DBatchFloat32, ClContextControlFixture, BatchMatMul3DBatchTest); diff --git a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp index ece87c2672..f21666b90a 100644 --- a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp +++ b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp @@ -13,6 +13,7 @@ #include #include +#include #include @@ -24,6 +25,7 @@ namespace armnn { + arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX, const TensorInfo& inputY, const TensorInfo& output, @@ -42,36 +44,41 @@ arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX, arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK); arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK); - const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX); - const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY); - const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output); + // ClGemmMatrixMultiplyNativeKernel used by CLGEMM can only support 3 dimensional + // tensors so try to reduce the dimensions to 3 + const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX, 3); + const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY, 3); + const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayoutY, 3); arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo(); arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo(); if (descriptor.m_TransposeX == true) { - auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputX.GetNumDimensions()); + armnn::TensorInfo inputXStripped = armnnUtils::ReduceDims(inputX, 3); + + auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputXStripped.GetNumDimensions()); const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector); - const TensorInfo permutedXInfo = armnnUtils::Permuted(inputX, permutationXVector); - aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo); + const TensorInfo permutedXInfo = armnnUtils::Permuted(inputXStripped, permutationXVector); + aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo, 3); statusPermuteX = arm_compute::CLPermute::validate(&aclInputXInfo, &aclPermutedXInfo, aclPermutationXVector); } - if ( descriptor.m_TransposeY == true) + if (descriptor.m_TransposeY == true) { - auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputY.GetNumDimensions()); + armnn::TensorInfo inputYStripped = armnnUtils::ReduceDims(inputY, 3); + + auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputYStripped.GetNumDimensions()); const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector); - const TensorInfo permutedYInfo = armnnUtils::Permuted(inputY, permutationYVector); - aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo); + const TensorInfo permutedYInfo = armnnUtils::Permuted(inputYStripped, permutationYVector); + aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo, 3); statusPermuteY = arm_compute::CLPermute::validate(&aclInputYInfo, &aclPermutedYInfo, aclPermutationYVector); - } const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped @@ -133,16 +140,24 @@ ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& d arm_compute::ICLTensor& output = PolymorphicDowncast(m_Data.m_Outputs[0])->GetTensor(); inputX.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX)); + arm_compute::TensorShape inputXTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape( + info.m_InputTensorInfos[0].GetShape(), 3); + inputX.info()->set_tensor_shape(inputXTensorInfo); inputY.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY)); + arm_compute::TensorShape inputYTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape( + info.m_InputTensorInfos[1].GetShape(), 3); + inputY.info()->set_tensor_shape(inputYTensorInfo); arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo(); arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo(); if (descriptor.m_Parameters.m_TransposeX == true) { + armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[0], 3); + armnn::PermutationVector permutationXVector - = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[0].GetNumDimensions()); - const TensorInfo permutedXInfo = armnnUtils::Permuted(info.m_InputTensorInfos[0], permutationXVector); + = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions()); + const TensorInfo permutedXInfo = armnnUtils::Permuted(strippedInfo, permutationXVector); const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector); armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo); armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX); @@ -157,9 +172,11 @@ ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& d if (descriptor.m_Parameters.m_TransposeY == true) { + armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[1], 3); + armnn::PermutationVector permutationYVector - = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[1].GetNumDimensions()); - const TensorInfo permutedYInfo = armnnUtils::Permuted(info.m_InputTensorInfos[1], permutationYVector); + = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions()); + const TensorInfo permutedYInfo = armnnUtils::Permuted(strippedInfo, permutationYVector); const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector); armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo); armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY); -- cgit v1.2.1