diff options
Diffstat (limited to 'src/runtime/NEON')
-rw-r--r-- | src/runtime/NEON/functions/NEStackLayer.cpp | 31 |
1 files changed, 10 insertions, 21 deletions
diff --git a/src/runtime/NEON/functions/NEStackLayer.cpp b/src/runtime/NEON/functions/NEStackLayer.cpp index 03e7026691..2f88ffca2a 100644 --- a/src/runtime/NEON/functions/NEStackLayer.cpp +++ b/src/runtime/NEON/functions/NEStackLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -39,7 +39,7 @@ namespace arm_compute NEStackLayer::~NEStackLayer() = default; NEStackLayer::NEStackLayer() // NOLINT - : _input(), _stack_kernels(), _num_inputs(0) + : _stack_kernel(std::make_unique<NEStackLayerKernel>()), _is_prepared(false) { } @@ -47,17 +47,10 @@ void NEStackLayer::configure(const std::vector<ITensor *> &input, int axis, ITen { ARM_COMPUTE_LOG_PARAMS(input, axis, output); - _num_inputs = input.size(); - _stack_kernels.resize(_num_inputs); - // Wrap around negative values const unsigned int axis_u = wrap_around(axis, static_cast<int>(input[0]->info()->num_dimensions() + 1)); - for (unsigned int i = 0; i < _num_inputs; i++) - { - _stack_kernels[i] = std::make_unique<NEStackLayerKernel>(); - _stack_kernels[i]->configure(input[i], axis_u, i, _num_inputs, output); - } + _stack_kernel->configure(input, axis_u, output); } Status NEStackLayer::validate(const std::vector<ITensorInfo *> &input, int axis, const ITensorInfo *output) @@ -69,24 +62,20 @@ Status NEStackLayer::validate(const std::vector<ITensorInfo *> &input, int axis, const size_t rank = input[0]->num_dimensions(); const unsigned int axis_u = wrap_around(axis, static_cast<int>(rank + 1)); - const unsigned int num_inputs = input.size(); - - for (unsigned int i = 0; i < num_inputs; i++) - { - // All the tensors must have the same rank - ARM_COMPUTE_RETURN_ERROR_ON(input[i]->num_dimensions() != rank); - // Validate Kernel - ARM_COMPUTE_RETURN_ON_ERROR(NEStackLayerKernel::validate(input[i], axis_u, i, num_inputs, output)); - } + // Validate Kernel + ARM_COMPUTE_RETURN_ON_ERROR(NEStackLayerKernel::validate(input, axis_u, output)); return Status{}; } void NEStackLayer::run() { - for (unsigned i = 0; i < _num_inputs; i++) + if (!_is_prepared) { - NEScheduler::get().schedule(_stack_kernels[i].get(), Window::DimY); + _stack_kernel->prepare(); + _is_prepared = true; } + + NEScheduler::get().schedule(_stack_kernel.get(), _stack_kernel->get_split_dimension()); } } // namespace arm_compute |