diff options
Diffstat (limited to 'src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp | 52 |
1 files changed, 41 insertions, 11 deletions
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp index 594c8eef34..831f108b85 100644 --- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp @@ -27,6 +27,7 @@ #include "arm_compute/core/Size2D.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/Cast.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/CL/CLScheduler.h" @@ -35,8 +36,10 @@ #include <memory> #include <tuple> -using namespace arm_compute; +namespace arm_compute +{ using namespace arm_compute::misc::shape_calculator; +using namespace arm_compute::utils::cast; CLConvolutionLayerReshapeWeights::CLConvolutionLayerReshapeWeights() : _weights_reshape_kernel() @@ -90,9 +93,10 @@ void CLConvolutionLayerReshapeWeights::run() CLScheduler::get().enqueue(_weights_reshape_kernel); } -CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager) - : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), - _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _is_prepared(false) +CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager) + : _memory_group(memory_manager), _weights_manager(weights_manager), _reshape_weights(), _reshape_weights_managed(), _im2col_kernel(), _mm_gemm(memory_manager, weights_manager), + _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _skip_im2col(false), + _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _is_prepared(false) { } @@ -238,6 +242,7 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * const ICLTensor *biases_to_use = biases; bool append_bias = false; + ICLTensor *weights_to_use = &_weights_reshaped; if(num_groups != 1 && biases != nullptr) { // num_groups != 1 can only be for NCHW @@ -245,11 +250,27 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * biases_to_use = nullptr; append_bias = true; - _reshape_weights.configure(weights, biases, &_weights_reshaped, num_groups); + if(_weights_manager && _weights_manager->are_weights_managed(weights)) + { + _reshape_weights_managed.configure(weights, biases, num_groups); + weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_reshape_weights_managed)); + } + else + { + _reshape_weights.configure(weights, biases, &_weights_reshaped, num_groups); + } } else { - _reshape_weights.configure(weights, nullptr, &_weights_reshaped, num_groups); + if(_weights_manager && _weights_manager->are_weights_managed(weights)) + { + _reshape_weights_managed.configure(weights, nullptr, num_groups); + weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_reshape_weights_managed)); + } + else + { + _reshape_weights.configure(weights, nullptr, &_weights_reshaped, num_groups); + } } // Create tensor to store im2col reshaped inputs @@ -340,7 +361,7 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * // In case of NHWC, we need to run GEMM3D (gemm_3d_depth != 0) in order to avoid reshaping the output matrix const unsigned int gemm_3d_depth = (data_layout == DataLayout::NHWC) ? conv_h : 0; - configure_mm(gemm_input_to_use, &_weights_reshaped, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, act_info); + configure_mm(gemm_input_to_use, weights_to_use, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, act_info); if(!_skip_im2col) { @@ -601,10 +622,18 @@ void CLGEMMConvolutionLayer::prepare() { if(!_is_prepared) { - // Run weights reshaping and mark original weights tensor as unused - _weights_reshaped.allocator()->allocate(); - _reshape_weights.run(); - _original_weights->mark_as_unused(); + ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); + if(_weights_manager && _weights_manager->are_weights_managed(_original_weights)) + { + _weights_manager->run(_original_weights, &_reshape_weights_managed); + } + else + { + // Run weights reshaping and mark original weights tensor as unused + _weights_reshaped.allocator()->allocate(); + _reshape_weights.run(); + _original_weights->mark_as_unused(); + } // Prepare GEMM _is_quantized ? _mm_gemmlowp.prepare() : _mm_gemm.prepare(); @@ -617,3 +646,4 @@ void CLGEMMConvolutionLayer::prepare() _is_prepared = true; } } +} // namespace arm_compute |