aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2023-01-10 10:32:51 +0000
committerTeresaARM <teresa.charlinreyes@arm.com>2023-05-08 13:16:25 +0000
commit97a3aefff63ae081ae62aa5bac17d6e9c401937e (patch)
tree4cda3515b8718215be14ae95283a51a49b372e68
parent1fe6c8170ae2fe90b53fb71b7570aec9dfe75c45 (diff)
downloadarmnn-97a3aefff63ae081ae62aa5bac17d6e9c401937e.tar.gz
IVGCVSW-7308 Add GpuAcc Batch MatMul workload
* Call dedicated MatMul kernel in ACL * Add int8 tests * Add int8 to documentation * Force tensors to be dynamic (nonConst) as per request of ACL Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I7b7ac20deec8637dc46ca990d339d92c4587cbe4
-rw-r--r--delegate/test/BatchMatMulTest.cpp16
-rw-r--r--docs/02_operator_list.dox1
-rw-r--r--src/backends/cl/ClLayerSupport.cpp3
-rw-r--r--src/backends/cl/test/ClEndToEndTests.cpp12
-rw-r--r--src/backends/cl/test/ClLayerTests.cpp28
-rw-r--r--src/backends/cl/workloads/ClBatchMatMulWorkload.cpp187
-rw-r--r--src/backends/cl/workloads/ClBatchMatMulWorkload.hpp46
7 files changed, 112 insertions, 181 deletions
diff --git a/delegate/test/BatchMatMulTest.cpp b/delegate/test/BatchMatMulTest.cpp
index 5cd1a70141..fcf1ec2b7c 100644
--- a/delegate/test/BatchMatMulTest.cpp
+++ b/delegate/test/BatchMatMulTest.cpp
@@ -303,7 +303,7 @@ namespace armnnDelegate
{
// Set input data
std::vector<int32_t> LHSInputShape { 2,2,2 };
- std::vector<int32_t> RHSInputShape { 1,2,2 };
+ std::vector<int32_t> RHSInputShape { 2,2 };
std::vector<int32_t> outputShape { 2,2,2 };
std::vector<int8_t> LHSInputValues = { 1, 2,
@@ -680,6 +680,7 @@ namespace armnnDelegate
BatchMatMul2DInt8SimpleAdjointTest(backends);
}
}
+
TEST_SUITE("BATCH_MATMUL_GpuAccTests")
{
TEST_CASE("BATCH_MATMUL_Fp32_GpuAccTests")
@@ -689,11 +690,20 @@ namespace armnnDelegate
BatchMatMul3DFp32SimpleTest (backends);
BatchMatMul4DFp32SimpleTest (backends);
BatchMatMul3DFp32BatchTest (backends);
- BatchMatMul3DFp32BroadcastTest (backends);
- BatchMatMul3D2DFp32BroadcastTest (backends);
BatchMatMul2DFp32TinyTest (backends);
BatchMatMulNonSquareFp32Test (backends);
BatchMatMul2DFp32SimpleAdjointTest(backends);
}
+
+ TEST_CASE("BATCH_MATMUL_Int8_GpuAccTests")
+ {
+ std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
+ BatchMatMul2DInt8SimpleTest (backends);
+ BatchMatMul3DInt8SimpleTest (backends);
+ BatchMatMul3DInt8BatchTest (backends);
+ BatchMatMul2DInt8TinyTest (backends);
+ BatchMatMulNonSquareInt8Test (backends);
+ BatchMatMul2DInt8SimpleAdjointTest(backends);
+ }
}
}
diff --git a/docs/02_operator_list.dox b/docs/02_operator_list.dox
index 791565a985..53e37e2b4e 100644
--- a/docs/02_operator_list.dox
+++ b/docs/02_operator_list.dox
@@ -311,6 +311,7 @@ where N = batches, C = channels, H = height, W = width
<table>
<tr><th>
<tr><td>FLOAT32
+ <tr><td>QASYMMS8
</table>
<tr>
<td rowspan="3">BatchNormalizationLayer
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index b63837539e..6fa4f3ce51 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -720,7 +720,8 @@ bool ClLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
inputX,
inputY,
output,
- descriptor);
+ descriptor,
+ nullptr);
}
bool ClLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
diff --git a/src/backends/cl/test/ClEndToEndTests.cpp b/src/backends/cl/test/ClEndToEndTests.cpp
index 4ff2b79d08..a6ddd97ecf 100644
--- a/src/backends/cl/test/ClEndToEndTests.cpp
+++ b/src/backends/cl/test/ClEndToEndTests.cpp
@@ -8,6 +8,7 @@
#include <backendsCommon/test/ActivationEndToEndTestImpl.hpp>
#include <backendsCommon/test/AdditionEndToEndTestImpl.hpp>
#include <backendsCommon/test/ArgMinMaxEndToEndTestImpl.hpp>
+#include <backendsCommon/test/BatchMatMulEndToEndTestImpl.hpp>
#include <backendsCommon/test/ComparisonEndToEndTestImpl.hpp>
#include <backendsCommon/test/ConcatEndToEndTestImpl.hpp>
#include <backendsCommon/test/DepthToSpaceEndToEndTestImpl.hpp>
@@ -56,6 +57,17 @@ TEST_CASE("ClAdditionEndToEndUint8Test")
AdditionEndToEnd<armnn::DataType::QAsymmU8>(clDefaultBackends);
}
+// Batch Mat Mul
+TEST_CASE("ClBatchMatMulEndToEndFloat32Test")
+{
+ BatchMatMulEndToEnd<armnn::DataType::Float32>(clDefaultBackends);
+}
+
+TEST_CASE("ClBatchMatMulEndToEndInt8Test")
+{
+ BatchMatMulEndToEnd<armnn::DataType::QAsymmS8>(clDefaultBackends);
+}
+
// Constant
TEST_CASE("ConstantUsage_Cl_Float32")
{
diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp
index 1ad1de8e04..a84ecc9f9f 100644
--- a/src/backends/cl/test/ClLayerTests.cpp
+++ b/src/backends/cl/test/ClLayerTests.cpp
@@ -77,27 +77,51 @@ ARMNN_AUTO_TEST_FIXTURE_WITH_THF(Elu, ClContextControlFixture, EluTest)
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul2DSimpleFloat32,
ClContextControlFixture,
BatchMatMul2DSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul2DSimpleInt8,
+ ClContextControlFixture,
+ BatchMatMul2DSimpleTest<DataType::QAsymmS8>);
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DSimpleFloat32,
ClContextControlFixture,
BatchMatMul3DSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DSimpleInt8,
+ ClContextControlFixture,
+ BatchMatMul3DSimpleTest<DataType::QAsymmS8>);
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMulNCHWSimpleFloat32,
ClContextControlFixture,
BatchMatMulNCHWSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMulNCHWSimpleFloat32,
+ ClContextControlFixture,
+ BatchMatMulNCHWSimpleTest<DataType::QAsymmS8>);
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DBatchFloat32,
ClContextControlFixture,
BatchMatMul3DBatchTest<DataType::Float32>);
-ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DBroadcastFloat32,
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DBatchInt8,
+ ClContextControlFixture,
+ BatchMatMul3DBatchTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(UNSUPPORTED_BatchMatMul3DBroadcastFloat32,
ClContextControlFixture,
BatchMatMul3DBroadcastTest<DataType::Float32>);
-ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3D2DBroadcastFloat32,
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(UNSUPPORTED_BatchMatMul3DBroadcastInt8,
+ ClContextControlFixture,
+ BatchMatMul3DBroadcastTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(UNSUPPORTED_BatchMatMul3D2DBroadcastFloat32,
ClContextControlFixture,
BatchMatMul3D2DBroadcastTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(UNSUPPORTED_BatchMatMul3D2DBroadcastInt8,
+ ClContextControlFixture,
+ BatchMatMul3D2DBroadcastTest<DataType::QAsymmS8>);
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul2DTinyFloat32,
ClContextControlFixture,
BatchMatMul2DTinyTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul2DTinyInt8,
+ ClContextControlFixture,
+ BatchMatMul2DTinyTest<DataType::QAsymmS8>);
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul2DTranspSimpleFloat32,
ClContextControlFixture,
BatchMatMul2DTranspSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul2DTranspSimpleInt8,
+ ClContextControlFixture,
+ BatchMatMul2DTranspSimpleTest<DataType::QAsymmS8>);
// Batch To Space
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchToSpaceNdNhwcFloat321,
diff --git a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
index f21666b90a..bd0fd51617 100644
--- a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
+++ b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
@@ -12,24 +12,19 @@
#include <armnn/utility/PolymorphicDowncast.hpp>
-#include <armnnUtils/Permute.hpp>
-#include <armnnUtils/TensorUtils.hpp>
-
#include <backendsCommon/WorkloadUtils.hpp>
#include <cl/ClTensorHandle.hpp>
-#include <arm_compute/runtime/CL/functions/CLGEMM.h>
-#include <arm_compute/runtime/CL/functions/CLPermute.h>
-
namespace armnn
{
-arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
- const TensorInfo& inputY,
- const TensorInfo& output,
- const BatchMatMulDescriptor& descriptor)
+arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputInfoX,
+ const TensorInfo& inputInfoY,
+ const TensorInfo& outputInfo,
+ const BatchMatMulDescriptor& descriptor,
+ const ActivationDescriptor* activationDescriptor)
{
if (descriptor.m_AdjointX || descriptor.m_AdjointY )
{
@@ -40,76 +35,23 @@ arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
throw Exception("Only supported the MatMul in the last 2 dimensions");
}
- arm_compute::Status statusGEMM = arm_compute::Status(arm_compute::ErrorCode::OK);
- arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
- arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
-
- // ClGemmMatrixMultiplyNativeKernel used by CLGEMM can only support 3 dimensional
- // tensors so try to reduce the dimensions to 3
- const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX, 3);
- const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY, 3);
- const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayoutY, 3);
-
- arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
- arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
-
- if (descriptor.m_TransposeX == true)
- {
- armnn::TensorInfo inputXStripped = armnnUtils::ReduceDims(inputX, 3);
-
- auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputXStripped.GetNumDimensions());
- const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
- const TensorInfo permutedXInfo = armnnUtils::Permuted(inputXStripped, permutationXVector);
- aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo, 3);
-
- statusPermuteX = arm_compute::CLPermute::validate(&aclInputXInfo,
- &aclPermutedXInfo,
- aclPermutationXVector);
- }
-
- if (descriptor.m_TransposeY == true)
- {
- armnn::TensorInfo inputYStripped = armnnUtils::ReduceDims(inputY, 3);
-
- auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputYStripped.GetNumDimensions());
- const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
- const TensorInfo permutedYInfo = armnnUtils::Permuted(inputYStripped, permutationYVector);
- aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo, 3);
-
- statusPermuteY = arm_compute::CLPermute::validate(&aclInputYInfo,
- &aclPermutedYInfo,
- aclPermutationYVector);
- }
-
- const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
- false, // is inputY reshaped
- false); // is inputY reshaped only 1st run
+ arm_compute::TensorInfo aclInputInfoX = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoX);
+ arm_compute::TensorInfo aclInputInfoY = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoY);
+ const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(outputInfo);
+ // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set
+ aclInputInfoX.set_are_values_constant(false);
+ aclInputInfoY.set_are_values_constant(false);
- statusGEMM = arm_compute::CLGEMM::validate(descriptor.m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo,
- descriptor.m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo,
- nullptr,
- &aclOutputInfo,
- 1.0,
- 0,
- gemm_info);
+ const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo(
+ activationDescriptor);
- if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK &&
- statusPermuteY.error_code() == arm_compute::ErrorCode::OK &&
- statusGEMM.error_code() == arm_compute::ErrorCode::OK)
- {
- return arm_compute::Status(arm_compute::ErrorCode::OK,
- "All Batch Mat Mul layers validate status OK.");
- }
- else
- {
- return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
- "BatchMatMul layer validate status failed."
- + statusGEMM.error_description()
- + statusPermuteX.error_description()
- + statusPermuteY.error_description());
- }
+ arm_compute::MatMulInfo matMulInfo;
+ matMulInfo.adj_lhs(descriptor.m_TransposeX);
+ matMulInfo.adj_rhs(descriptor.m_TransposeY);
+ matMulInfo.fused_activation(activationInfo);
+ return arm_compute::CLMatMul::validate(&aclInputInfoX, &aclInputInfoY, &aclOutputInfo, matMulInfo);
}
ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
@@ -135,86 +77,37 @@ ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& d
m_Data.ValidateInputsOutputs("ClBatchMatMulWorkload", 2, 1);
- const arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
- const arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
- arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+ arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+ arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
+ auto outputHandle = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0]);
+ arm_compute::ICLTensor& output = outputHandle->GetTensor();
- inputX.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX));
- arm_compute::TensorShape inputXTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
- info.m_InputTensorInfos[0].GetShape(), 3);
- inputX.info()->set_tensor_shape(inputXTensorInfo);
- inputY.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY));
- arm_compute::TensorShape inputYTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
- info.m_InputTensorInfos[1].GetShape(), 3);
- inputY.info()->set_tensor_shape(inputYTensorInfo);
+ // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set
+ inputX.info()->set_are_values_constant(false);
+ inputY.info()->set_are_values_constant(false);
- arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
- arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
+ const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
- if (descriptor.m_Parameters.m_TransposeX == true)
- {
- armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[0], 3);
-
- armnn::PermutationVector permutationXVector
- = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
- const TensorInfo permutedXInfo = armnnUtils::Permuted(strippedInfo, permutationXVector);
- const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
- armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
- armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
-
- auto permuteLayerX = std::make_unique<arm_compute::CLPermute>();
- permuteLayerX->configure(clCompileContext,
- &inputX,
- &m_PermutedTensorX,
- aclPermutationXVector);
- m_PermuteLayerX.reset(permuteLayerX.release());
- }
+ arm_compute::MatMulInfo matMulInfo;
+ matMulInfo.adj_lhs(descriptor.m_Parameters.m_TransposeX);
+ matMulInfo.adj_rhs(descriptor.m_Parameters.m_TransposeY);
+ matMulInfo.fused_activation(activationInfo);
- if (descriptor.m_Parameters.m_TransposeY == true)
- {
- armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[1], 3);
-
- armnn::PermutationVector permutationYVector
- = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
- const TensorInfo permutedYInfo = armnnUtils::Permuted(strippedInfo, permutationYVector);
- const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
- armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
- armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
-
- auto permuteLayerY = std::make_unique<arm_compute::CLPermute>();
- permuteLayerY->configure(clCompileContext,
- &inputY,
- &m_PermutedTensorY,
- aclPermutationYVector);
- m_PermuteLayerY.reset(permuteLayerY.release());
- }
+ m_MatMulLayer.configure(clCompileContext, &inputX, &inputY, &output, matMulInfo);
- const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
- false, // is inputY reshaped
- false); // is inputY reshaped only 1st run
- auto gemmLayer = std::make_unique<arm_compute::CLGEMM>();
- gemmLayer->configure(clCompileContext,
- descriptor.m_Parameters.m_TransposeX ? &m_PermutedTensorX : &inputX,
- descriptor.m_Parameters.m_TransposeY ? &m_PermutedTensorY : &inputY,
- nullptr,
- &output,
- 1.0,
- 0,
- gemm_info);
- m_GEMMLayer.reset(gemmLayer.release());
+ // Report Profiling Details
+ WorkloadInfo detailsInfo;
+ detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos;
+ detailsInfo.m_OutputTensorInfos = info.m_OutputTensorInfos;
+ ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchMatMulWorkload_Construct",
+ descriptor.m_Parameters,
+ detailsInfo,
+ GetGuid());
}
void ClBatchMatMulWorkload::Execute() const
{
ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClBatchMatMulWorkload_Execute", this->GetGuid());
- if (m_PermuteLayerX)
- {
- m_PermuteLayerX->run();
- }
- if (m_PermuteLayerY)
- {
- m_PermuteLayerY->run();
- }
- m_GEMMLayer->run();
+ RunClFunction(m_MatMulLayer, CHECK_LOCATION());
}
} //namespace armnn
diff --git a/src/backends/cl/workloads/ClBatchMatMulWorkload.hpp b/src/backends/cl/workloads/ClBatchMatMulWorkload.hpp
index 5277efc947..d45fb7edb4 100644
--- a/src/backends/cl/workloads/ClBatchMatMulWorkload.hpp
+++ b/src/backends/cl/workloads/ClBatchMatMulWorkload.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -7,35 +7,25 @@
#include "ClBaseWorkload.hpp"
-#include <arm_compute/runtime/IFunction.h>
-#include <arm_compute/runtime/CL/CLTensor.h>
-#include <memory>
+#include <arm_compute/runtime/CL/functions/CLMatMul.h>
namespace armnn
{
- arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
- const TensorInfo& inputY,
- const TensorInfo& output,
- const BatchMatMulDescriptor& descriptor);
+arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
+ const TensorInfo& inputY,
+ const TensorInfo& output,
+ const BatchMatMulDescriptor& descriptor,
+ const ActivationDescriptor* activationDescriptor);
- class ClBatchMatMulWorkload : public ClBaseWorkload<BatchMatMulQueueDescriptor>
- {
- public:
- ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
- const WorkloadInfo& info,
- const arm_compute::CLCompileContext& clCompileContext);
- virtual void Execute() const override;
-
- private:
- // ACL layers required to fully form a Batch Mat Mul layer.
- std::unique_ptr<arm_compute::IFunction> m_GEMMLayer;
- std::unique_ptr<arm_compute::IFunction> m_PermuteLayerX;
- std::unique_ptr<arm_compute::IFunction> m_PermuteLayerY;
-
- // Additional CL arm_compute::Tensors.
- // Required to perform permutations.
- arm_compute::CLTensor m_PermutedTensorX;
- arm_compute::CLTensor m_PermutedTensorY;
-
- };
+class ClBatchMatMulWorkload : public ClBaseWorkload<BatchMatMulQueueDescriptor>
+{
+public:
+ ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
+ const WorkloadInfo& info,
+ const arm_compute::CLCompileContext& clCompileContext);
+ virtual void Execute() const override;
+
+private:
+ mutable arm_compute::CLMatMul m_MatMulLayer;
+};
} //namespace armnn