16 #include <arm_compute/runtime/NEON/functions/NEGEMM.h>
18 #include <arm_compute/runtime/NEON/functions/NEPermute.h>
30 throw Exception(
"Support for adjoint not implemented.");
34 throw Exception(
"Only supported the MatMul in the last 2 dimensions");
37 const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.
m_DataLayoutX);
38 const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.
m_DataLayoutY);
39 const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
45 arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
46 arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
51 const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
53 aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo);
55 statusPermuteX = arm_compute::NEPermute::validate(&aclInputXInfo,
57 aclPermutationXVector);
63 const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
65 aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo);
67 statusPermuteY = arm_compute::NEPermute::validate(&aclInputYInfo,
69 aclPermutationYVector);
72 const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(
false,
76 statusGEMM = arm_compute::NEGEMM::validate(descriptor.
m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo,
77 descriptor.
m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo,
84 if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK &&
85 statusPermuteY.error_code() == arm_compute::ErrorCode::OK &&
86 statusGEMM.error_code() == arm_compute::ErrorCode::OK)
89 "All BatchMatMul layers validate status OK.");
94 "BatchMatMul layer validate status failed."
95 + statusGEMM.error_description()
96 + statusPermuteX.error_description()
97 + statusPermuteY.error_description());
108 throw Exception(
"Support for adjoint not implemented.");
113 throw Exception(
"Only supported the MatMul in the last 2 dimensions");
124 arm_compute::ITensor& inputX = PolymorphicDowncast<IAclTensorHandle*>(
m_Data.
m_Inputs[0])->GetTensor();
125 arm_compute::ITensor& inputY = PolymorphicDowncast<IAclTensorHandle*>(
m_Data.
m_Inputs[1])->GetTensor();
126 auto outputHandle = PolymorphicDowncast<IAclTensorHandle*>(
m_Data.
m_Outputs[0]);
127 arm_compute::ITensor& output = outputHandle->GetTensor();
132 inputX.info()->set_data_layout(aclDataLayoutX);
133 inputY.info()->set_data_layout(aclDataLayoutY);
140 const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
142 auto permuteLayerX = std::make_unique<arm_compute::NEPermute>();
143 BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
144 InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
145 permuteLayerX->configure(&inputX, &m_PermutedTensorX, aclPermutationXVector);
146 m_PermuteLayerX.reset(permuteLayerX.release());
154 const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
156 auto permuteLayerY = std::make_unique<arm_compute::NEPermute>();
157 BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
158 InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
159 permuteLayerY->configure(&inputY, &m_PermutedTensorY, aclPermutationYVector);
160 m_PermuteLayerY.reset(permuteLayerY.release());
163 const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(
false,
166 auto gemmLayer = std::make_unique<arm_compute::NEGEMM>();
174 m_GEMMLayer.reset(gemmLayer.release());
182 m_PermuteLayerX->run();
186 m_PermuteLayerY->run();