diff options
Diffstat (limited to 'src/cpu/kernels/CpuPermuteKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuPermuteKernel.cpp | 155 |
1 files changed, 71 insertions, 84 deletions
diff --git a/src/cpu/kernels/CpuPermuteKernel.cpp b/src/cpu/kernels/CpuPermuteKernel.cpp index d65e011032..b444a25ff7 100644 --- a/src/cpu/kernels/CpuPermuteKernel.cpp +++ b/src/cpu/kernels/CpuPermuteKernel.cpp @@ -28,8 +28,9 @@ #include "arm_compute/core/ITensor.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" -#include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/core/Validate.h" + #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" @@ -48,56 +49,31 @@ namespace { inline bool is_permutation_supported(const PermutationVector &v) { - static const std::array<PermutationVector, 2> permutations2 = - { - { - PermutationVector(0U, 1U), - PermutationVector(1U, 0U), - } - }; - static const std::array<PermutationVector, 6> permutations3 = - { - { - PermutationVector(2U, 0U, 1U), - PermutationVector(1U, 2U, 0U), - PermutationVector(0U, 1U, 2U), - PermutationVector(0U, 2U, 1U), - PermutationVector(1U, 0U, 2U), - PermutationVector(2U, 1U, 0U), - } - }; - static const std::array<PermutationVector, 24> permutations4 = - { - { - PermutationVector(0U, 1U, 2U, 3U), - PermutationVector(1U, 0U, 2U, 3U), - PermutationVector(2U, 0U, 1U, 3U), - PermutationVector(0U, 2U, 1U, 3U), - PermutationVector(1U, 2U, 0U, 3U), - PermutationVector(2U, 1U, 0U, 3U), - PermutationVector(2U, 1U, 3U, 0U), - PermutationVector(1U, 2U, 3U, 0U), - PermutationVector(3U, 2U, 1U, 0U), - PermutationVector(2U, 3U, 1U, 0U), - PermutationVector(1U, 3U, 2U, 0U), - PermutationVector(3U, 1U, 2U, 0U), - PermutationVector(3U, 0U, 2U, 1U), - PermutationVector(0U, 3U, 2U, 1U), - PermutationVector(2U, 3U, 0U, 1U), - PermutationVector(3U, 2U, 0U, 1U), - PermutationVector(0U, 2U, 3U, 1U), - PermutationVector(2U, 0U, 3U, 1U), - PermutationVector(1U, 0U, 3U, 2U), - PermutationVector(0U, 1U, 3U, 2U), - PermutationVector(3U, 1U, 0U, 2U), - PermutationVector(1U, 3U, 0U, 2U), - PermutationVector(0U, 3U, 1U, 2U), - PermutationVector(3U, 0U, 1U, 2U) - } - }; + static const std::array<PermutationVector, 2> permutations2 = {{ + PermutationVector(0U, 1U), + PermutationVector(1U, 0U), + }}; + static const std::array<PermutationVector, 6> permutations3 = {{ + PermutationVector(2U, 0U, 1U), + PermutationVector(1U, 2U, 0U), + PermutationVector(0U, 1U, 2U), + PermutationVector(0U, 2U, 1U), + PermutationVector(1U, 0U, 2U), + PermutationVector(2U, 1U, 0U), + }}; + static const std::array<PermutationVector, 24> permutations4 = { + {PermutationVector(0U, 1U, 2U, 3U), PermutationVector(1U, 0U, 2U, 3U), PermutationVector(2U, 0U, 1U, 3U), + PermutationVector(0U, 2U, 1U, 3U), PermutationVector(1U, 2U, 0U, 3U), PermutationVector(2U, 1U, 0U, 3U), + PermutationVector(2U, 1U, 3U, 0U), PermutationVector(1U, 2U, 3U, 0U), PermutationVector(3U, 2U, 1U, 0U), + PermutationVector(2U, 3U, 1U, 0U), PermutationVector(1U, 3U, 2U, 0U), PermutationVector(3U, 1U, 2U, 0U), + PermutationVector(3U, 0U, 2U, 1U), PermutationVector(0U, 3U, 2U, 1U), PermutationVector(2U, 3U, 0U, 1U), + PermutationVector(3U, 2U, 0U, 1U), PermutationVector(0U, 2U, 3U, 1U), PermutationVector(2U, 0U, 3U, 1U), + PermutationVector(1U, 0U, 3U, 2U), PermutationVector(0U, 1U, 3U, 2U), PermutationVector(3U, 1U, 0U, 2U), + PermutationVector(1U, 3U, 0U, 2U), PermutationVector(0U, 3U, 1U, 2U), PermutationVector(3U, 0U, 1U, 2U)}}; - return (permutations2.end() != std::find(permutations2.begin(), permutations2.end(), v)) || (permutations3.end() != std::find(permutations3.begin(), permutations3.end(), v)) - || (permutations4.end() != std::find(permutations4.begin(), permutations4.end(), v)); + return (permutations2.end() != std::find(permutations2.begin(), permutations2.end(), v)) || + (permutations3.end() != std::find(permutations3.begin(), permutations3.end(), v)) || + (permutations4.end() != std::find(permutations4.begin(), permutations4.end(), v)); } Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const PermutationVector &perm) @@ -108,7 +84,7 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const const TensorShape dst_shape = misc::shape_calculator::compute_permutation_output_shape(*src, perm); // Validate configured destination - if(dst->total_size() != 0) + if (dst->total_size() != 0) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), dst_shape); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(src, dst); @@ -128,18 +104,22 @@ void run_permute(const Window &window, const ITensor *src, const ITensor *dst, c // we only support these two configs in src/core/NEON/kernels/convolution/common/shims.hpp, for all others // we have to fall back to C++ - if((src_layout == DataLayout::NCHW && perm == PermutationVector{ 2U, 0U, 1U }) || (src_layout == DataLayout::NHWC && perm == PermutationVector{ 1U, 2U, 0U })) + if ((src_layout == DataLayout::NCHW && perm == PermutationVector{2U, 0U, 1U}) || + (src_layout == DataLayout::NHWC && perm == PermutationVector{1U, 2U, 0U})) { - window_src.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), window.x().end() - window.x().start())); - window_src.set(Window::DimY, Window::Dimension(window.y().start(), window.y().end(), window.y().end() - window.y().start())); - window_src.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), window.z().end() - window.z().start())); + window_src.set(Window::DimX, + Window::Dimension(window.x().start(), window.x().end(), window.x().end() - window.x().start())); + window_src.set(Window::DimY, + Window::Dimension(window.y().start(), window.y().end(), window.y().end() - window.y().start())); + window_src.set(Window::DimZ, + Window::Dimension(window.z().start(), window.z().end(), window.z().end() - window.z().start())); window_src.set(3, Window::Dimension(window[3].start(), window[3].end(), window[3].end() - window[3].start())); } // Destination window Window window_dst(window); const Window::Dimension zero_window = Window::Dimension(0, 0, 0); - for(size_t d = 0; d <= dst->info()->num_dimensions(); ++d) + for (size_t d = 0; d <= dst->info()->num_dimensions(); ++d) { window_dst.set(d, zero_window); } @@ -157,7 +137,7 @@ void run_permute(const Window &window, const ITensor *src, const ITensor *dst, c int n_channels = 0; int n_batches = 0; - switch(src_layout) + switch (src_layout) { case DataLayout::NCHW: { @@ -189,38 +169,42 @@ void run_permute(const Window &window, const ITensor *src, const ITensor *dst, c } // CHW -> HWC - if(src_layout == DataLayout::NCHW && perm == PermutationVector{ 2U, 0U, 1U }) + if (src_layout == DataLayout::NCHW && perm == PermutationVector{2U, 0U, 1U}) { const int out_channel_stride = dst->info()->strides_in_bytes().x() / sizeof(T); const int out_col_stride = dst->info()->strides_in_bytes().y() / sizeof(T); const int out_row_stride = dst->info()->strides_in_bytes().z() / sizeof(T); const int out_batch_stride = dst->info()->strides_in_bytes()[3] / sizeof(T); - execute_window_loop(window_src, [&](const Coordinates & id) - { - const int idx = id[0] * out_col_stride + id[1] * out_row_stride + id[2] * out_channel_stride; - reorder::nchw_to_nhwc(reinterpret_cast<const T *>(src_it.ptr()), reinterpret_cast<T *>(dst_it.ptr()) + idx, - n_batches, n_channels, n_rows, n_cols, - in_batch_stride, in_channel_stride, in_row_stride, - out_batch_stride, out_row_stride, out_col_stride); - }, - src_it, dst_it); + execute_window_loop( + window_src, + [&](const Coordinates &id) + { + const int idx = id[0] * out_col_stride + id[1] * out_row_stride + id[2] * out_channel_stride; + reorder::nchw_to_nhwc(reinterpret_cast<const T *>(src_it.ptr()), + reinterpret_cast<T *>(dst_it.ptr()) + idx, n_batches, n_channels, n_rows, n_cols, + in_batch_stride, in_channel_stride, in_row_stride, out_batch_stride, + out_row_stride, out_col_stride); + }, + src_it, dst_it); } // HWC -> CHW - else if(src_layout == DataLayout::NHWC && perm == PermutationVector{ 1U, 2U, 0U }) + else if (src_layout == DataLayout::NHWC && perm == PermutationVector{1U, 2U, 0U}) { const int out_col_stride = dst->info()->strides_in_bytes().x() / sizeof(T); const int out_row_stride = dst->info()->strides_in_bytes().y() / sizeof(T); const int out_channel_stride = dst->info()->strides_in_bytes().z() / sizeof(T); const int out_batch_stride = dst->info()->strides_in_bytes()[3] / sizeof(T); - execute_window_loop(window_src, [&](const Coordinates & id) - { - const int idx = id[0] * out_channel_stride + id[1] * out_col_stride + id[2] * out_row_stride; - reorder::nhwc_to_nchw(reinterpret_cast<const T *>(src_it.ptr()), reinterpret_cast<T *>(dst_it.ptr()) + idx, - n_batches, n_rows, n_cols, n_channels, - in_batch_stride, in_row_stride, in_col_stride, - out_batch_stride, out_channel_stride, out_row_stride); - }, - src_it, dst_it); + execute_window_loop( + window_src, + [&](const Coordinates &id) + { + const int idx = id[0] * out_channel_stride + id[1] * out_col_stride + id[2] * out_row_stride; + reorder::nhwc_to_nchw(reinterpret_cast<const T *>(src_it.ptr()), + reinterpret_cast<T *>(dst_it.ptr()) + idx, n_batches, n_rows, n_cols, n_channels, + in_batch_stride, in_row_stride, in_col_stride, out_batch_stride, + out_channel_stride, out_row_stride); + }, + src_it, dst_it); } else { @@ -230,12 +214,15 @@ void run_permute(const Window &window, const ITensor *src, const ITensor *dst, c Strides perm_strides = strides; permute_strides(perm_strides, perm); const int perm_stride_3 = src->info()->num_dimensions() >= 4 ? perm_strides[3] : 0; - execute_window_loop(window, [&](const Coordinates & id) - { - const int idx = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2] + id[3] * perm_stride_3; - *(reinterpret_cast<T *>(dst_it.ptr() + idx)) = *(reinterpret_cast<const T *>(src_it.ptr())); - }, - src_it, dst_it); + execute_window_loop( + window, + [&](const Coordinates &id) + { + const int idx = + id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2] + id[3] * perm_stride_3; + *(reinterpret_cast<T *>(dst_it.ptr() + idx)) = *(reinterpret_cast<const T *>(src_it.ptr())); + }, + src_it, dst_it); } } } // namespace @@ -275,7 +262,7 @@ void CpuPermuteKernel::run_op(ITensorPack &tensors, const Window &window, const const auto src = tensors.get_const_tensor(TensorType::ACL_SRC); auto dst = tensors.get_tensor(TensorType::ACL_DST); - switch(src->info()->element_size()) + switch (src->info()->element_size()) { case 1: run_permute<uint8_t>(window, src, dst, _perm); |