From b27e13a0ad630d3d9b3143c0374b5ff5000eebc0 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Fri, 27 Sep 2019 11:04:27 +0100 Subject: COMPMID-2685: [CL] Use Weights manager Change-Id: Ia1818e6ecd9386e96378e64f14d02592fe3cdf0f Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/1997 Comments-Addressed: Arm Jenkins Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- .../runtime/CL/functions/CLGEMMConvolutionLayer.h | 76 +++++++++++++++++++--- 1 file changed, 67 insertions(+), 9 deletions(-) (limited to 'arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h') diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h index 0b27c824d9..017bf78938 100644 --- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h +++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h @@ -39,6 +39,8 @@ #include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h" #include "arm_compute/runtime/CL/functions/CLReshapeLayer.h" #include "arm_compute/runtime/IMemoryManager.h" +#include "arm_compute/runtime/ITransformWeights.h" +#include "arm_compute/runtime/IWeightsManager.h" #include "arm_compute/runtime/MemoryGroup.h" #include @@ -82,6 +84,59 @@ private: CLWeightsReshapeKernel _weights_reshape_kernel; }; +namespace weights_transformations +{ +/** Basic function to manage the reshape weights generated from @ref CLConvolutionLayerReshapeWeights */ +class CLConvolutionLayerReshapeWeightsTransform : public ITransformWeights +{ +public: + /** Configures the @ref CLConvolutionLayerReshapeWeights function + * + * @param[in] input Input tensor. Data type supported: QASYMM8/F16/F32. + * @param[in] biases Biases tensor. Data type supported: Same as @p input. + * @param[in] num_groups Number of groups when performing a grouped convolution. + */ + void configure(const ICLTensor *input, const ICLTensor *biases, unsigned int num_groups) + { + _bias_bit = (biases != nullptr) ? 1 : 0; + _num_groups = num_groups; + _func.configure(input, biases, &_output, num_groups); + } + + //Inherited method override + void run() override + { + _output.allocator()->allocate(); + _func.run(); + _reshape_run = true; + } + + //Inherited method override + ICLTensor *get_weights() override + { + return &_output; + } + + //Inherited method override + void release() override + { + _output.allocator()->free(); + } + + //Inherited method override + uint32_t uid() override + { + return ((0x9) | (_bias_bit << 7) | (_num_groups << 8)); + } + +private: + CLTensor _output{}; + CLConvolutionLayerReshapeWeights _func{}; + int32_t _bias_bit{ 0 }; + unsigned int _num_groups{ 0 }; +}; +} // namespace weights_transformations + /** Basic function to compute the convolution layer. This function calls the following OpenCL kernels/functions: * * -# @ref CLIm2ColKernel @@ -96,9 +151,10 @@ class CLGEMMConvolutionLayer : public IFunction public: /** Constructor * - * @param[in] memory_manager (Optional) Memory manager. + * @param[in] memory_manager (Optional) Memory manager. + * @param[in] weights_manager (Optional) Weights manager. */ - CLGEMMConvolutionLayer(std::shared_ptr memory_manager = nullptr); + CLGEMMConvolutionLayer(std::shared_ptr memory_manager = nullptr, IWeightsManager *weights_manager = nullptr); /** Prevent instances of this class from being copied (As this class contains pointers) */ CLGEMMConvolutionLayer(const CLGEMMConvolutionLayer &) = delete; /** Default move constructor */ @@ -186,13 +242,15 @@ private: int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info); private: - MemoryGroup _memory_group; - CLConvolutionLayerReshapeWeights _reshape_weights; - CLIm2ColKernel _im2col_kernel; - CLGEMM _mm_gemm; - CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; - CLCol2ImKernel _col2im_kernel; - CLActivationLayer _activationlayer_function; + MemoryGroup _memory_group; + IWeightsManager *_weights_manager; + CLConvolutionLayerReshapeWeights _reshape_weights; + weights_transformations::CLConvolutionLayerReshapeWeightsTransform _reshape_weights_managed; + CLIm2ColKernel _im2col_kernel; + CLGEMM _mm_gemm; + CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; + CLCol2ImKernel _col2im_kernel; + CLActivationLayer _activationlayer_function; const ICLTensor *_original_weights; -- cgit v1.2.1