aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/workloads/ClBatchMatMulWorkload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/workloads/ClBatchMatMulWorkload.hpp')
-rw-r--r--src/backends/cl/workloads/ClBatchMatMulWorkload.hpp46
1 files changed, 18 insertions, 28 deletions
diff --git a/src/backends/cl/workloads/ClBatchMatMulWorkload.hpp b/src/backends/cl/workloads/ClBatchMatMulWorkload.hpp
index 5277efc947..d45fb7edb4 100644
--- a/src/backends/cl/workloads/ClBatchMatMulWorkload.hpp
+++ b/src/backends/cl/workloads/ClBatchMatMulWorkload.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -7,35 +7,25 @@
#include "ClBaseWorkload.hpp"
-#include <arm_compute/runtime/IFunction.h>
-#include <arm_compute/runtime/CL/CLTensor.h>
-#include <memory>
+#include <arm_compute/runtime/CL/functions/CLMatMul.h>
namespace armnn
{
- arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
- const TensorInfo& inputY,
- const TensorInfo& output,
- const BatchMatMulDescriptor& descriptor);
+arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
+ const TensorInfo& inputY,
+ const TensorInfo& output,
+ const BatchMatMulDescriptor& descriptor,
+ const ActivationDescriptor* activationDescriptor);
- class ClBatchMatMulWorkload : public ClBaseWorkload<BatchMatMulQueueDescriptor>
- {
- public:
- ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
- const WorkloadInfo& info,
- const arm_compute::CLCompileContext& clCompileContext);
- virtual void Execute() const override;
-
- private:
- // ACL layers required to fully form a Batch Mat Mul layer.
- std::unique_ptr<arm_compute::IFunction> m_GEMMLayer;
- std::unique_ptr<arm_compute::IFunction> m_PermuteLayerX;
- std::unique_ptr<arm_compute::IFunction> m_PermuteLayerY;
-
- // Additional CL arm_compute::Tensors.
- // Required to perform permutations.
- arm_compute::CLTensor m_PermutedTensorX;
- arm_compute::CLTensor m_PermutedTensorY;
-
- };
+class ClBatchMatMulWorkload : public ClBaseWorkload<BatchMatMulQueueDescriptor>
+{
+public:
+ ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
+ const WorkloadInfo& info,
+ const arm_compute::CLCompileContext& clCompileContext);
+ virtual void Execute() const override;
+
+private:
+ mutable arm_compute::CLMatMul m_MatMulLayer;
+};
} //namespace armnn