diff options
Diffstat (limited to 'src/core/NEON/kernels/NEStackLayerKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEStackLayerKernel.cpp | 55 |
1 files changed, 35 insertions, 20 deletions
diff --git a/src/core/NEON/kernels/NEStackLayerKernel.cpp b/src/core/NEON/kernels/NEStackLayerKernel.cpp index 93080e2ac7..e23b40a9aa 100644 --- a/src/core/NEON/kernels/NEStackLayerKernel.cpp +++ b/src/core/NEON/kernels/NEStackLayerKernel.cpp @@ -25,13 +25,13 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" -#include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Utils.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" -#include "arm_compute/core/utils/misc/ShapeCalculator.h" + #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" @@ -41,7 +41,11 @@ using namespace arm_compute::misc::shape_calculator; namespace { -Status validate_arguments(const ITensorInfo *input, unsigned int axis, unsigned int idx_input, unsigned int num_tensors, const ITensorInfo *output) +Status validate_arguments(const ITensorInfo *input, + unsigned int axis, + unsigned int idx_input, + unsigned int num_tensors, + const ITensorInfo *output) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); // Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use CPU FP16 instructions. @@ -50,9 +54,10 @@ Status validate_arguments(const ITensorInfo *input, unsigned int axis, unsigned ARM_COMPUTE_RETURN_ERROR_ON(axis > input->num_dimensions()); ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4); - if(output->total_size() != 0) + if (output->total_size() != 0) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_stack_shape(*input, axis, num_tensors)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), + compute_stack_shape(*input, axis, num_tensors)); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output); } @@ -60,7 +65,8 @@ Status validate_arguments(const ITensorInfo *input, unsigned int axis, unsigned return Status{}; } -std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, unsigned int axis, unsigned int num_tensors, ITensorInfo *output) +std::pair<Status, Window> +validate_and_configure_window(ITensorInfo *input, unsigned int axis, unsigned int num_tensors, ITensorInfo *output) { // Output auto inizialitation if not yet initialized auto_init_if_empty(*output, input->clone()->set_tensor_shape(compute_stack_shape(*input, axis, num_tensors))); @@ -71,11 +77,12 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, unsi return std::make_pair(Status{}, win); } -inline Coordinates shift_from_axis_and_replace_coordinate(const Coordinates &id, unsigned int axis, unsigned int idx_input) +inline Coordinates +shift_from_axis_and_replace_coordinate(const Coordinates &id, unsigned int axis, unsigned int idx_input) { constexpr int max_out_coord = 5; // Input shape is max a 4D shape, output is max 5D Coordinates id_out = id; - for(unsigned int i = max_out_coord - 1; i > axis; --i) + for (unsigned int i = max_out_coord - 1; i > axis; --i) { id_out.set(i, id[i - 1]); } @@ -84,12 +91,12 @@ inline Coordinates shift_from_axis_and_replace_coordinate(const Coordinates &id, } } // namespace -NEStackLayerKernel::NEStackLayerKernel() - : _input(nullptr), _output(nullptr), _axis(), _idx_input() +NEStackLayerKernel::NEStackLayerKernel() : _input(nullptr), _output(nullptr), _axis(), _idx_input() { } -void NEStackLayerKernel::configure(const ITensor *input, unsigned int axis, unsigned int idx_input, unsigned int num_tensors, ITensor *output) +void NEStackLayerKernel::configure( + const ITensor *input, unsigned int axis, unsigned int idx_input, unsigned int num_tensors, ITensor *output) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), axis, idx_input, num_tensors, output->info())); @@ -106,10 +113,15 @@ void NEStackLayerKernel::configure(const ITensor *input, unsigned int axis, unsi INEKernel::configure(win_config.second); } -Status NEStackLayerKernel::validate(const ITensorInfo *input, unsigned int axis, unsigned int idx_input, unsigned int num_tensors, const ITensorInfo *output) +Status NEStackLayerKernel::validate(const ITensorInfo *input, + unsigned int axis, + unsigned int idx_input, + unsigned int num_tensors, + const ITensorInfo *output) { ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, axis, idx_input, num_tensors, output)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), axis, num_tensors, output->clone().get()).first); + ARM_COMPUTE_RETURN_ON_ERROR( + validate_and_configure_window(input->clone().get(), axis, num_tensors, output->clone().get()).first); return Status{}; } @@ -131,12 +143,15 @@ void NEStackLayerKernel::run(const Window &window, const ThreadInfo &info) const int stride_w = _output->info()->num_dimensions() >= 3 ? _output->info()->strides_in_bytes()[3] : 0; const int stride_k = _output->info()->num_dimensions() >= 4 ? _output->info()->strides_in_bytes()[4] : 0; - execute_window_loop(window, [&](const Coordinates & id) - { - Coordinates id_out = shift_from_axis_and_replace_coordinate(id, _axis, _idx_input); - const int idx = id_out[0] * stride_x + id_out[1] * stride_y + id_out[2] * stride_z + id_out[3] * stride_w + id_out[4] * stride_k; - std::memcpy(output.ptr() + idx, input.ptr(), _input->info()->element_size()); - }, - input); + execute_window_loop( + window, + [&](const Coordinates &id) + { + Coordinates id_out = shift_from_axis_and_replace_coordinate(id, _axis, _idx_input); + const int idx = id_out[0] * stride_x + id_out[1] * stride_y + id_out[2] * stride_z + id_out[3] * stride_w + + id_out[4] * stride_k; + std::memcpy(output.ptr() + idx, input.ptr(), _input->info()->element_size()); + }, + input); } } // namespace arm_compute |