aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLGEMM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLGEMM.cpp')
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp112
1 files changed, 92 insertions, 20 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 762b00177c..2a027d872c 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -36,6 +36,7 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/helpers/float_ops.h"
+#include "arm_compute/core/utils/misc/Cast.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
#include "arm_compute/runtime/ITensorAllocator.h"
@@ -44,12 +45,15 @@ namespace arm_compute
{
using namespace arm_compute::misc::shape_calculator;
using namespace arm_compute::cl_gemm;
+using namespace arm_compute::utils::cast;
-CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager)
+CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
: _memory_group(std::move(memory_manager)),
+ _weights_manager(weights_manager),
_mm_kernel(),
_reshape_lhs_kernel(),
_reshape_rhs_kernel(),
+ _reshape_rhs_kernel_managed(),
_mm_reshaped_kernel(),
_mm_reshaped_only_rhs_kernel(),
_tmp_a(),
@@ -178,8 +182,12 @@ void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const
GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
+ const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
+
+ // Manage intermediate buffers
_memory_group.manage(&_tmp_a);
- if(!_reshape_b_only_on_first_run)
+
+ if(!_reshape_b_only_on_first_run && use_mm_b)
{
_memory_group.manage(&_tmp_b);
}
@@ -188,16 +196,26 @@ void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const
_reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, reinterpret_input_as_3d);
// Configure transpose kernel
- _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+ ICLTensor *reshaped_rhs = &_tmp_b;
+ if(_weights_manager && _weights_manager->are_weights_managed(b))
+ {
+ _reshape_rhs_kernel_managed.configure(b, rhs_info);
+ reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
+ }
+ else
+ {
+ _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+ }
// Configure and tune matrix multiply kernel
- _mm_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
+ _mm_kernel.configure(&_tmp_a, reshaped_rhs, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
CLScheduler::get().tune_kernel_static(_mm_kernel);
// Allocate intermediate tensors
_tmp_a.allocator()->allocate();
- if(!_reshape_b_only_on_first_run)
+
+ if(!_reshape_b_only_on_first_run && use_mm_b)
{
_tmp_b.allocator()->allocate();
}
@@ -228,12 +246,16 @@ void CLGEMM::configure_reshaped_v2(const ICLTensor *a, const ICLTensor *b, const
_reshape_lhs_kernel.set_target(gpu_target);
_mm_kernel.set_target(gpu_target);
+ const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
+
// Manage intermediate buffers
_memory_group.manage(&_tmp_a);
- if(!_reshape_b_only_on_first_run)
+
+ if(!_reshape_b_only_on_first_run && use_mm_b)
{
_memory_group.manage(&_tmp_b);
}
+
// _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
GEMMLHSMatrixInfo lhs_info{};
@@ -247,14 +269,25 @@ void CLGEMM::configure_reshaped_v2(const ICLTensor *a, const ICLTensor *b, const
std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
_reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
- _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+
+ ICLTensor *reshaped_rhs = &_tmp_b;
+ if(_weights_manager && _weights_manager->are_weights_managed(b))
+ {
+ _reshape_rhs_kernel_managed.configure(b, rhs_info);
+ reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
+ }
+ else
+ {
+ _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+ }
// Configure and tune matrix multiply kernel
- _mm_reshaped_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
+ _mm_reshaped_kernel.configure(&_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
// Allocate intermediate tensors
_tmp_a.allocator()->allocate();
- if(!_reshape_b_only_on_first_run)
+
+ if(!_reshape_b_only_on_first_run && use_mm_b)
{
_tmp_b.allocator()->allocate();
}
@@ -284,8 +317,10 @@ void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b,
// Set the target for the kernels
_mm_kernel.set_target(gpu_target);
+ const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
+
// Manage intermediate buffers
- if(!_reshape_b_only_on_first_run)
+ if(!_reshape_b_only_on_first_run && use_mm_b)
{
_memory_group.manage(&_tmp_b);
}
@@ -300,12 +335,21 @@ void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b,
// Configure lhs_info and rhs_info
std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
- _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+ ICLTensor *reshaped_rhs = &_tmp_b;
+ if(_weights_manager && _weights_manager->are_weights_managed(b))
+ {
+ _reshape_rhs_kernel_managed.configure(b, rhs_info);
+ reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
+ }
+ else
+ {
+ _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+ }
// Configure and tune matrix multiply kernel
- _mm_reshaped_only_rhs_kernel.configure(a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
+ _mm_reshaped_only_rhs_kernel.configure(a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
- if(!_reshape_b_only_on_first_run)
+ if(!_reshape_b_only_on_first_run && use_mm_b)
{
_tmp_b.allocator()->allocate();
}
@@ -607,7 +651,14 @@ void CLGEMM::run()
if(!_reshape_b_only_on_first_run)
{
// Run transpose kernel
- CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+ if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
+ {
+ _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
+ }
+ else
+ {
+ CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+ }
}
CLScheduler::get().enqueue(_mm_kernel, true);
@@ -621,7 +672,14 @@ void CLGEMM::run()
if(!_reshape_b_only_on_first_run)
{
// Run transpose kernel
- CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+ if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
+ {
+ _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
+ }
+ else
+ {
+ CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+ }
}
CLScheduler::get().enqueue(_mm_reshaped_kernel, true);
@@ -632,7 +690,14 @@ void CLGEMM::run()
if(!_reshape_b_only_on_first_run)
{
// Run transpose kernel
- CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+ if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
+ {
+ _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
+ }
+ else
+ {
+ CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+ }
}
CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, true);
@@ -651,10 +716,17 @@ void CLGEMM::prepare()
{
if(_gemm_type != GEMMType::NATIVE && _reshape_b_only_on_first_run)
{
- // Run transpose kernel and mark original weights tensor as unused
- _tmp_b.allocator()->allocate();
- CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
- _original_b->mark_as_unused();
+ if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
+ {
+ _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
+ }
+ else
+ {
+ // Run transpose kernel and mark original weights tensor as unused
+ _tmp_b.allocator()->allocate();
+ CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+ _original_b->mark_as_unused();
+ }
}
CLScheduler::get().queue().finish();
_is_prepared = true;