aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp52
1 files changed, 41 insertions, 11 deletions
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index 594c8eef34..831f108b85 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/core/Size2D.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/Cast.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
@@ -35,8 +36,10 @@
#include <memory>
#include <tuple>
-using namespace arm_compute;
+namespace arm_compute
+{
using namespace arm_compute::misc::shape_calculator;
+using namespace arm_compute::utils::cast;
CLConvolutionLayerReshapeWeights::CLConvolutionLayerReshapeWeights()
: _weights_reshape_kernel()
@@ -90,9 +93,10 @@ void CLConvolutionLayerReshapeWeights::run()
CLScheduler::get().enqueue(_weights_reshape_kernel);
}
-CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(),
- _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _is_prepared(false)
+CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+ : _memory_group(memory_manager), _weights_manager(weights_manager), _reshape_weights(), _reshape_weights_managed(), _im2col_kernel(), _mm_gemm(memory_manager, weights_manager),
+ _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _skip_im2col(false),
+ _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _is_prepared(false)
{
}
@@ -238,6 +242,7 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *
const ICLTensor *biases_to_use = biases;
bool append_bias = false;
+ ICLTensor *weights_to_use = &_weights_reshaped;
if(num_groups != 1 && biases != nullptr)
{
// num_groups != 1 can only be for NCHW
@@ -245,11 +250,27 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *
biases_to_use = nullptr;
append_bias = true;
- _reshape_weights.configure(weights, biases, &_weights_reshaped, num_groups);
+ if(_weights_manager && _weights_manager->are_weights_managed(weights))
+ {
+ _reshape_weights_managed.configure(weights, biases, num_groups);
+ weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_reshape_weights_managed));
+ }
+ else
+ {
+ _reshape_weights.configure(weights, biases, &_weights_reshaped, num_groups);
+ }
}
else
{
- _reshape_weights.configure(weights, nullptr, &_weights_reshaped, num_groups);
+ if(_weights_manager && _weights_manager->are_weights_managed(weights))
+ {
+ _reshape_weights_managed.configure(weights, nullptr, num_groups);
+ weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_reshape_weights_managed));
+ }
+ else
+ {
+ _reshape_weights.configure(weights, nullptr, &_weights_reshaped, num_groups);
+ }
}
// Create tensor to store im2col reshaped inputs
@@ -340,7 +361,7 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *
// In case of NHWC, we need to run GEMM3D (gemm_3d_depth != 0) in order to avoid reshaping the output matrix
const unsigned int gemm_3d_depth = (data_layout == DataLayout::NHWC) ? conv_h : 0;
- configure_mm(gemm_input_to_use, &_weights_reshaped, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, act_info);
+ configure_mm(gemm_input_to_use, weights_to_use, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, act_info);
if(!_skip_im2col)
{
@@ -601,10 +622,18 @@ void CLGEMMConvolutionLayer::prepare()
{
if(!_is_prepared)
{
- // Run weights reshaping and mark original weights tensor as unused
- _weights_reshaped.allocator()->allocate();
- _reshape_weights.run();
- _original_weights->mark_as_unused();
+ ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+ if(_weights_manager && _weights_manager->are_weights_managed(_original_weights))
+ {
+ _weights_manager->run(_original_weights, &_reshape_weights_managed);
+ }
+ else
+ {
+ // Run weights reshaping and mark original weights tensor as unused
+ _weights_reshaped.allocator()->allocate();
+ _reshape_weights.run();
+ _original_weights->mark_as_unused();
+ }
// Prepare GEMM
_is_quantized ? _mm_gemmlowp.prepare() : _mm_gemm.prepare();
@@ -617,3 +646,4 @@ void CLGEMMConvolutionLayer::prepare()
_is_prepared = true;
}
}
+} // namespace arm_compute