diff options
Diffstat (limited to 'src/backends/cl/workloads/ClBatchMatMulWorkload.cpp')
-rw-r--r-- | src/backends/cl/workloads/ClBatchMatMulWorkload.cpp | 187 |
1 files changed, 40 insertions, 147 deletions
diff --git a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp index f21666b90a..bd0fd51617 100644 --- a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp +++ b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp @@ -12,24 +12,19 @@ #include <armnn/utility/PolymorphicDowncast.hpp> -#include <armnnUtils/Permute.hpp> -#include <armnnUtils/TensorUtils.hpp> - #include <backendsCommon/WorkloadUtils.hpp> #include <cl/ClTensorHandle.hpp> -#include <arm_compute/runtime/CL/functions/CLGEMM.h> -#include <arm_compute/runtime/CL/functions/CLPermute.h> - namespace armnn { -arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX, - const TensorInfo& inputY, - const TensorInfo& output, - const BatchMatMulDescriptor& descriptor) +arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputInfoX, + const TensorInfo& inputInfoY, + const TensorInfo& outputInfo, + const BatchMatMulDescriptor& descriptor, + const ActivationDescriptor* activationDescriptor) { if (descriptor.m_AdjointX || descriptor.m_AdjointY ) { @@ -40,76 +35,23 @@ arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX, throw Exception("Only supported the MatMul in the last 2 dimensions"); } - arm_compute::Status statusGEMM = arm_compute::Status(arm_compute::ErrorCode::OK); - arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK); - arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK); - - // ClGemmMatrixMultiplyNativeKernel used by CLGEMM can only support 3 dimensional - // tensors so try to reduce the dimensions to 3 - const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX, 3); - const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY, 3); - const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayoutY, 3); - - arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo(); - arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo(); - - if (descriptor.m_TransposeX == true) - { - armnn::TensorInfo inputXStripped = armnnUtils::ReduceDims(inputX, 3); - - auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputXStripped.GetNumDimensions()); - const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector); - const TensorInfo permutedXInfo = armnnUtils::Permuted(inputXStripped, permutationXVector); - aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo, 3); - - statusPermuteX = arm_compute::CLPermute::validate(&aclInputXInfo, - &aclPermutedXInfo, - aclPermutationXVector); - } - - if (descriptor.m_TransposeY == true) - { - armnn::TensorInfo inputYStripped = armnnUtils::ReduceDims(inputY, 3); - - auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputYStripped.GetNumDimensions()); - const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector); - const TensorInfo permutedYInfo = armnnUtils::Permuted(inputYStripped, permutationYVector); - aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo, 3); - - statusPermuteY = arm_compute::CLPermute::validate(&aclInputYInfo, - &aclPermutedYInfo, - aclPermutationYVector); - } - - const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped - false, // is inputY reshaped - false); // is inputY reshaped only 1st run + arm_compute::TensorInfo aclInputInfoX = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoX); + arm_compute::TensorInfo aclInputInfoY = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoY); + const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(outputInfo); + // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set + aclInputInfoX.set_are_values_constant(false); + aclInputInfoY.set_are_values_constant(false); - statusGEMM = arm_compute::CLGEMM::validate(descriptor.m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo, - descriptor.m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo, - nullptr, - &aclOutputInfo, - 1.0, - 0, - gemm_info); + const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo( + activationDescriptor); - if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK && - statusPermuteY.error_code() == arm_compute::ErrorCode::OK && - statusGEMM.error_code() == arm_compute::ErrorCode::OK) - { - return arm_compute::Status(arm_compute::ErrorCode::OK, - "All Batch Mat Mul layers validate status OK."); - } - else - { - return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR, - "BatchMatMul layer validate status failed." - + statusGEMM.error_description() - + statusPermuteX.error_description() - + statusPermuteY.error_description()); - } + arm_compute::MatMulInfo matMulInfo; + matMulInfo.adj_lhs(descriptor.m_TransposeX); + matMulInfo.adj_rhs(descriptor.m_TransposeY); + matMulInfo.fused_activation(activationInfo); + return arm_compute::CLMatMul::validate(&aclInputInfoX, &aclInputInfoY, &aclOutputInfo, matMulInfo); } ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor, @@ -135,86 +77,37 @@ ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& d m_Data.ValidateInputsOutputs("ClBatchMatMulWorkload", 2, 1); - const arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); - const arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor(); - arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); + arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor(); + auto outputHandle = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0]); + arm_compute::ICLTensor& output = outputHandle->GetTensor(); - inputX.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX)); - arm_compute::TensorShape inputXTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape( - info.m_InputTensorInfos[0].GetShape(), 3); - inputX.info()->set_tensor_shape(inputXTensorInfo); - inputY.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY)); - arm_compute::TensorShape inputYTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape( - info.m_InputTensorInfos[1].GetShape(), 3); - inputY.info()->set_tensor_shape(inputYTensorInfo); + // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set + inputX.info()->set_are_values_constant(false); + inputY.info()->set_are_values_constant(false); - arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo(); - arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo(); + const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor); - if (descriptor.m_Parameters.m_TransposeX == true) - { - armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[0], 3); - - armnn::PermutationVector permutationXVector - = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions()); - const TensorInfo permutedXInfo = armnnUtils::Permuted(strippedInfo, permutationXVector); - const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector); - armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo); - armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX); - - auto permuteLayerX = std::make_unique<arm_compute::CLPermute>(); - permuteLayerX->configure(clCompileContext, - &inputX, - &m_PermutedTensorX, - aclPermutationXVector); - m_PermuteLayerX.reset(permuteLayerX.release()); - } + arm_compute::MatMulInfo matMulInfo; + matMulInfo.adj_lhs(descriptor.m_Parameters.m_TransposeX); + matMulInfo.adj_rhs(descriptor.m_Parameters.m_TransposeY); + matMulInfo.fused_activation(activationInfo); - if (descriptor.m_Parameters.m_TransposeY == true) - { - armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[1], 3); - - armnn::PermutationVector permutationYVector - = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions()); - const TensorInfo permutedYInfo = armnnUtils::Permuted(strippedInfo, permutationYVector); - const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector); - armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo); - armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY); - - auto permuteLayerY = std::make_unique<arm_compute::CLPermute>(); - permuteLayerY->configure(clCompileContext, - &inputY, - &m_PermutedTensorY, - aclPermutationYVector); - m_PermuteLayerY.reset(permuteLayerY.release()); - } + m_MatMulLayer.configure(clCompileContext, &inputX, &inputY, &output, matMulInfo); - const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped - false, // is inputY reshaped - false); // is inputY reshaped only 1st run - auto gemmLayer = std::make_unique<arm_compute::CLGEMM>(); - gemmLayer->configure(clCompileContext, - descriptor.m_Parameters.m_TransposeX ? &m_PermutedTensorX : &inputX, - descriptor.m_Parameters.m_TransposeY ? &m_PermutedTensorY : &inputY, - nullptr, - &output, - 1.0, - 0, - gemm_info); - m_GEMMLayer.reset(gemmLayer.release()); + // Report Profiling Details + WorkloadInfo detailsInfo; + detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos; + detailsInfo.m_OutputTensorInfos = info.m_OutputTensorInfos; + ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchMatMulWorkload_Construct", + descriptor.m_Parameters, + detailsInfo, + GetGuid()); } void ClBatchMatMulWorkload::Execute() const { ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClBatchMatMulWorkload_Execute", this->GetGuid()); - if (m_PermuteLayerX) - { - m_PermuteLayerX->run(); - } - if (m_PermuteLayerY) - { - m_PermuteLayerY->run(); - } - m_GEMMLayer->run(); + RunClFunction(m_MatMulLayer, CHECK_LOCATION()); } } //namespace armnn |