diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEStackLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEStackLayer.cpp | 37 |
1 files changed, 16 insertions, 21 deletions
diff --git a/src/runtime/NEON/functions/NEStackLayer.cpp b/src/runtime/NEON/functions/NEStackLayer.cpp index 351497c96e..2f88ffca2a 100644 --- a/src/runtime/NEON/functions/NEStackLayer.cpp +++ b/src/runtime/NEON/functions/NEStackLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 ARM Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -31,27 +31,26 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/NEON/NEScheduler.h" +#include "src/common/utils/Log.h" +#include "src/core/NEON/kernels/NEStackLayerKernel.h" + namespace arm_compute { +NEStackLayer::~NEStackLayer() = default; + NEStackLayer::NEStackLayer() // NOLINT - : _input(), - _stack_kernels(), - _num_inputs(0) + : _stack_kernel(std::make_unique<NEStackLayerKernel>()), _is_prepared(false) { } void NEStackLayer::configure(const std::vector<ITensor *> &input, int axis, ITensor *output) { - _num_inputs = input.size(); - _stack_kernels.resize(_num_inputs); + ARM_COMPUTE_LOG_PARAMS(input, axis, output); // Wrap around negative values const unsigned int axis_u = wrap_around(axis, static_cast<int>(input[0]->info()->num_dimensions() + 1)); - for(unsigned int i = 0; i < _num_inputs; i++) - { - _stack_kernels[i].configure(input[i], axis_u, i, _num_inputs, output); - } + _stack_kernel->configure(input, axis_u, output); } Status NEStackLayer::validate(const std::vector<ITensorInfo *> &input, int axis, const ITensorInfo *output) @@ -63,24 +62,20 @@ Status NEStackLayer::validate(const std::vector<ITensorInfo *> &input, int axis, const size_t rank = input[0]->num_dimensions(); const unsigned int axis_u = wrap_around(axis, static_cast<int>(rank + 1)); - const unsigned int num_inputs = input.size(); - - for(unsigned int i = 0; i < num_inputs; i++) - { - // All the tensors must have the same rank - ARM_COMPUTE_RETURN_ERROR_ON(input[i]->num_dimensions() != rank); - // Validate Kernel - ARM_COMPUTE_RETURN_ON_ERROR(NEStackLayerKernel::validate(input[i], axis_u, i, num_inputs, output)); - } + // Validate Kernel + ARM_COMPUTE_RETURN_ON_ERROR(NEStackLayerKernel::validate(input, axis_u, output)); return Status{}; } void NEStackLayer::run() { - for(unsigned i = 0; i < _num_inputs; i++) + if (!_is_prepared) { - NEScheduler::get().schedule(&_stack_kernels[i], Window::DimY); + _stack_kernel->prepare(); + _is_prepared = true; } + + NEScheduler::get().schedule(_stack_kernel.get(), _stack_kernel->get_split_dimension()); } } // namespace arm_compute |