aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp')
-rw-r--r--src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp32
1 files changed, 22 insertions, 10 deletions
diff --git a/src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp b/src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp
index c8db697778..2ca1ff59df 100644
--- a/src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp
+++ b/src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp
@@ -212,9 +212,15 @@ void ClWinogradConv2d::configure(const ClCompileContext &compile_context, ITenso
// Configure output transform
_output_transform->configure(compile_context, &_batched_mm_output, biases, dst, winograd_info, act_info);
- _aux_mem = _batched_mm.workspace();
+ _aux_mem = _batched_mm.workspace();
+ const MemoryLifetime wino_wei_lifetm = std::any_of(std::begin(_aux_mem), std::end(_aux_mem), [](const auto & r)
+ {
+ return (r.lifetime == MemoryLifetime::Persistent) && (r.size > 0);
+ }) ?
+ MemoryLifetime::Prepare :
+ MemoryLifetime::Persistent;
_aux_mem.push_back(MemoryInfo(offset_int_vec(2), MemoryLifetime::Temporary, _input0.total_size()));
- _aux_mem.push_back(MemoryInfo(offset_int_vec(3), MemoryLifetime::Persistent, _input1.total_size()));
+ _aux_mem.push_back(MemoryInfo(offset_int_vec(3), wino_wei_lifetm, _input1.total_size()));
_aux_mem.push_back(MemoryInfo(offset_int_vec(4), MemoryLifetime::Temporary, _batched_mm_output.total_size()));
}
@@ -229,7 +235,6 @@ void ClWinogradConv2d::run(ITensorPack &tensors)
{
prepare(tensors);
- // Run input transform
auto src = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_0));
auto biases = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_2));
auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST));
@@ -238,6 +243,7 @@ void ClWinogradConv2d::run(ITensorPack &tensors)
CLAuxTensorHandler input1(offset_int_vec(3), _input1, tensors, true);
CLAuxTensorHandler batched_mm_output(offset_int_vec(4), _batched_mm_output, tensors, true);
+ // Run input transform
ITensorPack pack_it
{
{ TensorType::ACL_SRC, src },
@@ -247,12 +253,17 @@ void ClWinogradConv2d::run(ITensorPack &tensors)
CLScheduler::get().enqueue_op(*_input_transform, pack_it);
// Run batched matrix multiplication
- ITensorPack pack_mm
+ ITensorPack pack_mm = tensors;
+ pack_mm.add_const_tensor(TensorType::ACL_SRC_0, input0.get());
+ pack_mm.add_tensor(TensorType::ACL_DST, batched_mm_output.get());
+ if(_aux_mem[3].lifetime == MemoryLifetime::Prepare)
{
- { TensorType::ACL_SRC_0, input0.get() },
- { TensorType::ACL_SRC_1, input1.get() },
- { TensorType::ACL_DST, batched_mm_output.get() },
- };
+ pack_mm.remove_tensor(TensorType::ACL_SRC_1);
+ }
+ else
+ {
+ pack_mm.add_const_tensor(TensorType::ACL_SRC_1, input1.get());
+ }
_batched_mm.run(pack_mm);
// Run output transform
@@ -282,9 +293,10 @@ void ClWinogradConv2d::prepare(ITensorPack &tensors)
CLScheduler::get().enqueue_op(*_filter_transform, pack_ft, false);
weights->mark_as_unused();
- tensors.add_tensor(ACL_SRC_1, input1.get());
// Prepare GEMM and release reshaped weights if marked unused by ClGemm
- _batched_mm.prepare(tensors);
+ ITensorPack mm_prepare_pack = tensors;
+ mm_prepare_pack.add_tensor(ACL_SRC_1, input1.get());
+ _batched_mm.prepare(mm_prepare_pack);
CLScheduler::get().queue().finish();
_is_prepared = true;