aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/workloads/ClBatchMatMulWorkload.cpp')
-rw-r--r--src/backends/cl/workloads/ClBatchMatMulWorkload.cpp187
1 files changed, 40 insertions, 147 deletions
diff --git a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
index f21666b90a..bd0fd51617 100644
--- a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
+++ b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
@@ -12,24 +12,19 @@
#include <armnn/utility/PolymorphicDowncast.hpp>
-#include <armnnUtils/Permute.hpp>
-#include <armnnUtils/TensorUtils.hpp>
-
#include <backendsCommon/WorkloadUtils.hpp>
#include <cl/ClTensorHandle.hpp>
-#include <arm_compute/runtime/CL/functions/CLGEMM.h>
-#include <arm_compute/runtime/CL/functions/CLPermute.h>
-
namespace armnn
{
-arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
- const TensorInfo& inputY,
- const TensorInfo& output,
- const BatchMatMulDescriptor& descriptor)
+arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputInfoX,
+ const TensorInfo& inputInfoY,
+ const TensorInfo& outputInfo,
+ const BatchMatMulDescriptor& descriptor,
+ const ActivationDescriptor* activationDescriptor)
{
if (descriptor.m_AdjointX || descriptor.m_AdjointY )
{
@@ -40,76 +35,23 @@ arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
throw Exception("Only supported the MatMul in the last 2 dimensions");
}
- arm_compute::Status statusGEMM = arm_compute::Status(arm_compute::ErrorCode::OK);
- arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
- arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
-
- // ClGemmMatrixMultiplyNativeKernel used by CLGEMM can only support 3 dimensional
- // tensors so try to reduce the dimensions to 3
- const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX, 3);
- const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY, 3);
- const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayoutY, 3);
-
- arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
- arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
-
- if (descriptor.m_TransposeX == true)
- {
- armnn::TensorInfo inputXStripped = armnnUtils::ReduceDims(inputX, 3);
-
- auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputXStripped.GetNumDimensions());
- const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
- const TensorInfo permutedXInfo = armnnUtils::Permuted(inputXStripped, permutationXVector);
- aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo, 3);
-
- statusPermuteX = arm_compute::CLPermute::validate(&aclInputXInfo,
- &aclPermutedXInfo,
- aclPermutationXVector);
- }
-
- if (descriptor.m_TransposeY == true)
- {
- armnn::TensorInfo inputYStripped = armnnUtils::ReduceDims(inputY, 3);
-
- auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputYStripped.GetNumDimensions());
- const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
- const TensorInfo permutedYInfo = armnnUtils::Permuted(inputYStripped, permutationYVector);
- aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo, 3);
-
- statusPermuteY = arm_compute::CLPermute::validate(&aclInputYInfo,
- &aclPermutedYInfo,
- aclPermutationYVector);
- }
-
- const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
- false, // is inputY reshaped
- false); // is inputY reshaped only 1st run
+ arm_compute::TensorInfo aclInputInfoX = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoX);
+ arm_compute::TensorInfo aclInputInfoY = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoY);
+ const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(outputInfo);
+ // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set
+ aclInputInfoX.set_are_values_constant(false);
+ aclInputInfoY.set_are_values_constant(false);
- statusGEMM = arm_compute::CLGEMM::validate(descriptor.m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo,
- descriptor.m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo,
- nullptr,
- &aclOutputInfo,
- 1.0,
- 0,
- gemm_info);
+ const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo(
+ activationDescriptor);
- if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK &&
- statusPermuteY.error_code() == arm_compute::ErrorCode::OK &&
- statusGEMM.error_code() == arm_compute::ErrorCode::OK)
- {
- return arm_compute::Status(arm_compute::ErrorCode::OK,
- "All Batch Mat Mul layers validate status OK.");
- }
- else
- {
- return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
- "BatchMatMul layer validate status failed."
- + statusGEMM.error_description()
- + statusPermuteX.error_description()
- + statusPermuteY.error_description());
- }
+ arm_compute::MatMulInfo matMulInfo;
+ matMulInfo.adj_lhs(descriptor.m_TransposeX);
+ matMulInfo.adj_rhs(descriptor.m_TransposeY);
+ matMulInfo.fused_activation(activationInfo);
+ return arm_compute::CLMatMul::validate(&aclInputInfoX, &aclInputInfoY, &aclOutputInfo, matMulInfo);
}
ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
@@ -135,86 +77,37 @@ ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& d
m_Data.ValidateInputsOutputs("ClBatchMatMulWorkload", 2, 1);
- const arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
- const arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
- arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+ arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+ arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
+ auto outputHandle = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0]);
+ arm_compute::ICLTensor& output = outputHandle->GetTensor();
- inputX.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX));
- arm_compute::TensorShape inputXTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
- info.m_InputTensorInfos[0].GetShape(), 3);
- inputX.info()->set_tensor_shape(inputXTensorInfo);
- inputY.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY));
- arm_compute::TensorShape inputYTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
- info.m_InputTensorInfos[1].GetShape(), 3);
- inputY.info()->set_tensor_shape(inputYTensorInfo);
+ // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set
+ inputX.info()->set_are_values_constant(false);
+ inputY.info()->set_are_values_constant(false);
- arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
- arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
+ const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
- if (descriptor.m_Parameters.m_TransposeX == true)
- {
- armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[0], 3);
-
- armnn::PermutationVector permutationXVector
- = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
- const TensorInfo permutedXInfo = armnnUtils::Permuted(strippedInfo, permutationXVector);
- const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
- armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
- armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
-
- auto permuteLayerX = std::make_unique<arm_compute::CLPermute>();
- permuteLayerX->configure(clCompileContext,
- &inputX,
- &m_PermutedTensorX,
- aclPermutationXVector);
- m_PermuteLayerX.reset(permuteLayerX.release());
- }
+ arm_compute::MatMulInfo matMulInfo;
+ matMulInfo.adj_lhs(descriptor.m_Parameters.m_TransposeX);
+ matMulInfo.adj_rhs(descriptor.m_Parameters.m_TransposeY);
+ matMulInfo.fused_activation(activationInfo);
- if (descriptor.m_Parameters.m_TransposeY == true)
- {
- armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[1], 3);
-
- armnn::PermutationVector permutationYVector
- = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
- const TensorInfo permutedYInfo = armnnUtils::Permuted(strippedInfo, permutationYVector);
- const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
- armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
- armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
-
- auto permuteLayerY = std::make_unique<arm_compute::CLPermute>();
- permuteLayerY->configure(clCompileContext,
- &inputY,
- &m_PermutedTensorY,
- aclPermutationYVector);
- m_PermuteLayerY.reset(permuteLayerY.release());
- }
+ m_MatMulLayer.configure(clCompileContext, &inputX, &inputY, &output, matMulInfo);
- const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
- false, // is inputY reshaped
- false); // is inputY reshaped only 1st run
- auto gemmLayer = std::make_unique<arm_compute::CLGEMM>();
- gemmLayer->configure(clCompileContext,
- descriptor.m_Parameters.m_TransposeX ? &m_PermutedTensorX : &inputX,
- descriptor.m_Parameters.m_TransposeY ? &m_PermutedTensorY : &inputY,
- nullptr,
- &output,
- 1.0,
- 0,
- gemm_info);
- m_GEMMLayer.reset(gemmLayer.release());
+ // Report Profiling Details
+ WorkloadInfo detailsInfo;
+ detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos;
+ detailsInfo.m_OutputTensorInfos = info.m_OutputTensorInfos;
+ ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchMatMulWorkload_Construct",
+ descriptor.m_Parameters,
+ detailsInfo,
+ GetGuid());
}
void ClBatchMatMulWorkload::Execute() const
{
ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClBatchMatMulWorkload_Execute", this->GetGuid());
- if (m_PermuteLayerX)
- {
- m_PermuteLayerX->run();
- }
- if (m_PermuteLayerY)
- {
- m_PermuteLayerY->run();
- }
- m_GEMMLayer->run();
+ RunClFunction(m_MatMulLayer, CHECK_LOCATION());
}
} //namespace armnn