diff options
Diffstat (limited to 'src/core/CPP/kernels/CPPPermuteKernel.cpp')
-rw-r--r-- | src/core/CPP/kernels/CPPPermuteKernel.cpp | 45 |
1 files changed, 25 insertions, 20 deletions
diff --git a/src/core/CPP/kernels/CPPPermuteKernel.cpp b/src/core/CPP/kernels/CPPPermuteKernel.cpp index 054c7bf05a..e68090d82b 100644 --- a/src/core/CPP/kernels/CPPPermuteKernel.cpp +++ b/src/core/CPP/kernels/CPPPermuteKernel.cpp @@ -25,6 +25,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" + #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" @@ -43,7 +44,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c const TensorShape output_shape = misc::shape_calculator::compute_permutation_output_shape(*input, perm); // Validate configured output - if(output->total_size() != 0) + if (output->total_size() != 0) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); @@ -65,7 +66,7 @@ void CPPPermuteKernel::run_permute(const Window &window) // Create output window Window window_out(window); const Window::Dimension zero_window = Window::Dimension(0, 0, 0); - for(size_t d = 0; d <= _perm.num_dimensions(); ++d) + for (size_t d = 0; d <= _perm.num_dimensions(); ++d) { window_out.set(d, zero_window); } @@ -74,28 +75,32 @@ void CPPPermuteKernel::run_permute(const Window &window) Iterator in(_input, window); Iterator out(_output, window_out); - if(_input->info()->num_dimensions() <= 3) + if (_input->info()->num_dimensions() <= 3) { - execute_window_loop(window, [&](const Coordinates & id) - { - const int idx = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2]; - *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr())); - }, - in, out); + execute_window_loop( + window, + [&](const Coordinates &id) + { + const int idx = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2]; + *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr())); + }, + in, out); } - else if(_input->info()->num_dimensions() >= 4) + else if (_input->info()->num_dimensions() >= 4) { - execute_window_loop(window, [&](const Coordinates & id) - { - const int idx = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2] + id[3] * perm_strides[3]; - *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr())); - }, - in, out); + execute_window_loop( + window, + [&](const Coordinates &id) + { + const int idx = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2] + + id[3] * perm_strides[3]; + *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr())); + }, + in, out); } } -CPPPermuteKernel::CPPPermuteKernel() - : _func(), _input(nullptr), _output(nullptr), _perm() +CPPPermuteKernel::CPPPermuteKernel() : _func(), _input(nullptr), _output(nullptr), _perm() { } @@ -113,7 +118,7 @@ void CPPPermuteKernel::configure(const ITensor *input, ITensor *output, const Pe _output = output; _perm = perm; - switch(input->info()->element_size()) + switch (input->info()->element_size()) { case 1: _func = &CPPPermuteKernel::run_permute<uint8_t>; @@ -152,7 +157,7 @@ void CPPPermuteKernel::run(const Window &window, const ThreadInfo &info) ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICPPKernel::window(), window); - if(_func != nullptr) + if (_func != nullptr) { (this->*_func)(window); } |