diff options
Diffstat (limited to 'src/runtime/gpu/cl/operators/ClGemm.cpp')
-rw-r--r-- | src/runtime/gpu/cl/operators/ClGemm.cpp | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/src/runtime/gpu/cl/operators/ClGemm.cpp b/src/runtime/gpu/cl/operators/ClGemm.cpp index cb0eecae4b..2792dc470d 100644 --- a/src/runtime/gpu/cl/operators/ClGemm.cpp +++ b/src/runtime/gpu/cl/operators/ClGemm.cpp @@ -208,6 +208,7 @@ ClGemm::ClGemm() _tmp_b(), _reshape_b_only_on_first_run(false), _gemm_kernel_type(CLGEMMKernelType::NATIVE_V1), + _is_prepared(false), _aux_mem(AuxTensorIdx::Count) { } @@ -696,6 +697,7 @@ void ClGemm::run(ITensorPack &tensors) } ITensorPack gemm_reshaped_pack{ { ACL_SRC_0, lhs_reshaped.get() }, { ACL_SRC_1, rhs_reshaped.get() }, { ACL_SRC_2, src2 }, { ACL_DST, dst } }; + if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED) { CLScheduler::get().enqueue_op(*_mm_reshaped_kernel, gemm_reshaped_pack, true); @@ -740,19 +742,23 @@ void ClGemm::run(ITensorPack &tensors) void ClGemm::prepare(ITensorPack &constants) { - const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1); - ICLTensor *rhs_aux = utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape))); - - // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed - if((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && (src1 != nullptr && rhs_aux != nullptr) && rhs_aux) + if(!_is_prepared) { - ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!"); + const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1); + ICLTensor *rhs_aux = utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape))); - CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux); - ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr); + // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed + if((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && (src1 != nullptr && rhs_aux != nullptr) && rhs_aux) + { + ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!"); - ITensorPack reshape_rhs_pack{ { ACL_SRC, src1 }, { ACL_DST, rhs_reshaped.get() } }; - CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true); + CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux); + ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr); + + ITensorPack reshape_rhs_pack{ { ACL_SRC, src1 }, { ACL_DST, rhs_reshaped.get() } }; + CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true); + } + _is_prepared = true; } } |