aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/workloads/ClBatchMatMulWorkload.cpp')
-rw-r--r--src/backends/cl/workloads/ClBatchMatMulWorkload.cpp47
1 files changed, 32 insertions, 15 deletions
diff --git a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
index ece87c2672..f21666b90a 100644
--- a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
+++ b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
@@ -13,6 +13,7 @@
#include <armnn/utility/PolymorphicDowncast.hpp>
#include <armnnUtils/Permute.hpp>
+#include <armnnUtils/TensorUtils.hpp>
#include <backendsCommon/WorkloadUtils.hpp>
@@ -24,6 +25,7 @@
namespace armnn
{
+
arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
const TensorInfo& inputY,
const TensorInfo& output,
@@ -42,36 +44,41 @@ arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
- const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX);
- const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY);
- const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+ // ClGemmMatrixMultiplyNativeKernel used by CLGEMM can only support 3 dimensional
+ // tensors so try to reduce the dimensions to 3
+ const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX, 3);
+ const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY, 3);
+ const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayoutY, 3);
arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
if (descriptor.m_TransposeX == true)
{
- auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputX.GetNumDimensions());
+ armnn::TensorInfo inputXStripped = armnnUtils::ReduceDims(inputX, 3);
+
+ auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputXStripped.GetNumDimensions());
const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
- const TensorInfo permutedXInfo = armnnUtils::Permuted(inputX, permutationXVector);
- aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo);
+ const TensorInfo permutedXInfo = armnnUtils::Permuted(inputXStripped, permutationXVector);
+ aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo, 3);
statusPermuteX = arm_compute::CLPermute::validate(&aclInputXInfo,
&aclPermutedXInfo,
aclPermutationXVector);
}
- if ( descriptor.m_TransposeY == true)
+ if (descriptor.m_TransposeY == true)
{
- auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputY.GetNumDimensions());
+ armnn::TensorInfo inputYStripped = armnnUtils::ReduceDims(inputY, 3);
+
+ auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputYStripped.GetNumDimensions());
const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
- const TensorInfo permutedYInfo = armnnUtils::Permuted(inputY, permutationYVector);
- aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo);
+ const TensorInfo permutedYInfo = armnnUtils::Permuted(inputYStripped, permutationYVector);
+ aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo, 3);
statusPermuteY = arm_compute::CLPermute::validate(&aclInputYInfo,
&aclPermutedYInfo,
aclPermutationYVector);
-
}
const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
@@ -133,16 +140,24 @@ ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& d
arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
inputX.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX));
+ arm_compute::TensorShape inputXTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
+ info.m_InputTensorInfos[0].GetShape(), 3);
+ inputX.info()->set_tensor_shape(inputXTensorInfo);
inputY.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY));
+ arm_compute::TensorShape inputYTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
+ info.m_InputTensorInfos[1].GetShape(), 3);
+ inputY.info()->set_tensor_shape(inputYTensorInfo);
arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
if (descriptor.m_Parameters.m_TransposeX == true)
{
+ armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[0], 3);
+
armnn::PermutationVector permutationXVector
- = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[0].GetNumDimensions());
- const TensorInfo permutedXInfo = armnnUtils::Permuted(info.m_InputTensorInfos[0], permutationXVector);
+ = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
+ const TensorInfo permutedXInfo = armnnUtils::Permuted(strippedInfo, permutationXVector);
const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
@@ -157,9 +172,11 @@ ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& d
if (descriptor.m_Parameters.m_TransposeY == true)
{
+ armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[1], 3);
+
armnn::PermutationVector permutationYVector
- = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[1].GetNumDimensions());
- const TensorInfo permutedYInfo = armnnUtils::Permuted(info.m_InputTensorInfos[1], permutationYVector);
+ = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
+ const TensorInfo permutedYInfo = armnnUtils::Permuted(strippedInfo, permutationYVector);
const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY);