22 #include <arm_compute/runtime/CL/functions/CLGEMM.h>
23 #include <arm_compute/runtime/CL/functions/CLPermute.h>
36 throw Exception(
"Support for adjoint not implemented.");
40 throw Exception(
"Only supported the MatMul in the last 2 dimensions");
49 const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.
m_DataLayoutX, 3);
50 const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.
m_DataLayoutY, 3);
51 const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.
m_DataLayoutY, 3);
53 arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
54 arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
61 const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
63 aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo, 3);
65 statusPermuteX = arm_compute::CLPermute::validate(&aclInputXInfo,
67 aclPermutationXVector);
75 const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
77 aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo, 3);
79 statusPermuteY = arm_compute::CLPermute::validate(&aclInputYInfo,
81 aclPermutationYVector);
84 const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(
false,
89 statusGEMM = arm_compute::CLGEMM::validate(descriptor.
m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo,
90 descriptor.
m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo,
97 if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK &&
98 statusPermuteY.error_code() == arm_compute::ErrorCode::OK &&
99 statusGEMM.error_code() == arm_compute::ErrorCode::OK)
102 "All Batch Mat Mul layers validate status OK.");
107 "BatchMatMul layer validate status failed."
108 + statusGEMM.error_description()
109 + statusPermuteX.error_description()
110 + statusPermuteY.error_description());
117 const arm_compute::CLCompileContext& clCompileContext)
128 throw Exception(
"Support for adjoint not implemented.");
133 throw Exception(
"Only supported the MatMul in the last 2 dimensions");
138 const arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(
m_Data.
m_Inputs[0])->GetTensor();
139 const arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(
m_Data.
m_Inputs[1])->GetTensor();
140 arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(
m_Data.
m_Outputs[0])->GetTensor();
143 arm_compute::TensorShape inputXTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
144 info.m_InputTensorInfos[0].GetShape(), 3);
145 inputX.info()->set_tensor_shape(inputXTensorInfo);
147 arm_compute::TensorShape inputYTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
148 info.m_InputTensorInfos[1].GetShape(), 3);
149 inputY.info()->set_tensor_shape(inputYTensorInfo);
151 arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
152 arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
161 const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
162 armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
163 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
165 auto permuteLayerX = std::make_unique<arm_compute::CLPermute>();
166 permuteLayerX->configure(clCompileContext,
169 aclPermutationXVector);
170 m_PermuteLayerX.reset(permuteLayerX.release());
180 const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
181 armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
182 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
184 auto permuteLayerY = std::make_unique<arm_compute::CLPermute>();
185 permuteLayerY->configure(clCompileContext,
188 aclPermutationYVector);
189 m_PermuteLayerY.reset(permuteLayerY.release());
192 const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(
false,
195 auto gemmLayer = std::make_unique<arm_compute::CLGEMM>();
196 gemmLayer->configure(clCompileContext,
204 m_GEMMLayer.reset(gemmLayer.release());
212 m_PermuteLayerX->run();
216 m_PermuteLayerY->run();