From b27e13a0ad630d3d9b3143c0374b5ff5000eebc0 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Fri, 27 Sep 2019 11:04:27 +0100 Subject: COMPMID-2685: [CL] Use Weights manager Change-Id: Ia1818e6ecd9386e96378e64f14d02592fe3cdf0f Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/1997 Comments-Addressed: Arm Jenkins Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- .../CL/functions/CLConvertFullyConnectedWeights.h | 51 ++++++++++ .../runtime/CL/functions/CLFullyConnectedLayer.h | 89 ++++++++++++---- arm_compute/runtime/CL/functions/CLGEMM.h | 81 ++++++++++++--- .../runtime/CL/functions/CLGEMMConvolutionLayer.h | 76 ++++++++++++-- src/graph/GraphContext.cpp | 2 +- src/graph/backends/CL/CLDeviceBackend.cpp | 14 ++- src/runtime/CL/functions/CLFullyConnectedLayer.cpp | 89 ++++++++++++---- src/runtime/CL/functions/CLGEMM.cpp | 112 +++++++++++++++++---- .../CL/functions/CLGEMMConvolutionLayer.cpp | 52 ++++++++-- 9 files changed, 469 insertions(+), 97 deletions(-) diff --git a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h index 43abb6769b..e4e6f0760e 100644 --- a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h +++ b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h @@ -25,7 +25,9 @@ #define __ARM_COMPUTE_CLCONVERTFULLYCONNECTEDWEIGHTS_H__ #include "arm_compute/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.h" +#include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/CL/ICLSimpleFunction.h" +#include "arm_compute/runtime/ITransformWeights.h" namespace arm_compute { @@ -54,5 +56,54 @@ public: */ static Status validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape, DataLayout data_layout); }; + +namespace weights_transformations +{ +/** Basic function to run @ref CLConvertFullyConnectedWeightsKernel. */ +class CLConvertFullyConnectedWeightsManaged : public ITransformWeights +{ +public: + //Inherited method override + void run() override + { + _output.allocator()->allocate(); + _func.run(); + _reshape_run = true; + } + + //Inherited method override + void release() override + { + _output.allocator()->free(); + } + + //Inherited method override + ICLTensor *get_weights() override + { + return &_output; + } + + //Inherited method override + uint32_t uid() override + { + return _uid; + } + /** Configures the @ref CLConvertFullyConnectedWeights function + * + * @param[in] input Source weights tensor info to convert. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32. + * @param[in] original_input_shape Shape of the original input tensor (the one entering fully connected layer). + * @param[in] data_layout The data layout the weights have been trained in. + */ + void configure(const ICLTensor *input, const TensorShape &original_input_shape, DataLayout data_layout) + { + _func.configure(input, &_output, original_input_shape, data_layout); + } + +private: + static constexpr uint32_t _uid = 0x5; + CLTensor _output{}; + CLConvertFullyConnectedWeights _func{}; +}; +} // namespace weights_transformations } // namespace arm_compute #endif /* __ARM_COMPUTE_CLCONVERTFULLYCONNECTEDWEIGHTS_H__ */ diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h index d54304ed77..9512b22c08 100644 --- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h +++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h @@ -64,6 +64,54 @@ public: static Status validate(const ITensorInfo *input, const ITensorInfo *output); }; +namespace weights_transformations +{ +/** Basic function to manage the reshape weights generated from @ref CLFullyConnectedLayerReshapeWeights */ +class CLFullyConnectedLayerReshapeWeightsManaged : public ITransformWeights +{ +public: + //Inherited method override + void run() override + { + _output.allocator()->allocate(); + _func.run(); + _reshape_run = true; + } + + //Inherited method override + void release() override + { + _output.allocator()->free(); + } + + //Inherited method override + ICLTensor *get_weights() override + { + return &_output; + } + + //Inherited method override + uint32_t uid() override + { + return _uid; + } + + /** Configures the @ref CLFullyConnectedLayerReshapeWeights function + * + * @param[in] input Source tensor. Data type supported: QASYMM8/F16/F32. + */ + void configure(const ICLTensor *input) + { + _func.configure(input, &_output); + } + +private: + static constexpr uint32_t _uid = 0x0; + CLTensor _output{}; + CLFullyConnectedLayerReshapeWeights _func{}; +}; +} // namespace weights_transformations + /** Basic function to compute a Fully Connected layer on OpenCL. This function calls the following OpenCL kernels: * * -# @ref CLIm2ColKernel (called when the input comes from a convolutional layer) @@ -130,25 +178,28 @@ private: void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights); void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights); - MemoryGroup _memory_group; - CLConvertFullyConnectedWeights _convert_weights; - CLFlattenLayer _flatten_layer; - CLFullyConnectedLayerReshapeWeights _reshape_weights_kernel; - CLGEMM _mm_gemm; - CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; - CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage; - CLGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel; // TODO(COMPMID-1889): Use CLGEMM to add bias in CLFullyConnectedLayer - CLTensor _flatten_output; - CLTensor _gemmlowp_output; - CLTensor _converted_weights_output; - CLTensor _reshape_weights_output; - bool _are_weights_converted; - bool _are_weights_reshaped; - bool _is_fc_after_conv; - bool _accumulate_biases; - bool _is_quantized; - bool _is_prepared; - const ICLTensor *_original_weights; + MemoryGroup _memory_group; + IWeightsManager *_weights_manager; + CLConvertFullyConnectedWeights _convert_weights; + weights_transformations::CLConvertFullyConnectedWeightsManaged _convert_weights_managed; + weights_transformations::CLFullyConnectedLayerReshapeWeightsManaged _reshape_weights_managed_function; + CLFlattenLayer _flatten_layer; + CLFullyConnectedLayerReshapeWeights _reshape_weights_function; + CLGEMM _mm_gemm; + CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; + CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage; + CLGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel; // TODO(COMPMID-1889): Use CLGEMM to add bias in CLFullyConnectedLayer + CLTensor _flatten_output; + CLTensor _gemmlowp_output; + CLTensor _converted_weights_output; + CLTensor _reshape_weights_output; + bool _are_weights_converted; + bool _are_weights_reshaped; + bool _is_fc_after_conv; + bool _accumulate_biases; + bool _is_quantized; + bool _is_prepared; + const ICLTensor *_original_weights; }; } // namespace arm_compute #endif /* __ARM_COMPUTE_CLFULLYCONNECTEDLAYER_H__ */ diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h index b8e5fa67dd..3691fe9e21 100644 --- a/arm_compute/runtime/CL/functions/CLGEMM.h +++ b/arm_compute/runtime/CL/functions/CLGEMM.h @@ -32,12 +32,62 @@ #include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/IMemoryManager.h" +#include "arm_compute/runtime/IWeightsManager.h" #include "arm_compute/runtime/MemoryGroup.h" namespace arm_compute { class ICLTensor; +namespace weights_transformations +{ +/** Basic function to manage the reshape weights generated from @ref CLGEMMReshapeRHSMatrixKernel */ +class CLGEMMReshapeRHSMatrixKernelManaged : public ITransformWeights +{ +public: + //Inherited method override + void run() override + { + _output.allocator()->allocate(); + CLScheduler::get().enqueue(_kernel, false); + _reshape_run = true; + } + + //Inherited method override + void release() override + { + _output.allocator()->free(); + } + + //Inherited method override + ICLTensor *get_weights() override + { + return &_output; + } + + //Inherited method override + uint32_t uid() override + { + return _uid; + } + + /** Configures the @ref CLGEMMReshapeRHSMatrixKernel kernel + * + * @param[in] input Input tensor. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32 + * @param[in] info RHS matrix information to be used for reshaping. + */ + void configure(const ICLTensor *input, GEMMRHSMatrixInfo info) + { + _kernel.configure(input, &_output, info); + } + +private: + static constexpr uint32_t _uid = 0x15; + CLTensor _output{}; + CLGEMMReshapeRHSMatrixKernel _kernel{}; +}; +} // namespace weights_transformations + /** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels: * * -# @ref CLGEMMReshapeLHSMatrixKernel (only if the RESHAPED_V1 is selected by the heuristic model) @@ -52,9 +102,10 @@ class CLGEMM : public IFunction public: /** Default constructor. * - * @param[in] memory_manager (Optional) Memory manager. + * @param[in] memory_manager (Optional) Memory manager. + * @param[in] weights_manager (Optional) Weights manager. */ - CLGEMM(std::shared_ptr memory_manager = nullptr); + CLGEMM(std::shared_ptr memory_manager = nullptr, IWeightsManager *weights_manager = nullptr); /** Prevent instances of this class from being copied (As this class contains pointers) */ CLGEMM(const CLGEMM &) = delete; /** Default move constructor */ @@ -123,18 +174,20 @@ private: static Status validate_reshaped_v2(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); - MemoryGroup _memory_group; - CLGEMMMatrixMultiplyKernel _mm_kernel; - CLGEMMReshapeLHSMatrixKernel _reshape_lhs_kernel; - CLGEMMReshapeRHSMatrixKernel _reshape_rhs_kernel; - CLGEMMMatrixMultiplyReshapedKernel _mm_reshaped_kernel; - CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_kernel; - CLTensor _tmp_a; - CLTensor _tmp_b; - const ICLTensor *_original_b; - bool _reshape_b_only_on_first_run; - bool _is_prepared; - GEMMType _gemm_type; + MemoryGroup _memory_group; + IWeightsManager *_weights_manager; + CLGEMMMatrixMultiplyKernel _mm_kernel; + CLGEMMReshapeLHSMatrixKernel _reshape_lhs_kernel; + CLGEMMReshapeRHSMatrixKernel _reshape_rhs_kernel; + weights_transformations::CLGEMMReshapeRHSMatrixKernelManaged _reshape_rhs_kernel_managed; + CLGEMMMatrixMultiplyReshapedKernel _mm_reshaped_kernel; + CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_kernel; + CLTensor _tmp_a; + CLTensor _tmp_b; + const ICLTensor *_original_b; + bool _reshape_b_only_on_first_run; + bool _is_prepared; + GEMMType _gemm_type; }; } // namespace arm_compute diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h index 0b27c824d9..017bf78938 100644 --- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h +++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h @@ -39,6 +39,8 @@ #include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h" #include "arm_compute/runtime/CL/functions/CLReshapeLayer.h" #include "arm_compute/runtime/IMemoryManager.h" +#include "arm_compute/runtime/ITransformWeights.h" +#include "arm_compute/runtime/IWeightsManager.h" #include "arm_compute/runtime/MemoryGroup.h" #include @@ -82,6 +84,59 @@ private: CLWeightsReshapeKernel _weights_reshape_kernel; }; +namespace weights_transformations +{ +/** Basic function to manage the reshape weights generated from @ref CLConvolutionLayerReshapeWeights */ +class CLConvolutionLayerReshapeWeightsTransform : public ITransformWeights +{ +public: + /** Configures the @ref CLConvolutionLayerReshapeWeights function + * + * @param[in] input Input tensor. Data type supported: QASYMM8/F16/F32. + * @param[in] biases Biases tensor. Data type supported: Same as @p input. + * @param[in] num_groups Number of groups when performing a grouped convolution. + */ + void configure(const ICLTensor *input, const ICLTensor *biases, unsigned int num_groups) + { + _bias_bit = (biases != nullptr) ? 1 : 0; + _num_groups = num_groups; + _func.configure(input, biases, &_output, num_groups); + } + + //Inherited method override + void run() override + { + _output.allocator()->allocate(); + _func.run(); + _reshape_run = true; + } + + //Inherited method override + ICLTensor *get_weights() override + { + return &_output; + } + + //Inherited method override + void release() override + { + _output.allocator()->free(); + } + + //Inherited method override + uint32_t uid() override + { + return ((0x9) | (_bias_bit << 7) | (_num_groups << 8)); + } + +private: + CLTensor _output{}; + CLConvolutionLayerReshapeWeights _func{}; + int32_t _bias_bit{ 0 }; + unsigned int _num_groups{ 0 }; +}; +} // namespace weights_transformations + /** Basic function to compute the convolution layer. This function calls the following OpenCL kernels/functions: * * -# @ref CLIm2ColKernel @@ -96,9 +151,10 @@ class CLGEMMConvolutionLayer : public IFunction public: /** Constructor * - * @param[in] memory_manager (Optional) Memory manager. + * @param[in] memory_manager (Optional) Memory manager. + * @param[in] weights_manager (Optional) Weights manager. */ - CLGEMMConvolutionLayer(std::shared_ptr memory_manager = nullptr); + CLGEMMConvolutionLayer(std::shared_ptr memory_manager = nullptr, IWeightsManager *weights_manager = nullptr); /** Prevent instances of this class from being copied (As this class contains pointers) */ CLGEMMConvolutionLayer(const CLGEMMConvolutionLayer &) = delete; /** Default move constructor */ @@ -186,13 +242,15 @@ private: int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info); private: - MemoryGroup _memory_group; - CLConvolutionLayerReshapeWeights _reshape_weights; - CLIm2ColKernel _im2col_kernel; - CLGEMM _mm_gemm; - CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; - CLCol2ImKernel _col2im_kernel; - CLActivationLayer _activationlayer_function; + MemoryGroup _memory_group; + IWeightsManager *_weights_manager; + CLConvolutionLayerReshapeWeights _reshape_weights; + weights_transformations::CLConvolutionLayerReshapeWeightsTransform _reshape_weights_managed; + CLIm2ColKernel _im2col_kernel; + CLGEMM _mm_gemm; + CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; + CLCol2ImKernel _col2im_kernel; + CLActivationLayer _activationlayer_function; const ICLTensor *_original_weights; diff --git a/src/graph/GraphContext.cpp b/src/graph/GraphContext.cpp index c959d5e35c..4d978073e1 100644 --- a/src/graph/GraphContext.cpp +++ b/src/graph/GraphContext.cpp @@ -79,7 +79,7 @@ bool GraphContext::insert_weights_management_ctx(WeightsManagerContext &&weights { Target target = weights_managers.target; - if(target != Target::NEON || _weights_managers.find(target) != std::end(_weights_managers)) + if(_weights_managers.find(target) != std::end(_weights_managers)) { return false; } diff --git a/src/graph/backends/CL/CLDeviceBackend.cpp b/src/graph/backends/CL/CLDeviceBackend.cpp index 9b7c879b2a..ea3b6b801a 100644 --- a/src/graph/backends/CL/CLDeviceBackend.cpp +++ b/src/graph/backends/CL/CLDeviceBackend.cpp @@ -38,6 +38,7 @@ #include "arm_compute/runtime/BlobLifetimeManager.h" #include "arm_compute/runtime/CL/CLBufferAllocator.h" #include "arm_compute/runtime/CL/CLScheduler.h" +#include "arm_compute/runtime/IWeightsManager.h" #include "arm_compute/runtime/MemoryGroup.h" #include "arm_compute/runtime/MemoryManagerOnDemand.h" #include "arm_compute/runtime/PoolManager.h" @@ -137,6 +138,16 @@ void CLDeviceBackend::setup_backend_context(GraphContext &ctx) ctx.insert_memory_management_ctx(std::move(mm_ctx)); } + + // Create function level weights manager + if(ctx.weights_management_ctx(Target::CL) == nullptr) + { + WeightsManagerContext wm_ctx; + wm_ctx.target = Target::CL; + wm_ctx.wm = create_weights_manager(); + + ctx.insert_weights_management_ctx(std::move(wm_ctx)); + } } bool CLDeviceBackend::is_backend_supported() @@ -207,7 +218,8 @@ std::shared_ptr CLDeviceBackend::create_memory_mana std::shared_ptr CLDeviceBackend::create_weights_manager() { - return nullptr; + auto weights_mgr = std::make_shared(); + return weights_mgr; } } // namespace backends } // namespace graph diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp index 0452a236c5..91f722fdce 100644 --- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp @@ -25,6 +25,7 @@ #include "arm_compute/core/Size2D.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/Cast.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/CL/CLScheduler.h" @@ -32,8 +33,10 @@ #include -using namespace arm_compute; +namespace arm_compute +{ using namespace arm_compute::misc::shape_calculator; +using namespace arm_compute::utils::cast; namespace { @@ -77,9 +80,10 @@ Status CLFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, c } CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr memory_manager, IWeightsManager *weights_manager) - : _memory_group(memory_manager), _convert_weights(), _flatten_layer(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), - _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true), - _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _is_prepared(false), _original_weights(nullptr) + : _memory_group(memory_manager), _weights_manager(weights_manager), _convert_weights(), _convert_weights_managed(), _reshape_weights_managed_function(), _flatten_layer(), _reshape_weights_function(), + _mm_gemm(memory_manager, weights_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(), _converted_weights_output(), + _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _is_prepared(false), + _original_weights(nullptr) { } void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights) @@ -157,6 +161,11 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w _is_prepared = fc_info.retain_internal_weights; _original_weights = weights; + if(_weights_manager) + { + _weights_manager->manage(weights); + } + // Configure gemmlowp output if(_is_quantized) { @@ -199,21 +208,39 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w // Reshape weights if needed if(!_are_weights_reshaped) { - // Reshape the weights - _reshape_weights_kernel.configure(weights, &_reshape_weights_output); - weights_to_use = &_reshape_weights_output; + if(_weights_manager && _weights_manager->are_weights_managed(weights)) + { + _reshape_weights_managed_function.configure(weights); + weights_to_use = utils::cast::polymorphic_downcast(_weights_manager->acquire(weights, &_reshape_weights_managed_function)); + } + else + { + // Reshape the weights + _reshape_weights_function.configure(weights, &_reshape_weights_output); + weights_to_use = &_reshape_weights_output; + } } // Convert weights if needed if(_is_fc_after_conv && (input->info()->data_layout() != fc_info.weights_trained_layout)) { - // Convert weights - _convert_weights.configure(weights_to_use, - &_converted_weights_output, - input->info()->tensor_shape(), - fc_info.weights_trained_layout); + if(_weights_manager && _weights_manager->are_weights_managed(weights_to_use)) + { + _convert_weights_managed.configure(weights_to_use, + input->info()->tensor_shape(), + fc_info.weights_trained_layout); + weights_to_use = utils::cast::polymorphic_downcast(_weights_manager->acquire(weights, &_convert_weights_managed)); + } + else + { + // Convert weights + _convert_weights.configure(weights_to_use, + &_converted_weights_output, + input->info()->tensor_shape(), + fc_info.weights_trained_layout); - weights_to_use = &_converted_weights_output; + weights_to_use = &_converted_weights_output; + } _are_weights_converted = false; } @@ -384,7 +411,10 @@ void CLFullyConnectedLayer::prepare() { if(!_is_prepared) { - ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); + if(!_weights_manager) + { + ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); + } auto release_unused = [](CLTensor * w) { @@ -401,22 +431,36 @@ void CLFullyConnectedLayer::prepare() // Reshape of the weights if needed (happens only once) if(!_are_weights_reshaped) { - // Run reshape weights kernel and mark weights as unused - _reshape_weights_output.allocator()->allocate(); - _reshape_weights_kernel.run(); + if(_weights_manager && _weights_manager->are_weights_managed(_original_weights)) + { + cur_weights = utils::cast::polymorphic_downcast(_weights_manager->run(cur_weights, &_reshape_weights_managed_function)); + } + else + { + // Run reshape weights kernel and mark weights as unused + _reshape_weights_output.allocator()->allocate(); + _reshape_weights_function.run(); - cur_weights->mark_as_unused(); - cur_weights = &_reshape_weights_output; + cur_weights->mark_as_unused(); + cur_weights = &_reshape_weights_output; + } _are_weights_reshaped = true; } // Convert weights if needed (happens only once) if(!_are_weights_converted) { - _converted_weights_output.allocator()->allocate(); - _convert_weights.run(); + if(_weights_manager && _weights_manager->are_weights_managed(cur_weights)) + { + _weights_manager->run(cur_weights, &_convert_weights_managed); + } + else + { + _converted_weights_output.allocator()->allocate(); + _convert_weights.run(); + cur_weights->mark_as_unused(); + } - cur_weights->mark_as_unused(); _are_weights_converted = true; } @@ -436,3 +480,4 @@ void CLFullyConnectedLayer::prepare() _is_prepared = true; } } +} // namespace arm_compute diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index 762b00177c..2a027d872c 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -36,6 +36,7 @@ #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/helpers/float_ops.h" +#include "arm_compute/core/utils/misc/Cast.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/CL/CLScheduler.h" #include "arm_compute/runtime/ITensorAllocator.h" @@ -44,12 +45,15 @@ namespace arm_compute { using namespace arm_compute::misc::shape_calculator; using namespace arm_compute::cl_gemm; +using namespace arm_compute::utils::cast; -CLGEMM::CLGEMM(std::shared_ptr memory_manager) +CLGEMM::CLGEMM(std::shared_ptr memory_manager, IWeightsManager *weights_manager) : _memory_group(std::move(memory_manager)), + _weights_manager(weights_manager), _mm_kernel(), _reshape_lhs_kernel(), _reshape_rhs_kernel(), + _reshape_rhs_kernel_managed(), _mm_reshaped_kernel(), _mm_reshaped_only_rhs_kernel(), _tmp_a(), @@ -178,8 +182,12 @@ void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias()); + const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b)); + + // Manage intermediate buffers _memory_group.manage(&_tmp_a); - if(!_reshape_b_only_on_first_run) + + if(!_reshape_b_only_on_first_run && use_mm_b) { _memory_group.manage(&_tmp_b); } @@ -188,16 +196,26 @@ void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, reinterpret_input_as_3d); // Configure transpose kernel - _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); + ICLTensor *reshaped_rhs = &_tmp_b; + if(_weights_manager && _weights_manager->are_weights_managed(b)) + { + _reshape_rhs_kernel_managed.configure(b, rhs_info); + reshaped_rhs = utils::cast::polymorphic_downcast(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed)); + } + else + { + _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); + } // Configure and tune matrix multiply kernel - _mm_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info()); + _mm_kernel.configure(&_tmp_a, reshaped_rhs, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info()); CLScheduler::get().tune_kernel_static(_mm_kernel); // Allocate intermediate tensors _tmp_a.allocator()->allocate(); - if(!_reshape_b_only_on_first_run) + + if(!_reshape_b_only_on_first_run && use_mm_b) { _tmp_b.allocator()->allocate(); } @@ -228,12 +246,16 @@ void CLGEMM::configure_reshaped_v2(const ICLTensor *a, const ICLTensor *b, const _reshape_lhs_kernel.set_target(gpu_target); _mm_kernel.set_target(gpu_target); + const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b)); + // Manage intermediate buffers _memory_group.manage(&_tmp_a); - if(!_reshape_b_only_on_first_run) + + if(!_reshape_b_only_on_first_run && use_mm_b) { _memory_group.manage(&_tmp_b); } + // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel GEMMLHSMatrixInfo lhs_info{}; @@ -247,14 +269,25 @@ void CLGEMM::configure_reshaped_v2(const ICLTensor *a, const ICLTensor *b, const std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type); _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d()); - _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); + + ICLTensor *reshaped_rhs = &_tmp_b; + if(_weights_manager && _weights_manager->are_weights_managed(b)) + { + _reshape_rhs_kernel_managed.configure(b, rhs_info); + reshaped_rhs = utils::cast::polymorphic_downcast(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed)); + } + else + { + _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); + } // Configure and tune matrix multiply kernel - _mm_reshaped_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info); + _mm_reshaped_kernel.configure(&_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info); // Allocate intermediate tensors _tmp_a.allocator()->allocate(); - if(!_reshape_b_only_on_first_run) + + if(!_reshape_b_only_on_first_run && use_mm_b) { _tmp_b.allocator()->allocate(); } @@ -284,8 +317,10 @@ void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, // Set the target for the kernels _mm_kernel.set_target(gpu_target); + const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b)); + // Manage intermediate buffers - if(!_reshape_b_only_on_first_run) + if(!_reshape_b_only_on_first_run && use_mm_b) { _memory_group.manage(&_tmp_b); } @@ -300,12 +335,21 @@ void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, // Configure lhs_info and rhs_info std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type); - _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); + ICLTensor *reshaped_rhs = &_tmp_b; + if(_weights_manager && _weights_manager->are_weights_managed(b)) + { + _reshape_rhs_kernel_managed.configure(b, rhs_info); + reshaped_rhs = utils::cast::polymorphic_downcast(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed)); + } + else + { + _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); + } // Configure and tune matrix multiply kernel - _mm_reshaped_only_rhs_kernel.configure(a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info); + _mm_reshaped_only_rhs_kernel.configure(a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info); - if(!_reshape_b_only_on_first_run) + if(!_reshape_b_only_on_first_run && use_mm_b) { _tmp_b.allocator()->allocate(); } @@ -607,7 +651,14 @@ void CLGEMM::run() if(!_reshape_b_only_on_first_run) { // Run transpose kernel - CLScheduler::get().enqueue(_reshape_rhs_kernel, false); + if(_weights_manager && _weights_manager->are_weights_managed(_original_b)) + { + _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed); + } + else + { + CLScheduler::get().enqueue(_reshape_rhs_kernel, false); + } } CLScheduler::get().enqueue(_mm_kernel, true); @@ -621,7 +672,14 @@ void CLGEMM::run() if(!_reshape_b_only_on_first_run) { // Run transpose kernel - CLScheduler::get().enqueue(_reshape_rhs_kernel, false); + if(_weights_manager && _weights_manager->are_weights_managed(_original_b)) + { + _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed); + } + else + { + CLScheduler::get().enqueue(_reshape_rhs_kernel, false); + } } CLScheduler::get().enqueue(_mm_reshaped_kernel, true); @@ -632,7 +690,14 @@ void CLGEMM::run() if(!_reshape_b_only_on_first_run) { // Run transpose kernel - CLScheduler::get().enqueue(_reshape_rhs_kernel, false); + if(_weights_manager && _weights_manager->are_weights_managed(_original_b)) + { + _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed); + } + else + { + CLScheduler::get().enqueue(_reshape_rhs_kernel, false); + } } CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, true); @@ -651,10 +716,17 @@ void CLGEMM::prepare() { if(_gemm_type != GEMMType::NATIVE && _reshape_b_only_on_first_run) { - // Run transpose kernel and mark original weights tensor as unused - _tmp_b.allocator()->allocate(); - CLScheduler::get().enqueue(_reshape_rhs_kernel, false); - _original_b->mark_as_unused(); + if(_weights_manager && _weights_manager->are_weights_managed(_original_b)) + { + _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed); + } + else + { + // Run transpose kernel and mark original weights tensor as unused + _tmp_b.allocator()->allocate(); + CLScheduler::get().enqueue(_reshape_rhs_kernel, false); + _original_b->mark_as_unused(); + } } CLScheduler::get().queue().finish(); _is_prepared = true; diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp index 594c8eef34..831f108b85 100644 --- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp @@ -27,6 +27,7 @@ #include "arm_compute/core/Size2D.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/Cast.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/CL/CLScheduler.h" @@ -35,8 +36,10 @@ #include #include -using namespace arm_compute; +namespace arm_compute +{ using namespace arm_compute::misc::shape_calculator; +using namespace arm_compute::utils::cast; CLConvolutionLayerReshapeWeights::CLConvolutionLayerReshapeWeights() : _weights_reshape_kernel() @@ -90,9 +93,10 @@ void CLConvolutionLayerReshapeWeights::run() CLScheduler::get().enqueue(_weights_reshape_kernel); } -CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr memory_manager) - : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), - _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _is_prepared(false) +CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr memory_manager, IWeightsManager *weights_manager) + : _memory_group(memory_manager), _weights_manager(weights_manager), _reshape_weights(), _reshape_weights_managed(), _im2col_kernel(), _mm_gemm(memory_manager, weights_manager), + _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _skip_im2col(false), + _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _is_prepared(false) { } @@ -238,6 +242,7 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * const ICLTensor *biases_to_use = biases; bool append_bias = false; + ICLTensor *weights_to_use = &_weights_reshaped; if(num_groups != 1 && biases != nullptr) { // num_groups != 1 can only be for NCHW @@ -245,11 +250,27 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * biases_to_use = nullptr; append_bias = true; - _reshape_weights.configure(weights, biases, &_weights_reshaped, num_groups); + if(_weights_manager && _weights_manager->are_weights_managed(weights)) + { + _reshape_weights_managed.configure(weights, biases, num_groups); + weights_to_use = utils::cast::polymorphic_downcast(_weights_manager->acquire(weights, &_reshape_weights_managed)); + } + else + { + _reshape_weights.configure(weights, biases, &_weights_reshaped, num_groups); + } } else { - _reshape_weights.configure(weights, nullptr, &_weights_reshaped, num_groups); + if(_weights_manager && _weights_manager->are_weights_managed(weights)) + { + _reshape_weights_managed.configure(weights, nullptr, num_groups); + weights_to_use = utils::cast::polymorphic_downcast(_weights_manager->acquire(weights, &_reshape_weights_managed)); + } + else + { + _reshape_weights.configure(weights, nullptr, &_weights_reshaped, num_groups); + } } // Create tensor to store im2col reshaped inputs @@ -340,7 +361,7 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * // In case of NHWC, we need to run GEMM3D (gemm_3d_depth != 0) in order to avoid reshaping the output matrix const unsigned int gemm_3d_depth = (data_layout == DataLayout::NHWC) ? conv_h : 0; - configure_mm(gemm_input_to_use, &_weights_reshaped, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, act_info); + configure_mm(gemm_input_to_use, weights_to_use, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, act_info); if(!_skip_im2col) { @@ -601,10 +622,18 @@ void CLGEMMConvolutionLayer::prepare() { if(!_is_prepared) { - // Run weights reshaping and mark original weights tensor as unused - _weights_reshaped.allocator()->allocate(); - _reshape_weights.run(); - _original_weights->mark_as_unused(); + ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); + if(_weights_manager && _weights_manager->are_weights_managed(_original_weights)) + { + _weights_manager->run(_original_weights, &_reshape_weights_managed); + } + else + { + // Run weights reshaping and mark original weights tensor as unused + _weights_reshaped.allocator()->allocate(); + _reshape_weights.run(); + _original_weights->mark_as_unused(); + } // Prepare GEMM _is_quantized ? _mm_gemmlowp.prepare() : _mm_gemm.prepare(); @@ -617,3 +646,4 @@ void CLGEMMConvolutionLayer::prepare() _is_prepared = true; } } +} // namespace arm_compute -- cgit v1.2.1