diff options
Diffstat (limited to 'src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp')
-rw-r--r-- | src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp | 19 |
1 files changed, 7 insertions, 12 deletions
diff --git a/src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp b/src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp index 2ca1ff59df..07f90ddaef 100644 --- a/src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp +++ b/src/runtime/gpu/cl/operators/ClWinogradConv2d.cpp @@ -233,37 +233,32 @@ Status ClWinogradConv2d::validate(const ITensorInfo *src, const ITensorInfo *wei void ClWinogradConv2d::run(ITensorPack &tensors) { - prepare(tensors); + const bool is_gemm_reshaped = _aux_mem[3].lifetime == MemoryLifetime::Prepare; auto src = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_0)); auto biases = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_2)); auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST)); CLAuxTensorHandler input0(offset_int_vec(2), _input0, tensors, true); - CLAuxTensorHandler input1(offset_int_vec(3), _input1, tensors, true); + CLAuxTensorHandler input1(offset_int_vec(3), _input1, tensors, true, is_gemm_reshaped); CLAuxTensorHandler batched_mm_output(offset_int_vec(4), _batched_mm_output, tensors, true); + prepare(tensors); + // Run input transform ITensorPack pack_it { { TensorType::ACL_SRC, src }, { TensorType::ACL_DST, input0.get() }, }; - CLScheduler::get().enqueue_op(_border_handler, pack_it); - CLScheduler::get().enqueue_op(*_input_transform, pack_it); + CLScheduler::get().enqueue_op(_border_handler, pack_it, false); + CLScheduler::get().enqueue_op(*_input_transform, pack_it, false); // Run batched matrix multiplication ITensorPack pack_mm = tensors; pack_mm.add_const_tensor(TensorType::ACL_SRC_0, input0.get()); pack_mm.add_tensor(TensorType::ACL_DST, batched_mm_output.get()); - if(_aux_mem[3].lifetime == MemoryLifetime::Prepare) - { - pack_mm.remove_tensor(TensorType::ACL_SRC_1); - } - else - { - pack_mm.add_const_tensor(TensorType::ACL_SRC_1, input1.get()); - } + is_gemm_reshaped ? pack_mm.remove_tensor(TensorType::ACL_SRC_1) : pack_mm.add_const_tensor(TensorType::ACL_SRC_1, input1.get()); _batched_mm.run(pack_mm); // Run output transform |