aboutsummaryrefslogtreecommitdiff
path: root/src/core/CPP/kernels
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2018-01-04 10:34:24 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:43:42 +0000
commit00afd11eaa7d408ff873732639c9a724fece9058 (patch)
tree37dfbaca825c3363dd2197ea85f99f740748b5b0 /src/core/CPP/kernels
parent5237e01c342b9301951a799842e9c48813b66fd4 (diff)
downloadComputeLibrary-00afd11eaa7d408ff873732639c9a724fece9058.tar.gz
COMPMID-719: NEPermuteKernel refactoring
Change-Id: I91b43d9706ac3244ce43684967ace0b022d35bad Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/114988 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/CPP/kernels')
-rw-r--r--src/core/CPP/kernels/CPPPermuteKernel.cpp94
1 files changed, 50 insertions, 44 deletions
diff --git a/src/core/CPP/kernels/CPPPermuteKernel.cpp b/src/core/CPP/kernels/CPPPermuteKernel.cpp
index 80b0abaabc..c7bae870d1 100644
--- a/src/core/CPP/kernels/CPPPermuteKernel.cpp
+++ b/src/core/CPP/kernels/CPPPermuteKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017, 2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -29,6 +29,7 @@
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include <cstddef>
#include <cstdint>
@@ -37,13 +38,6 @@ using namespace arm_compute;
namespace
{
-TensorShape get_output_shape(const ITensorInfo *input, const PermutationVector &perm)
-{
- TensorShape output_shape = input->tensor_shape();
- permute(output_shape, perm);
- return output_shape;
-}
-
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PermutationVector &perm)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8,
@@ -57,7 +51,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
|| (perm[0] != 1 && perm[1] != 2 && perm[2] != 0))),
"Only [2, 0, 1],[1, 2, 0] and [3, 2, 0, 1] permutation is supported");
- const TensorShape output_shape = get_output_shape(input, perm);
+ const TensorShape output_shape = misc::shape_calculator::compute_permutation_output_shape(*input, perm);
// Validate configured output
if(output->total_size() != 0)
@@ -69,59 +63,71 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
return Status{};
}
+
+template <typename T>
+inline void permute_strides(Dimensions<T> &dimensions, const PermutationVector &perm)
+{
+ const auto old_dim = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end());
+ for(unsigned int i = 0; i < perm.num_dimensions(); ++i)
+ {
+ dimensions[perm[i]] = old_dim[i];
+ }
+}
+
} // namespace
+
+
+
template <typename T>
void CPPPermuteKernel::run_permute(const Window &window)
{
- const int output_stride_x = _output->info()->strides_in_bytes().x();
- const int output_stride_y = _output->info()->strides_in_bytes().y();
- const int output_stride_z = _output->info()->strides_in_bytes().z();
- const int output_stride_w = _output->info()->strides_in_bytes()[3];
-
- Window window_out(window);
- window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
- window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
- window_out.set(Window::DimZ, Window::Dimension(0, 0, 0));
- window_out.set(3, Window::Dimension(0, 0, 0));
+ Strides strides = _output->info()->strides_in_bytes();
+ Strides perm_strides = strides;
+ permute_strides(perm_strides,_perm);
+ const int output_stride_w = strides[3];
+ Window window_out(window);
+ const Window::Dimension zero_window = Window::Dimension(0, 0, 0);
+ for(size_t d = 0; d <= _perm.num_dimensions(); ++d)
+ {
+ window_out.set(d, zero_window);
+ }
// Create iterators
Iterator in(_input, window);
Iterator out(_output, window_out);
-
- // Run [2, 0, 1] permute
- if(_perm[0] == 2 && _perm[1] == 0 && _perm[2] == 1)
+ ARM_COMPUTE_ERROR_ON(_perm.num_dimensions() > _input->info()->num_dimensions());
+ if(_input->info()->num_dimensions() <= 3)
{
execute_window_loop(window, [&](const Coordinates & id)
{
- const int idx = id[3] * output_stride_w + id.y() * output_stride_z + id.x() * output_stride_y + id.z() * output_stride_x;
+ const int idx = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2];
*(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr()));
},
in, out);
}
- // Run [1, 2, 0] permute
- else if(_perm[0] == 1 && _perm[1] == 2 && _perm[2] == 0)
+ else if(_input->info()->num_dimensions() >= 4)
{
- execute_window_loop(window, [&](const Coordinates & id)
+ if(_perm.num_dimensions() < _input->info()->num_dimensions())
{
- const int idx = id[3] * output_stride_w + id.x() * output_stride_z + id.z() * output_stride_y + id.y() * output_stride_x;
- *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr()));
- },
- in, out);
- }
- // Run [3, 2, 0, 1] permute
- else if(_perm[0] == 3 && _perm[1] == 2 && _perm[2] == 0 && _perm[3] == 1)
- {
- execute_window_loop(window, [&](const Coordinates & id)
+ // special case: perm.size = 3 and tensor size > 3, _perm[3] would be invalid so we handle this with id[3] * output_stride_w instead of id[_perm[3]]
+ ARM_COMPUTE_ERROR_ON(_perm.num_dimensions() < 3);
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ const int idx = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2] + id[3] * output_stride_w;
+ *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr()));
+ },
+ in, out);
+ }
+ else
{
- const int idx = id[3] * output_stride_x + id[2] * output_stride_y + id[0] * output_stride_z + id[1] * output_stride_w;
- *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr()));
- },
- in, out);
- }
- else
- {
- ARM_COMPUTE_ERROR("Not supported.");
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ const int idx = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2] + id[3] * perm_strides[3];
+ *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr()));
+ },
+ in, out);
+ }
}
}
@@ -133,7 +139,7 @@ CPPPermuteKernel::CPPPermuteKernel()
void CPPPermuteKernel::configure(const ITensor *input, ITensor *output, const PermutationVector &perm)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- const TensorShape output_shape = get_output_shape(input->info(), perm);
+ const TensorShape output_shape = misc::shape_calculator::compute_permutation_output_shape(*input->info(), perm);
// Output auto inizialitation if not yet initialized
auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));