From 00afd11eaa7d408ff873732639c9a724fece9058 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Thu, 4 Jan 2018 10:34:24 +0000 Subject: COMPMID-719: NEPermuteKernel refactoring Change-Id: I91b43d9706ac3244ce43684967ace0b022d35bad Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/114988 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- src/core/CPP/kernels/CPPPermuteKernel.cpp | 94 ++++++++++++++++--------------- 1 file changed, 50 insertions(+), 44 deletions(-) (limited to 'src/core/CPP/kernels') diff --git a/src/core/CPP/kernels/CPPPermuteKernel.cpp b/src/core/CPP/kernels/CPPPermuteKernel.cpp index 80b0abaabc..c7bae870d1 100644 --- a/src/core/CPP/kernels/CPPPermuteKernel.cpp +++ b/src/core/CPP/kernels/CPPPermuteKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017, 2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -29,6 +29,7 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include #include @@ -37,13 +38,6 @@ using namespace arm_compute; namespace { -TensorShape get_output_shape(const ITensorInfo *input, const PermutationVector &perm) -{ - TensorShape output_shape = input->tensor_shape(); - permute(output_shape, perm); - return output_shape; -} - Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PermutationVector &perm) { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8, @@ -57,7 +51,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c || (perm[0] != 1 && perm[1] != 2 && perm[2] != 0))), "Only [2, 0, 1],[1, 2, 0] and [3, 2, 0, 1] permutation is supported"); - const TensorShape output_shape = get_output_shape(input, perm); + const TensorShape output_shape = misc::shape_calculator::compute_permutation_output_shape(*input, perm); // Validate configured output if(output->total_size() != 0) @@ -69,59 +63,71 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c return Status{}; } + +template +inline void permute_strides(Dimensions &dimensions, const PermutationVector &perm) +{ + const auto old_dim = utility::make_array::num_max_dimensions>(dimensions.begin(), dimensions.end()); + for(unsigned int i = 0; i < perm.num_dimensions(); ++i) + { + dimensions[perm[i]] = old_dim[i]; + } +} + } // namespace + + + template void CPPPermuteKernel::run_permute(const Window &window) { - const int output_stride_x = _output->info()->strides_in_bytes().x(); - const int output_stride_y = _output->info()->strides_in_bytes().y(); - const int output_stride_z = _output->info()->strides_in_bytes().z(); - const int output_stride_w = _output->info()->strides_in_bytes()[3]; - - Window window_out(window); - window_out.set(Window::DimX, Window::Dimension(0, 0, 0)); - window_out.set(Window::DimY, Window::Dimension(0, 0, 0)); - window_out.set(Window::DimZ, Window::Dimension(0, 0, 0)); - window_out.set(3, Window::Dimension(0, 0, 0)); + Strides strides = _output->info()->strides_in_bytes(); + Strides perm_strides = strides; + permute_strides(perm_strides,_perm); + const int output_stride_w = strides[3]; + Window window_out(window); + const Window::Dimension zero_window = Window::Dimension(0, 0, 0); + for(size_t d = 0; d <= _perm.num_dimensions(); ++d) + { + window_out.set(d, zero_window); + } // Create iterators Iterator in(_input, window); Iterator out(_output, window_out); - - // Run [2, 0, 1] permute - if(_perm[0] == 2 && _perm[1] == 0 && _perm[2] == 1) + ARM_COMPUTE_ERROR_ON(_perm.num_dimensions() > _input->info()->num_dimensions()); + if(_input->info()->num_dimensions() <= 3) { execute_window_loop(window, [&](const Coordinates & id) { - const int idx = id[3] * output_stride_w + id.y() * output_stride_z + id.x() * output_stride_y + id.z() * output_stride_x; + const int idx = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2]; *(reinterpret_cast(out.ptr() + idx)) = *(reinterpret_cast(in.ptr())); }, in, out); } - // Run [1, 2, 0] permute - else if(_perm[0] == 1 && _perm[1] == 2 && _perm[2] == 0) + else if(_input->info()->num_dimensions() >= 4) { - execute_window_loop(window, [&](const Coordinates & id) + if(_perm.num_dimensions() < _input->info()->num_dimensions()) { - const int idx = id[3] * output_stride_w + id.x() * output_stride_z + id.z() * output_stride_y + id.y() * output_stride_x; - *(reinterpret_cast(out.ptr() + idx)) = *(reinterpret_cast(in.ptr())); - }, - in, out); - } - // Run [3, 2, 0, 1] permute - else if(_perm[0] == 3 && _perm[1] == 2 && _perm[2] == 0 && _perm[3] == 1) - { - execute_window_loop(window, [&](const Coordinates & id) + // special case: perm.size = 3 and tensor size > 3, _perm[3] would be invalid so we handle this with id[3] * output_stride_w instead of id[_perm[3]] + ARM_COMPUTE_ERROR_ON(_perm.num_dimensions() < 3); + execute_window_loop(window, [&](const Coordinates & id) + { + const int idx = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2] + id[3] * output_stride_w; + *(reinterpret_cast(out.ptr() + idx)) = *(reinterpret_cast(in.ptr())); + }, + in, out); + } + else { - const int idx = id[3] * output_stride_x + id[2] * output_stride_y + id[0] * output_stride_z + id[1] * output_stride_w; - *(reinterpret_cast(out.ptr() + idx)) = *(reinterpret_cast(in.ptr())); - }, - in, out); - } - else - { - ARM_COMPUTE_ERROR("Not supported."); + execute_window_loop(window, [&](const Coordinates & id) + { + const int idx = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2] + id[3] * perm_strides[3]; + *(reinterpret_cast(out.ptr() + idx)) = *(reinterpret_cast(in.ptr())); + }, + in, out); + } } } @@ -133,7 +139,7 @@ CPPPermuteKernel::CPPPermuteKernel() void CPPPermuteKernel::configure(const ITensor *input, ITensor *output, const PermutationVector &perm) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); - const TensorShape output_shape = get_output_shape(input->info(), perm); + const TensorShape output_shape = misc::shape_calculator::compute_permutation_output_shape(*input->info(), perm); // Output auto inizialitation if not yet initialized auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape)); -- cgit v1.2.1