diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMConv2d.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMConv2d.cpp | 79 |
1 files changed, 61 insertions, 18 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMConv2d.cpp b/src/runtime/NEON/functions/NEGEMMConv2d.cpp index 564ce2f514..6cca02eea9 100644 --- a/src/runtime/NEON/functions/NEGEMMConv2d.cpp +++ b/src/runtime/NEON/functions/NEGEMMConv2d.cpp @@ -24,50 +24,93 @@ #include "arm_compute/runtime/NEON/functions/NEGEMMConv2d.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "arm_compute/core/utils/quantization/AsymmHelpers.h" -#include "arm_compute/runtime/NEON/NEScheduler.h" -#include "src/runtime/cpu/operators/CpuGemmDirectConv2d.h" +#include "arm_compute/runtime/Tensor.h" -#include <set> +#include "src/core/helpers/MemoryHelpers.h" +#include "src/cpu/operators/CpuGemmDirectConv2d.h" namespace arm_compute { using OperatorType = cpu::CpuGemmDirectConv2d; +using namespace arm_compute::experimental; struct NEGEMMConv2d::Impl { - ITensorPack tensors{}; - std::unique_ptr<OperatorType> op{ nullptr }; + const ITensor *weights{nullptr}; + std::unique_ptr<OperatorType> op{nullptr}; + ITensorPack run_pack{}; + ITensorPack prep_pack{}; + WorkspaceData<Tensor> workspace{}; + MemoryGroup memory_group{}; + bool is_prepared{false}; + experimental::MemoryRequirements aux_mem_req{}; }; -NEGEMMConv2d::NEGEMMConv2d(const std::shared_ptr<IMemoryManager> &memory_manager) - : _impl(std::make_unique<Impl>()) +NEGEMMConv2d::NEGEMMConv2d(const std::shared_ptr<IMemoryManager> &memory_manager) : _impl(std::make_unique<Impl>()) { - _impl->op = std::make_unique<OperatorType>(memory_manager); + _impl->memory_group = MemoryGroup(memory_manager); } NEGEMMConv2d::~NEGEMMConv2d() = default; -void NEGEMMConv2d::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const Conv2dInfo &info) +void NEGEMMConv2d::configure( + ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const Conv2dInfo &info) { - _impl->tensors.add_const_tensor(TensorType::ACL_SRC_0, input); - _impl->tensors.add_const_tensor(TensorType::ACL_SRC_1, weights); - _impl->tensors.add_const_tensor(TensorType::ACL_SRC_2, biases); - _impl->tensors.add_tensor(TensorType::ACL_DST, output); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); - _impl->op->configure(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), info); + _impl->weights = weights; + _impl->is_prepared = false; + _impl->op = std::make_unique<OperatorType>(); + + _impl->op->configure(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), + info); + + _impl->aux_mem_req = _impl->op->workspace(); + _impl->run_pack = {{TensorType::ACL_SRC_0, input}, {TensorType::ACL_SRC_2, biases}, {TensorType::ACL_DST, output}}; + _impl->prep_pack = {{TensorType::ACL_SRC_1, weights}, {TensorType::ACL_SRC_2, biases}}; + _impl->workspace = + manage_workspace<Tensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack, _impl->prep_pack); } -Status NEGEMMConv2d::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const Conv2dInfo &info) +Status NEGEMMConv2d::validate(const ITensorInfo *input, + const ITensorInfo *weights, + const ITensorInfo *biases, + const ITensorInfo *output, + const Conv2dInfo &info) { return OperatorType::validate(input, weights, biases, output, info); } + void NEGEMMConv2d::run() { - _impl->op->run(_impl->tensors); + prepare(); + + MemoryGroupResourceScope scope_mg(_impl->memory_group); + _impl->op->run(_impl->run_pack); } + void NEGEMMConv2d::prepare() { - _impl->op->prepare(_impl->tensors); + if (!_impl->is_prepared) + { + _impl->op->prepare(_impl->prep_pack); + + auto has_reshape = + std::find_if(_impl->aux_mem_req.begin(), _impl->aux_mem_req.end(), + [](const MemoryInfo &m) -> bool { return m.lifetime == MemoryLifetime::Persistent; }); + + if (has_reshape != std::end(_impl->aux_mem_req)) + { + _impl->weights->mark_as_unused(); + } + else + { + _impl->run_pack.add_const_tensor(ACL_SRC_1, _impl->weights); + } + + // Release temporary tensors that are only used in prepare stage + release_temporaries<Tensor>(_impl->aux_mem_req, _impl->workspace); + _impl->is_prepared = true; + } } } // namespace arm_compute |