ArmNN
 23.02
ClBatchMatMulWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 
8 #include "ClWorkloadUtils.hpp"
9 
12 
14 
15 #include <armnnUtils/Permute.hpp>
17 
19 
20 #include <cl/ClTensorHandle.hpp>
21 
22 #include <arm_compute/runtime/CL/functions/CLGEMM.h>
23 #include <arm_compute/runtime/CL/functions/CLPermute.h>
24 
25 
26 namespace armnn
27 {
28 
30  const TensorInfo& inputY,
31  const TensorInfo& output,
32  const BatchMatMulDescriptor& descriptor)
33 {
34  if (descriptor.m_AdjointX || descriptor.m_AdjointY )
35  {
36  throw Exception("Support for adjoint not implemented.");
37  }
39  {
40  throw Exception("Only supported the MatMul in the last 2 dimensions");
41  }
42 
43  arm_compute::Status statusGEMM = arm_compute::Status(arm_compute::ErrorCode::OK);
44  arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
45  arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
46 
47  // ClGemmMatrixMultiplyNativeKernel used by CLGEMM can only support 3 dimensional
48  // tensors so try to reduce the dimensions to 3
49  const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX, 3);
50  const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY, 3);
51  const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayoutY, 3);
52 
53  arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
54  arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
55 
56  if (descriptor.m_TransposeX == true)
57  {
58  armnn::TensorInfo inputXStripped = armnnUtils::ReduceDims(inputX, 3);
59 
60  auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputXStripped.GetNumDimensions());
61  const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
62  const TensorInfo permutedXInfo = armnnUtils::Permuted(inputXStripped, permutationXVector);
63  aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo, 3);
64 
65  statusPermuteX = arm_compute::CLPermute::validate(&aclInputXInfo,
66  &aclPermutedXInfo,
67  aclPermutationXVector);
68  }
69 
70  if (descriptor.m_TransposeY == true)
71  {
72  armnn::TensorInfo inputYStripped = armnnUtils::ReduceDims(inputY, 3);
73 
74  auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputYStripped.GetNumDimensions());
75  const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
76  const TensorInfo permutedYInfo = armnnUtils::Permuted(inputYStripped, permutationYVector);
77  aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo, 3);
78 
79  statusPermuteY = arm_compute::CLPermute::validate(&aclInputYInfo,
80  &aclPermutedYInfo,
81  aclPermutationYVector);
82  }
83 
84  const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
85  false, // is inputY reshaped
86  false); // is inputY reshaped only 1st run
87 
88 
89  statusGEMM = arm_compute::CLGEMM::validate(descriptor.m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo,
90  descriptor.m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo,
91  nullptr,
92  &aclOutputInfo,
93  1.0,
94  0,
95  gemm_info);
96 
97  if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK &&
98  statusPermuteY.error_code() == arm_compute::ErrorCode::OK &&
99  statusGEMM.error_code() == arm_compute::ErrorCode::OK)
100  {
101  return arm_compute::Status(arm_compute::ErrorCode::OK,
102  "All Batch Mat Mul layers validate status OK.");
103  }
104  else
105  {
106  return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
107  "BatchMatMul layer validate status failed."
108  + statusGEMM.error_description()
109  + statusPermuteX.error_description()
110  + statusPermuteY.error_description());
111  }
112 
113 }
114 
116  const WorkloadInfo& info,
117  const arm_compute::CLCompileContext& clCompileContext)
119 {
120  // Report Profiling Details
121  ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchMatMulWorkload_Construct",
122  descriptor.m_Parameters,
123  info,
124  this->GetGuid());
125 
126  if (descriptor.m_Parameters.m_AdjointX || descriptor.m_Parameters.m_AdjointY )
127  {
128  throw Exception("Support for adjoint not implemented.");
129  }
132  {
133  throw Exception("Only supported the MatMul in the last 2 dimensions");
134  }
135 
136  m_Data.ValidateInputsOutputs("ClBatchMatMulWorkload", 2, 1);
137 
138  const arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
139  const arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
140  arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
141 
142  inputX.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX));
143  arm_compute::TensorShape inputXTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
144  info.m_InputTensorInfos[0].GetShape(), 3);
145  inputX.info()->set_tensor_shape(inputXTensorInfo);
146  inputY.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY));
147  arm_compute::TensorShape inputYTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
148  info.m_InputTensorInfos[1].GetShape(), 3);
149  inputY.info()->set_tensor_shape(inputYTensorInfo);
150 
151  arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
152  arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
153 
154  if (descriptor.m_Parameters.m_TransposeX == true)
155  {
156  armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[0], 3);
157 
158  armnn::PermutationVector permutationXVector
160  const TensorInfo permutedXInfo = armnnUtils::Permuted(strippedInfo, permutationXVector);
161  const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
162  armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
163  armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
164 
165  auto permuteLayerX = std::make_unique<arm_compute::CLPermute>();
166  permuteLayerX->configure(clCompileContext,
167  &inputX,
168  &m_PermutedTensorX,
169  aclPermutationXVector);
170  m_PermuteLayerX.reset(permuteLayerX.release());
171  }
172 
173  if (descriptor.m_Parameters.m_TransposeY == true)
174  {
175  armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[1], 3);
176 
177  armnn::PermutationVector permutationYVector
179  const TensorInfo permutedYInfo = armnnUtils::Permuted(strippedInfo, permutationYVector);
180  const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
181  armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
182  armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
183 
184  auto permuteLayerY = std::make_unique<arm_compute::CLPermute>();
185  permuteLayerY->configure(clCompileContext,
186  &inputY,
187  &m_PermutedTensorY,
188  aclPermutationYVector);
189  m_PermuteLayerY.reset(permuteLayerY.release());
190  }
191 
192  const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
193  false, // is inputY reshaped
194  false); // is inputY reshaped only 1st run
195  auto gemmLayer = std::make_unique<arm_compute::CLGEMM>();
196  gemmLayer->configure(clCompileContext,
197  descriptor.m_Parameters.m_TransposeX ? &m_PermutedTensorX : &inputX,
198  descriptor.m_Parameters.m_TransposeY ? &m_PermutedTensorY : &inputY,
199  nullptr,
200  &output,
201  1.0,
202  0,
203  gemm_info);
204  m_GEMMLayer.reset(gemmLayer.release());
205 }
206 
208 {
209  ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClBatchMatMulWorkload_Execute", this->GetGuid());
210  if (m_PermuteLayerX)
211  {
212  m_PermuteLayerX->run();
213  }
214  if (m_PermuteLayerY)
215  {
216  m_PermuteLayerY->run();
217  }
218  m_GEMMLayer->run();
219 }
220 } //namespace armnn
armnn::BatchMatMulDescriptor::m_TransposeX
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
Definition: Descriptors.hpp:1559
armnn::BaseWorkload< BatchMatMulQueueDescriptor >::GetGuid
arm::pipe::ProfilingGuid GetGuid() const final
Definition: Workload.hpp:61
armnn::ClBatchMatMulValidate
arm_compute::Status ClBatchMatMulValidate(const TensorInfo &inputX, const TensorInfo &inputY, const TensorInfo &output, const BatchMatMulDescriptor &descriptor)
Definition: ClBatchMatMulWorkload.cpp:29
armnn::QueueDescriptor::ValidateInputsOutputs
void ValidateInputsOutputs(const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
Definition: WorkloadData.cpp:475
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
armnn::BatchMatMulQueueDescriptor
Definition: WorkloadData.hpp:743
armnn::BatchMatMulDescriptor
A BatchMatMulDescriptor for the BatchMatMul operator.
Definition: Descriptors.hpp:1531
PolymorphicDowncast.hpp
TensorUtils.hpp
armnn::BaseWorkload< BatchMatMulQueueDescriptor >::m_Data
BatchMatMulQueueDescriptor m_Data
Definition: Workload.hpp:83
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnnUtils::Permuted
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:98
armnn::DataLayout::NCHW
@ NCHW
armnn::BatchMatMulDescriptor::m_DataLayoutX
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
Definition: Descriptors.hpp:1568
armnn::TensorInfo::GetNumDimensions
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:195
ArmComputeTensorUtils.hpp
armnn::GeneratePermutationVectorOnLastTwoDimensions
armnn::PermutationVector GeneratePermutationVectorOnLastTwoDimensions(unsigned int rank)
Generates a permutation vector of size rank that permutes the 2 most right dimensions.
Definition: WorkloadUtils.cpp:344
armnn::BatchMatMulDescriptor::m_TransposeY
bool m_TransposeY
Definition: Descriptors.hpp:1560
ArmComputeUtils.hpp
armnn::TensorInfo
Definition: Tensor.hpp:152
Permute.hpp
armnn::Status
Status
Definition: Types.hpp:42
armnn::ClBaseWorkload
Definition: ClBaseWorkload.hpp:13
armnnUtils::ReduceDims
armnn::TensorShape ReduceDims(const armnn::TensorShape &tensorInfo, unsigned int dimensions)
Definition: TensorUtils.cpp:106
armnn::ClBatchMatMulWorkload::ClBatchMatMulWorkload
ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor &descriptor, const WorkloadInfo &info, const arm_compute::CLCompileContext &clCompileContext)
Definition: ClBatchMatMulWorkload.cpp:115
armnn::WorkloadInfo
Contains information about TensorInfos of a layer.
Definition: WorkloadInfo.hpp:16
armnn::PermutationVector
Definition: Types.hpp:295
armnn::QueueDescriptorWithParameters::m_Parameters
LayerDescriptor m_Parameters
Definition: WorkloadData.hpp:66
ARMNN_REPORT_PROFILING_WORKLOAD_DESC
#define ARMNN_REPORT_PROFILING_WORKLOAD_DESC(name, desc, infos, guid)
Definition: Profiling.hpp:227
ClBatchMatMulWorkload.hpp
armnn::BatchMatMulDescriptor::m_DataLayoutY
DataLayout m_DataLayoutY
Definition: Descriptors.hpp:1569
ClWorkloadUtils.hpp
armnn::QueueDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition: WorkloadData.hpp:27
WorkloadUtils.hpp
armnn::BatchMatMulDescriptor::m_AdjointY
bool m_AdjointY
Definition: Descriptors.hpp:1565
armnn::BatchMatMulDescriptor::m_AdjointX
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
Definition: Descriptors.hpp:1564
ARMNN_SCOPED_PROFILING_EVENT_CL_GUID
#define ARMNN_SCOPED_PROFILING_EVENT_CL_GUID(name, guid)
Definition: ClWorkloadUtils.hpp:28
armnn::ClBatchMatMulWorkload::Execute
virtual void Execute() const override
Definition: ClBatchMatMulWorkload.cpp:207
ClTensorHandle.hpp
armnn::QueueDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition: WorkloadData.hpp:26
armnn::BoostLogSeverityMapping::info
@ info