aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMConv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMConv2d.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMConv2d.cpp79
1 files changed, 61 insertions, 18 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMConv2d.cpp b/src/runtime/NEON/functions/NEGEMMConv2d.cpp
index 564ce2f514..6cca02eea9 100644
--- a/src/runtime/NEON/functions/NEGEMMConv2d.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConv2d.cpp
@@ -24,50 +24,93 @@
#include "arm_compute/runtime/NEON/functions/NEGEMMConv2d.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
-#include "arm_compute/runtime/NEON/NEScheduler.h"
-#include "src/runtime/cpu/operators/CpuGemmDirectConv2d.h"
+#include "arm_compute/runtime/Tensor.h"
-#include <set>
+#include "src/core/helpers/MemoryHelpers.h"
+#include "src/cpu/operators/CpuGemmDirectConv2d.h"
namespace arm_compute
{
using OperatorType = cpu::CpuGemmDirectConv2d;
+using namespace arm_compute::experimental;
struct NEGEMMConv2d::Impl
{
- ITensorPack tensors{};
- std::unique_ptr<OperatorType> op{ nullptr };
+ const ITensor *weights{nullptr};
+ std::unique_ptr<OperatorType> op{nullptr};
+ ITensorPack run_pack{};
+ ITensorPack prep_pack{};
+ WorkspaceData<Tensor> workspace{};
+ MemoryGroup memory_group{};
+ bool is_prepared{false};
+ experimental::MemoryRequirements aux_mem_req{};
};
-NEGEMMConv2d::NEGEMMConv2d(const std::shared_ptr<IMemoryManager> &memory_manager)
- : _impl(std::make_unique<Impl>())
+NEGEMMConv2d::NEGEMMConv2d(const std::shared_ptr<IMemoryManager> &memory_manager) : _impl(std::make_unique<Impl>())
{
- _impl->op = std::make_unique<OperatorType>(memory_manager);
+ _impl->memory_group = MemoryGroup(memory_manager);
}
NEGEMMConv2d::~NEGEMMConv2d() = default;
-void NEGEMMConv2d::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const Conv2dInfo &info)
+void NEGEMMConv2d::configure(
+ ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const Conv2dInfo &info)
{
- _impl->tensors.add_const_tensor(TensorType::ACL_SRC_0, input);
- _impl->tensors.add_const_tensor(TensorType::ACL_SRC_1, weights);
- _impl->tensors.add_const_tensor(TensorType::ACL_SRC_2, biases);
- _impl->tensors.add_tensor(TensorType::ACL_DST, output);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
- _impl->op->configure(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), info);
+ _impl->weights = weights;
+ _impl->is_prepared = false;
+ _impl->op = std::make_unique<OperatorType>();
+
+ _impl->op->configure(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(),
+ info);
+
+ _impl->aux_mem_req = _impl->op->workspace();
+ _impl->run_pack = {{TensorType::ACL_SRC_0, input}, {TensorType::ACL_SRC_2, biases}, {TensorType::ACL_DST, output}};
+ _impl->prep_pack = {{TensorType::ACL_SRC_1, weights}, {TensorType::ACL_SRC_2, biases}};
+ _impl->workspace =
+ manage_workspace<Tensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack, _impl->prep_pack);
}
-Status NEGEMMConv2d::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const Conv2dInfo &info)
+Status NEGEMMConv2d::validate(const ITensorInfo *input,
+ const ITensorInfo *weights,
+ const ITensorInfo *biases,
+ const ITensorInfo *output,
+ const Conv2dInfo &info)
{
return OperatorType::validate(input, weights, biases, output, info);
}
+
void NEGEMMConv2d::run()
{
- _impl->op->run(_impl->tensors);
+ prepare();
+
+ MemoryGroupResourceScope scope_mg(_impl->memory_group);
+ _impl->op->run(_impl->run_pack);
}
+
void NEGEMMConv2d::prepare()
{
- _impl->op->prepare(_impl->tensors);
+ if (!_impl->is_prepared)
+ {
+ _impl->op->prepare(_impl->prep_pack);
+
+ auto has_reshape =
+ std::find_if(_impl->aux_mem_req.begin(), _impl->aux_mem_req.end(),
+ [](const MemoryInfo &m) -> bool { return m.lifetime == MemoryLifetime::Persistent; });
+
+ if (has_reshape != std::end(_impl->aux_mem_req))
+ {
+ _impl->weights->mark_as_unused();
+ }
+ else
+ {
+ _impl->run_pack.add_const_tensor(ACL_SRC_1, _impl->weights);
+ }
+
+ // Release temporary tensors that are only used in prepare stage
+ release_temporaries<Tensor>(_impl->aux_mem_req, _impl->workspace);
+ _impl->is_prepared = true;
+ }
}
} // namespace arm_compute