aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLGEMM.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLGEMM.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLGEMM.h102
1 files changed, 50 insertions, 52 deletions
diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h
index 92f9736e35..0b13e7dbbf 100644
--- a/arm_compute/runtime/CL/functions/CLGEMM.h
+++ b/arm_compute/runtime/CL/functions/CLGEMM.h
@@ -24,11 +24,6 @@
#ifndef ARM_COMPUTE_CLGEMM_H
#define ARM_COMPUTE_CLGEMM_H
-#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
#include "arm_compute/runtime/CL/CLTensor.h"
#include "arm_compute/runtime/CL/CLTypes.h"
#include "arm_compute/runtime/IFunction.h"
@@ -36,9 +31,18 @@
#include "arm_compute/runtime/IWeightsManager.h"
#include "arm_compute/runtime/MemoryGroup.h"
+#include <memory>
+
namespace arm_compute
{
+class CLCompileContext;
+class CLGEMMReshapeRHSMatrixKernel;
+class CLGEMMMatrixMultiplyKernel;
+class CLGEMMMatrixMultiplyReshapedKernel;
+class CLGEMMMatrixMultiplyReshapedOnlyRHSKernel;
+class CLGEMMReshapeLHSMatrixKernel;
class ICLTensor;
+class ITensorInfo;
namespace weights_transformations
{
@@ -46,41 +50,36 @@ namespace weights_transformations
class CLGEMMReshapeRHSMatrixKernelManaged : public ITransformWeights
{
public:
+ /** Default constructor */
+ CLGEMMReshapeRHSMatrixKernelManaged();
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CLGEMMReshapeRHSMatrixKernelManaged(const CLGEMMReshapeRHSMatrixKernelManaged &) = delete;
+ /** Default move constructor */
+ CLGEMMReshapeRHSMatrixKernelManaged(CLGEMMReshapeRHSMatrixKernelManaged &&) = default;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CLGEMMReshapeRHSMatrixKernelManaged &operator=(const CLGEMMReshapeRHSMatrixKernelManaged &) = delete;
+ /** Default move assignment operator */
+ CLGEMMReshapeRHSMatrixKernelManaged &operator=(CLGEMMReshapeRHSMatrixKernelManaged &&) = default;
+ /** Default desctructor */
+ ~CLGEMMReshapeRHSMatrixKernelManaged();
//Inherited method override
- void run() override
- {
- _output.allocator()->allocate();
- CLScheduler::get().enqueue(_kernel, false);
- _reshape_run = true;
- }
+ void run() override;
//Inherited method override
- void release() override
- {
- _output.allocator()->free();
- }
+ void release() override;
//Inherited method override
- ICLTensor *get_weights() override
- {
- return &_output;
- }
+ ICLTensor *get_weights() override;
//Inherited method override
- uint32_t uid() override
- {
- return _uid;
- }
+ uint32_t uid() override;
/** Configures the @ref CLGEMMReshapeRHSMatrixKernel kernel
*
* @param[in] input Input tensor. Data types supported: All
* @param[in] info RHS matrix information to be used for reshaping.
*/
- void configure(const ICLTensor *input, GEMMRHSMatrixInfo info)
- {
- configure(CLKernelLibrary::get().get_compile_context(), input, info);
- }
+ void configure(const ICLTensor *input, GEMMRHSMatrixInfo info);
/** Configures the @ref CLGEMMReshapeRHSMatrixKernel kernel
*
@@ -88,15 +87,12 @@ public:
* @param[in] input Input tensor. Data types supported: All
* @param[in] info RHS matrix information to be used for reshaping.
*/
- void configure(const CLCompileContext &compile_context, const ICLTensor *input, GEMMRHSMatrixInfo info)
- {
- _kernel.configure(compile_context, input, &_output, info);
- }
+ void configure(const CLCompileContext &compile_context, const ICLTensor *input, GEMMRHSMatrixInfo info);
private:
- static constexpr uint32_t _uid = 0x15;
- CLTensor _output{};
- CLGEMMReshapeRHSMatrixKernel _kernel{};
+ static constexpr uint32_t _uid{ 0x15 };
+ CLTensor _output{};
+ std::unique_ptr<CLGEMMReshapeRHSMatrixKernel> _kernel;
};
} // namespace weights_transformations
@@ -126,6 +122,8 @@ public:
CLGEMM &operator=(const CLGEMM &) = delete;
/** Default move assignment operator */
CLGEMM &operator=(CLGEMM &&) = default;
+ /** Default destructor */
+ ~CLGEMM();
/** Initialise the kernel's inputs and output
*
* @note GEMM: General Matrix Multiply - [alpha * A * B + beta * C].
@@ -198,24 +196,24 @@ private:
static Status validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
- MemoryGroup _memory_group;
- IWeightsManager *_weights_manager;
- CLGEMMMatrixMultiplyKernel _mm_kernel;
- CLGEMMReshapeLHSMatrixKernel _reshape_lhs_kernel;
- CLGEMMReshapeRHSMatrixKernel _reshape_rhs_kernel;
- weights_transformations::CLGEMMReshapeRHSMatrixKernelManaged _reshape_rhs_kernel_managed;
- CLGEMMMatrixMultiplyReshapedKernel _mm_reshaped_kernel;
- CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_kernel;
- CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_fallback_kernel;
- CLTensor _tmp_a;
- CLTensor _tmp_b;
- const ICLTensor *_original_b;
- const ICLTensor *_lhs;
- ICLTensor *_dst;
- bool _reshape_b_only_on_first_run;
- bool _is_prepared;
- bool _has_pad_y;
- CLGEMMKernelType _gemm_kernel_type;
+ MemoryGroup _memory_group;
+ IWeightsManager *_weights_manager;
+ std::unique_ptr<CLGEMMMatrixMultiplyKernel> _mm_kernel;
+ std::unique_ptr<CLGEMMReshapeLHSMatrixKernel> _reshape_lhs_kernel;
+ std::unique_ptr<CLGEMMReshapeRHSMatrixKernel> _reshape_rhs_kernel;
+ std::unique_ptr<weights_transformations::CLGEMMReshapeRHSMatrixKernelManaged> _reshape_rhs_kernel_managed;
+ std::unique_ptr<CLGEMMMatrixMultiplyReshapedKernel> _mm_reshaped_kernel;
+ std::unique_ptr<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel> _mm_reshaped_only_rhs_kernel;
+ std::unique_ptr<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel> _mm_reshaped_only_rhs_fallback_kernel;
+ CLTensor _tmp_a;
+ CLTensor _tmp_b;
+ const ICLTensor *_original_b;
+ const ICLTensor *_lhs;
+ ICLTensor *_dst;
+ bool _reshape_b_only_on_first_run;
+ bool _is_prepared;
+ bool _has_pad_y;
+ CLGEMMKernelType _gemm_kernel_type;
};
} // namespace arm_compute