aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/gpu/cl/operators/ClGemm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/gpu/cl/operators/ClGemm.cpp')
-rw-r--r--src/runtime/gpu/cl/operators/ClGemm.cpp26
1 files changed, 16 insertions, 10 deletions
diff --git a/src/runtime/gpu/cl/operators/ClGemm.cpp b/src/runtime/gpu/cl/operators/ClGemm.cpp
index cb0eecae4b..2792dc470d 100644
--- a/src/runtime/gpu/cl/operators/ClGemm.cpp
+++ b/src/runtime/gpu/cl/operators/ClGemm.cpp
@@ -208,6 +208,7 @@ ClGemm::ClGemm()
_tmp_b(),
_reshape_b_only_on_first_run(false),
_gemm_kernel_type(CLGEMMKernelType::NATIVE_V1),
+ _is_prepared(false),
_aux_mem(AuxTensorIdx::Count)
{
}
@@ -696,6 +697,7 @@ void ClGemm::run(ITensorPack &tensors)
}
ITensorPack gemm_reshaped_pack{ { ACL_SRC_0, lhs_reshaped.get() }, { ACL_SRC_1, rhs_reshaped.get() }, { ACL_SRC_2, src2 }, { ACL_DST, dst } };
+
if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED)
{
CLScheduler::get().enqueue_op(*_mm_reshaped_kernel, gemm_reshaped_pack, true);
@@ -740,19 +742,23 @@ void ClGemm::run(ITensorPack &tensors)
void ClGemm::prepare(ITensorPack &constants)
{
- const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1);
- ICLTensor *rhs_aux = utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape)));
-
- // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed
- if((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && (src1 != nullptr && rhs_aux != nullptr) && rhs_aux)
+ if(!_is_prepared)
{
- ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!");
+ const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1);
+ ICLTensor *rhs_aux = utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape)));
- CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux);
- ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr);
+ // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed
+ if((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && (src1 != nullptr && rhs_aux != nullptr) && rhs_aux)
+ {
+ ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!");
- ITensorPack reshape_rhs_pack{ { ACL_SRC, src1 }, { ACL_DST, rhs_reshaped.get() } };
- CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true);
+ CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux);
+ ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr);
+
+ ITensorPack reshape_rhs_pack{ { ACL_SRC, src1 }, { ACL_DST, rhs_reshaped.get() } };
+ CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true);
+ }
+ _is_prepared = true;
}
}