From b27e13a0ad630d3d9b3143c0374b5ff5000eebc0 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Fri, 27 Sep 2019 11:04:27 +0100 Subject: COMPMID-2685: [CL] Use Weights manager Change-Id: Ia1818e6ecd9386e96378e64f14d02592fe3cdf0f Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/1997 Comments-Addressed: Arm Jenkins Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- src/runtime/CL/functions/CLFullyConnectedLayer.cpp | 89 ++++++++++++++++------ 1 file changed, 67 insertions(+), 22 deletions(-) (limited to 'src/runtime/CL/functions/CLFullyConnectedLayer.cpp') diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp index 0452a236c5..91f722fdce 100644 --- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp @@ -25,6 +25,7 @@ #include "arm_compute/core/Size2D.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/Cast.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/CL/CLScheduler.h" @@ -32,8 +33,10 @@ #include -using namespace arm_compute; +namespace arm_compute +{ using namespace arm_compute::misc::shape_calculator; +using namespace arm_compute::utils::cast; namespace { @@ -77,9 +80,10 @@ Status CLFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, c } CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr memory_manager, IWeightsManager *weights_manager) - : _memory_group(memory_manager), _convert_weights(), _flatten_layer(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), - _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true), - _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _is_prepared(false), _original_weights(nullptr) + : _memory_group(memory_manager), _weights_manager(weights_manager), _convert_weights(), _convert_weights_managed(), _reshape_weights_managed_function(), _flatten_layer(), _reshape_weights_function(), + _mm_gemm(memory_manager, weights_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(), _converted_weights_output(), + _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _is_prepared(false), + _original_weights(nullptr) { } void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights) @@ -157,6 +161,11 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w _is_prepared = fc_info.retain_internal_weights; _original_weights = weights; + if(_weights_manager) + { + _weights_manager->manage(weights); + } + // Configure gemmlowp output if(_is_quantized) { @@ -199,21 +208,39 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w // Reshape weights if needed if(!_are_weights_reshaped) { - // Reshape the weights - _reshape_weights_kernel.configure(weights, &_reshape_weights_output); - weights_to_use = &_reshape_weights_output; + if(_weights_manager && _weights_manager->are_weights_managed(weights)) + { + _reshape_weights_managed_function.configure(weights); + weights_to_use = utils::cast::polymorphic_downcast(_weights_manager->acquire(weights, &_reshape_weights_managed_function)); + } + else + { + // Reshape the weights + _reshape_weights_function.configure(weights, &_reshape_weights_output); + weights_to_use = &_reshape_weights_output; + } } // Convert weights if needed if(_is_fc_after_conv && (input->info()->data_layout() != fc_info.weights_trained_layout)) { - // Convert weights - _convert_weights.configure(weights_to_use, - &_converted_weights_output, - input->info()->tensor_shape(), - fc_info.weights_trained_layout); + if(_weights_manager && _weights_manager->are_weights_managed(weights_to_use)) + { + _convert_weights_managed.configure(weights_to_use, + input->info()->tensor_shape(), + fc_info.weights_trained_layout); + weights_to_use = utils::cast::polymorphic_downcast(_weights_manager->acquire(weights, &_convert_weights_managed)); + } + else + { + // Convert weights + _convert_weights.configure(weights_to_use, + &_converted_weights_output, + input->info()->tensor_shape(), + fc_info.weights_trained_layout); - weights_to_use = &_converted_weights_output; + weights_to_use = &_converted_weights_output; + } _are_weights_converted = false; } @@ -384,7 +411,10 @@ void CLFullyConnectedLayer::prepare() { if(!_is_prepared) { - ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); + if(!_weights_manager) + { + ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); + } auto release_unused = [](CLTensor * w) { @@ -401,22 +431,36 @@ void CLFullyConnectedLayer::prepare() // Reshape of the weights if needed (happens only once) if(!_are_weights_reshaped) { - // Run reshape weights kernel and mark weights as unused - _reshape_weights_output.allocator()->allocate(); - _reshape_weights_kernel.run(); + if(_weights_manager && _weights_manager->are_weights_managed(_original_weights)) + { + cur_weights = utils::cast::polymorphic_downcast(_weights_manager->run(cur_weights, &_reshape_weights_managed_function)); + } + else + { + // Run reshape weights kernel and mark weights as unused + _reshape_weights_output.allocator()->allocate(); + _reshape_weights_function.run(); - cur_weights->mark_as_unused(); - cur_weights = &_reshape_weights_output; + cur_weights->mark_as_unused(); + cur_weights = &_reshape_weights_output; + } _are_weights_reshaped = true; } // Convert weights if needed (happens only once) if(!_are_weights_converted) { - _converted_weights_output.allocator()->allocate(); - _convert_weights.run(); + if(_weights_manager && _weights_manager->are_weights_managed(cur_weights)) + { + _weights_manager->run(cur_weights, &_convert_weights_managed); + } + else + { + _converted_weights_output.allocator()->allocate(); + _convert_weights.run(); + cur_weights->mark_as_unused(); + } - cur_weights->mark_as_unused(); _are_weights_converted = true; } @@ -436,3 +480,4 @@ void CLFullyConnectedLayer::prepare() _is_prepared = true; } } +} // namespace arm_compute -- cgit v1.2.1