diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-09-27 11:04:27 +0100 |
---|---|---|
committer | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-10-03 15:59:01 +0000 |
commit | b27e13a0ad630d3d9b3143c0374b5ff5000eebc0 (patch) | |
tree | 86defdbcd080fb8ab7f22c8c46e7793eeac80640 /arm_compute/runtime/CL/functions/CLGEMM.h | |
parent | 2ff0009ca9245304c48889c8ba8d3a39d42febed (diff) | |
download | ComputeLibrary-b27e13a0ad630d3d9b3143c0374b5ff5000eebc0.tar.gz |
COMPMID-2685: [CL] Use Weights manager
Change-Id: Ia1818e6ecd9386e96378e64f14d02592fe3cdf0f
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1997
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLGEMM.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLGEMM.h | 81 |
1 files changed, 67 insertions, 14 deletions
diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h index b8e5fa67dd..3691fe9e21 100644 --- a/arm_compute/runtime/CL/functions/CLGEMM.h +++ b/arm_compute/runtime/CL/functions/CLGEMM.h @@ -32,12 +32,62 @@ #include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/IMemoryManager.h" +#include "arm_compute/runtime/IWeightsManager.h" #include "arm_compute/runtime/MemoryGroup.h" namespace arm_compute { class ICLTensor; +namespace weights_transformations +{ +/** Basic function to manage the reshape weights generated from @ref CLGEMMReshapeRHSMatrixKernel */ +class CLGEMMReshapeRHSMatrixKernelManaged : public ITransformWeights +{ +public: + //Inherited method override + void run() override + { + _output.allocator()->allocate(); + CLScheduler::get().enqueue(_kernel, false); + _reshape_run = true; + } + + //Inherited method override + void release() override + { + _output.allocator()->free(); + } + + //Inherited method override + ICLTensor *get_weights() override + { + return &_output; + } + + //Inherited method override + uint32_t uid() override + { + return _uid; + } + + /** Configures the @ref CLGEMMReshapeRHSMatrixKernel kernel + * + * @param[in] input Input tensor. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32 + * @param[in] info RHS matrix information to be used for reshaping. + */ + void configure(const ICLTensor *input, GEMMRHSMatrixInfo info) + { + _kernel.configure(input, &_output, info); + } + +private: + static constexpr uint32_t _uid = 0x15; + CLTensor _output{}; + CLGEMMReshapeRHSMatrixKernel _kernel{}; +}; +} // namespace weights_transformations + /** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels: * * -# @ref CLGEMMReshapeLHSMatrixKernel (only if the RESHAPED_V1 is selected by the heuristic model) @@ -52,9 +102,10 @@ class CLGEMM : public IFunction public: /** Default constructor. * - * @param[in] memory_manager (Optional) Memory manager. + * @param[in] memory_manager (Optional) Memory manager. + * @param[in] weights_manager (Optional) Weights manager. */ - CLGEMM(std::shared_ptr<IMemoryManager> memory_manager = nullptr); + CLGEMM(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr); /** Prevent instances of this class from being copied (As this class contains pointers) */ CLGEMM(const CLGEMM &) = delete; /** Default move constructor */ @@ -123,18 +174,20 @@ private: static Status validate_reshaped_v2(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); - MemoryGroup _memory_group; - CLGEMMMatrixMultiplyKernel _mm_kernel; - CLGEMMReshapeLHSMatrixKernel _reshape_lhs_kernel; - CLGEMMReshapeRHSMatrixKernel _reshape_rhs_kernel; - CLGEMMMatrixMultiplyReshapedKernel _mm_reshaped_kernel; - CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_kernel; - CLTensor _tmp_a; - CLTensor _tmp_b; - const ICLTensor *_original_b; - bool _reshape_b_only_on_first_run; - bool _is_prepared; - GEMMType _gemm_type; + MemoryGroup _memory_group; + IWeightsManager *_weights_manager; + CLGEMMMatrixMultiplyKernel _mm_kernel; + CLGEMMReshapeLHSMatrixKernel _reshape_lhs_kernel; + CLGEMMReshapeRHSMatrixKernel _reshape_rhs_kernel; + weights_transformations::CLGEMMReshapeRHSMatrixKernelManaged _reshape_rhs_kernel_managed; + CLGEMMMatrixMultiplyReshapedKernel _mm_reshaped_kernel; + CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_kernel; + CLTensor _tmp_a; + CLTensor _tmp_b; + const ICLTensor *_original_b; + bool _reshape_b_only_on_first_run; + bool _is_prepared; + GEMMType _gemm_type; }; } // namespace arm_compute |