From b62280aca3148dd6762e57e5af3da0cb0a9e2db5 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Thu, 31 May 2018 17:31:05 +0100 Subject: COMPMID-1244: Allow retaining weights in CLGEMMConvolutionLayer and CLFullyConnectedLayer Change-Id: I1c3b2197906cd4b905309bbd5f2012bbae6a7dba Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/133730 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- .../CL/functions/CLGEMMConvolutionLayer.cpp | 25 +++++++++++++--------- 1 file changed, 15 insertions(+), 10 deletions(-) (limited to 'src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp') diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp index 610eec4d67..4f87043373 100644 --- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp @@ -91,7 +91,8 @@ void CLConvolutionLayerReshapeWeights::run() CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr memory_manager) : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _col2im_kernel(), _activationlayer_function(), - _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false) + _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false), + _retain_internal_weights(false) { } @@ -165,9 +166,10 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * dilation, act_info)); - _is_prepared = false; - _original_weights = weights; - _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type()); + _is_prepared = false; + _original_weights = weights; + _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type()); + _retain_internal_weights = weights_info.retain_internal_weights(); const DataType dt = input->info()->data_type(); @@ -404,17 +406,20 @@ void CLGEMMConvolutionLayer::prepare() { if(!_is_prepared) { - // Run weights reshaping and mark as unused - ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); - _weights_reshaped.allocator()->allocate(); - _reshape_weights.run(); - _original_weights->mark_as_unused(); + if(!_retain_internal_weights) + { + // Run weights reshaping and mark as unused + ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); + _weights_reshaped.allocator()->allocate(); + _reshape_weights.run(); + _original_weights->mark_as_unused(); + } // Run GEMM prepare if(!_is_quantized) { _mm_gemm.prepare(); - if(!_weights_reshaped.is_used()) + if(!_weights_reshaped.is_used() && !_retain_internal_weights) { _weights_reshaped.allocator()->free(); } -- cgit v1.2.1