aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp32
1 files changed, 21 insertions, 11 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 1bc785a0a7..35126ec0d7 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -53,7 +53,7 @@ struct CLGEMM::Impl
ITensorPack prep_pack{};
MemoryRequirements aux_mem_req{};
WorkspaceData<CLTensor> workspace_tensors{};
- bool _is_prepared{ false };
+ bool is_prepared{ false };
};
CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
@@ -74,19 +74,29 @@ void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
- _impl->a = a;
- _impl->b = b;
- _impl->c = c;
- _impl->dst = output;
- _impl->op = std::make_unique<OperatorType>();
+ _impl->a = a;
+ _impl->b = b;
+ _impl->c = c;
+ _impl->dst = output;
+ _impl->op = std::make_unique<OperatorType>();
+ _impl->is_prepared = gemm_info.retain_internal_weights();
_impl->op->configure(compile_context, a->info(), b->info(), c != nullptr ? c->info() : nullptr, output->info(), alpha, beta, gemm_info);
_impl->aux_mem_req = _impl->op->workspace();
// Manage/allocate auxilairy tensors
- _impl->run_pack = { { ACL_SRC_0, _impl->a }, { ACL_SRC_2, _impl->c }, { ACL_DST, _impl->dst } };
- _impl->prep_pack = { { ACL_SRC_1, _impl->b } };
- _impl->workspace_tensors = manage_workspace<CLTensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack, _impl->prep_pack);
+ if(_impl->is_prepared)
+ {
+ _impl->run_pack.add_const_tensor(ACL_SRC_0, _impl->a);
+ _impl->run_pack.add_tensor(ACL_DST, _impl->dst);
+ }
+ else
+ {
+ _impl->run_pack = { { ACL_SRC_0, _impl->a }, { ACL_SRC_2, _impl->c }, { ACL_DST, _impl->dst } };
+ _impl->prep_pack = { { ACL_SRC_1, _impl->b } };
+
+ _impl->workspace_tensors = manage_workspace<CLTensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack, _impl->prep_pack);
+ }
}
Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
@@ -106,7 +116,7 @@ void CLGEMM::run()
void CLGEMM::prepare()
{
- if(!_impl->_is_prepared)
+ if(!_impl->is_prepared)
{
_impl->op->prepare(_impl->prep_pack);
@@ -123,7 +133,7 @@ void CLGEMM::prepare()
// Pack the B matrix to be used as the underlying GEMM performs no reshapes
_impl->run_pack.add_const_tensor(ACL_SRC_1, _impl->b);
}
- _impl->_is_prepared = true;
+ _impl->is_prepared = true;
}
}
} // namespace arm_compute