diff options
Diffstat (limited to 'src/runtime/CL/functions')
-rw-r--r-- | src/runtime/CL/functions/CLGEMM.cpp | 32 |
1 files changed, 21 insertions, 11 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index 1bc785a0a7..35126ec0d7 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -53,7 +53,7 @@ struct CLGEMM::Impl ITensorPack prep_pack{}; MemoryRequirements aux_mem_req{}; WorkspaceData<CLTensor> workspace_tensors{}; - bool _is_prepared{ false }; + bool is_prepared{ false }; }; CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager) @@ -74,19 +74,29 @@ void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output); - _impl->a = a; - _impl->b = b; - _impl->c = c; - _impl->dst = output; - _impl->op = std::make_unique<OperatorType>(); + _impl->a = a; + _impl->b = b; + _impl->c = c; + _impl->dst = output; + _impl->op = std::make_unique<OperatorType>(); + _impl->is_prepared = gemm_info.retain_internal_weights(); _impl->op->configure(compile_context, a->info(), b->info(), c != nullptr ? c->info() : nullptr, output->info(), alpha, beta, gemm_info); _impl->aux_mem_req = _impl->op->workspace(); // Manage/allocate auxilairy tensors - _impl->run_pack = { { ACL_SRC_0, _impl->a }, { ACL_SRC_2, _impl->c }, { ACL_DST, _impl->dst } }; - _impl->prep_pack = { { ACL_SRC_1, _impl->b } }; - _impl->workspace_tensors = manage_workspace<CLTensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack, _impl->prep_pack); + if(_impl->is_prepared) + { + _impl->run_pack.add_const_tensor(ACL_SRC_0, _impl->a); + _impl->run_pack.add_tensor(ACL_DST, _impl->dst); + } + else + { + _impl->run_pack = { { ACL_SRC_0, _impl->a }, { ACL_SRC_2, _impl->c }, { ACL_DST, _impl->dst } }; + _impl->prep_pack = { { ACL_SRC_1, _impl->b } }; + + _impl->workspace_tensors = manage_workspace<CLTensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack, _impl->prep_pack); + } } Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) @@ -106,7 +116,7 @@ void CLGEMM::run() void CLGEMM::prepare() { - if(!_impl->_is_prepared) + if(!_impl->is_prepared) { _impl->op->prepare(_impl->prep_pack); @@ -123,7 +133,7 @@ void CLGEMM::prepare() // Pack the B matrix to be used as the underlying GEMM performs no reshapes _impl->run_pack.add_const_tensor(ACL_SRC_1, _impl->b); } - _impl->_is_prepared = true; + _impl->is_prepared = true; } } } // namespace arm_compute |