diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEFullyConnectedLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEFullyConnectedLayer.cpp | 19 |
1 files changed, 2 insertions, 17 deletions
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp index d815a73b93..504200e9ce 100644 --- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp +++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp @@ -44,7 +44,6 @@ struct NEFullyConnectedLayer::Impl const ITensor *original_weights{ nullptr }; ITensorPack run_pack{}; - ITensorPack prep_pack{}; WorkspaceData<Tensor> workspace{}; experimental::MemoryRequirements aux_mem_req{}; @@ -79,8 +78,7 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh _impl->aux_mem_req = _impl->op->workspace(); _impl->run_pack = { { ACL_SRC_0, input }, { ACL_SRC_1, weights }, { ACL_SRC_2, biases }, { ACL_DST, output } }; - _impl->prep_pack = { { ACL_SRC_1, weights }, { ACL_SRC_2, biases } }; - _impl->workspace = manage_workspace<Tensor>(_impl->aux_mem_req, _impl->memory_group, _impl->run_pack, _impl->prep_pack); + _impl->workspace = manage_workspace<Tensor>(_impl->aux_mem_req, _impl->memory_group, _impl->run_pack, _impl->run_pack); } Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, @@ -101,20 +99,7 @@ void NEFullyConnectedLayer::prepare() { if(!_impl->is_prepared) { - _impl->op->prepare(_impl->prep_pack); - - auto has_reshape = std::find_if(_impl->aux_mem_req.begin(), - _impl->aux_mem_req.end(), - [](const MemoryInfo & m) -> bool { return m.lifetime == MemoryLifetime::Persistent; }); - - if(has_reshape != std::end(_impl->aux_mem_req)) - { - _impl->original_weights->mark_as_unused(); - } - else - { - _impl->run_pack.add_const_tensor(ACL_SRC_1, _impl->original_weights); - } + _impl->op->prepare(_impl->run_pack); // Release temporary tensors that are only used in prepare stage release_temporaries<Tensor>(_impl->aux_mem_req, _impl->workspace); |