From b27e13a0ad630d3d9b3143c0374b5ff5000eebc0 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Fri, 27 Sep 2019 11:04:27 +0100 Subject: COMPMID-2685: [CL] Use Weights manager Change-Id: Ia1818e6ecd9386e96378e64f14d02592fe3cdf0f Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/1997 Comments-Addressed: Arm Jenkins Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- .../runtime/CL/functions/CLFullyConnectedLayer.h | 89 +++++++++++++++++----- 1 file changed, 70 insertions(+), 19 deletions(-) (limited to 'arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h') diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h index d54304ed77..9512b22c08 100644 --- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h +++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h @@ -64,6 +64,54 @@ public: static Status validate(const ITensorInfo *input, const ITensorInfo *output); }; +namespace weights_transformations +{ +/** Basic function to manage the reshape weights generated from @ref CLFullyConnectedLayerReshapeWeights */ +class CLFullyConnectedLayerReshapeWeightsManaged : public ITransformWeights +{ +public: + //Inherited method override + void run() override + { + _output.allocator()->allocate(); + _func.run(); + _reshape_run = true; + } + + //Inherited method override + void release() override + { + _output.allocator()->free(); + } + + //Inherited method override + ICLTensor *get_weights() override + { + return &_output; + } + + //Inherited method override + uint32_t uid() override + { + return _uid; + } + + /** Configures the @ref CLFullyConnectedLayerReshapeWeights function + * + * @param[in] input Source tensor. Data type supported: QASYMM8/F16/F32. + */ + void configure(const ICLTensor *input) + { + _func.configure(input, &_output); + } + +private: + static constexpr uint32_t _uid = 0x0; + CLTensor _output{}; + CLFullyConnectedLayerReshapeWeights _func{}; +}; +} // namespace weights_transformations + /** Basic function to compute a Fully Connected layer on OpenCL. This function calls the following OpenCL kernels: * * -# @ref CLIm2ColKernel (called when the input comes from a convolutional layer) @@ -130,25 +178,28 @@ private: void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights); void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights); - MemoryGroup _memory_group; - CLConvertFullyConnectedWeights _convert_weights; - CLFlattenLayer _flatten_layer; - CLFullyConnectedLayerReshapeWeights _reshape_weights_kernel; - CLGEMM _mm_gemm; - CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; - CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage; - CLGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel; // TODO(COMPMID-1889): Use CLGEMM to add bias in CLFullyConnectedLayer - CLTensor _flatten_output; - CLTensor _gemmlowp_output; - CLTensor _converted_weights_output; - CLTensor _reshape_weights_output; - bool _are_weights_converted; - bool _are_weights_reshaped; - bool _is_fc_after_conv; - bool _accumulate_biases; - bool _is_quantized; - bool _is_prepared; - const ICLTensor *_original_weights; + MemoryGroup _memory_group; + IWeightsManager *_weights_manager; + CLConvertFullyConnectedWeights _convert_weights; + weights_transformations::CLConvertFullyConnectedWeightsManaged _convert_weights_managed; + weights_transformations::CLFullyConnectedLayerReshapeWeightsManaged _reshape_weights_managed_function; + CLFlattenLayer _flatten_layer; + CLFullyConnectedLayerReshapeWeights _reshape_weights_function; + CLGEMM _mm_gemm; + CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; + CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage; + CLGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel; // TODO(COMPMID-1889): Use CLGEMM to add bias in CLFullyConnectedLayer + CLTensor _flatten_output; + CLTensor _gemmlowp_output; + CLTensor _converted_weights_output; + CLTensor _reshape_weights_output; + bool _are_weights_converted; + bool _are_weights_reshaped; + bool _is_fc_after_conv; + bool _accumulate_biases; + bool _is_quantized; + bool _is_prepared; + const ICLTensor *_original_weights; }; } // namespace arm_compute #endif /* __ARM_COMPUTE_CLFULLYCONNECTEDLAYER_H__ */ -- cgit v1.2.1