aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuTransposeKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuTransposeKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuTransposeKernel.cpp772
1 files changed, 452 insertions, 320 deletions
diff --git a/src/cpu/kernels/CpuTransposeKernel.cpp b/src/cpu/kernels/CpuTransposeKernel.cpp
index b2cebc4230..615bc6ce1e 100644
--- a/src/cpu/kernels/CpuTransposeKernel.cpp
+++ b/src/cpu/kernels/CpuTransposeKernel.cpp
@@ -28,8 +28,9 @@
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/core/Validate.h"
+
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
@@ -45,7 +46,7 @@ namespace
{
unsigned int num_elems_processed(size_t element_size)
{
- switch(element_size)
+ switch (element_size)
{
case 1:
return 8;
@@ -81,10 +82,10 @@ void transpose_8bit_elements(const ITensor *in, ITensor *out, const Window &wind
Window window_in(window);
window_in.set(Window::DimX, Window::Dimension(0, 1, 1));
- if(left_over_loop_y)
+ if (left_over_loop_y)
{
// Check if window_end_y_multiple_of is greater than window_start_y
- if(window_end_y_multiple_of > window_start_y)
+ if (window_end_y_multiple_of > window_start_y)
{
window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
}
@@ -101,87 +102,121 @@ void transpose_8bit_elements(const ITensor *in, ITensor *out, const Window &wind
Iterator output(out, window_out);
// Run the SIMD path if and only if the input is not a row-vector
- if(in->info()->dimension(1) != 1)
+ if (in->info()->dimension(1) != 1)
{
Iterator input(in, window_in);
- execute_window_loop(window_in, [&](const Coordinates & id)
- {
- // Compute 8x8 elements per iteration
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ execute_window_loop(
+ window_in,
+ [&](const Coordinates &id)
{
- const uint8x8_t row0 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 0 * input_stride_in_bytes));
- const uint8x8_t row1 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 1 * input_stride_in_bytes));
- const uint8x8_t row2 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 2 * input_stride_in_bytes));
- const uint8x8_t row3 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 3 * input_stride_in_bytes));
- const uint8x8_t row4 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 4 * input_stride_in_bytes));
- const uint8x8_t row5 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 5 * input_stride_in_bytes));
- const uint8x8_t row6 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 6 * input_stride_in_bytes));
- const uint8x8_t row7 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 7 * input_stride_in_bytes));
-
- // Transpose 2x2
- const uint8x8x2_t k0_u8 = vtrn_u8(row0, row1);
- const uint8x8x2_t k1_u8 = vtrn_u8(row2, row3);
- const uint8x8x2_t k2_u8 = vtrn_u8(row4, row5);
- const uint8x8x2_t k3_u8 = vtrn_u8(row6, row7);
-
- // Transpose 4x4
- const uint16x4x2_t k0_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[0]), vreinterpret_u16_u8(k1_u8.val[0]));
- const uint16x4x2_t k1_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[1]), vreinterpret_u16_u8(k1_u8.val[1]));
- const uint16x4x2_t k2_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[0]), vreinterpret_u16_u8(k3_u8.val[0]));
- const uint16x4x2_t k3_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[1]), vreinterpret_u16_u8(k3_u8.val[1]));
-
- // Transpose 8x8
- const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k2_u16.val[0]));
- const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k2_u16.val[1]));
- const uint32x2x2_t k2_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[0]), vreinterpret_u32_u16(k3_u16.val[0]));
- const uint32x2x2_t k3_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[1]), vreinterpret_u32_u16(k3_u16.val[1]));
-
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + x * output_stride_in_bytes;
-
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[0])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[0])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[0])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[0])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[1])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[1])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[1])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[1])));
- }
-
- // Compute left-over elements along the x dimension (1x8)
- for(; x < window_end_x; ++x)
- {
- const uint8_t val0 = *(input.ptr() + x + 0 * input_stride_in_bytes);
- const uint8_t val1 = *(input.ptr() + x + 1 * input_stride_in_bytes);
- const uint8_t val2 = *(input.ptr() + x + 2 * input_stride_in_bytes);
- const uint8_t val3 = *(input.ptr() + x + 3 * input_stride_in_bytes);
- const uint8_t val4 = *(input.ptr() + x + 4 * input_stride_in_bytes);
- const uint8_t val5 = *(input.ptr() + x + 5 * input_stride_in_bytes);
- const uint8_t val6 = *(input.ptr() + x + 6 * input_stride_in_bytes);
- const uint8_t val7 = *(input.ptr() + x + 7 * input_stride_in_bytes);
-
- uint8x8_t result = vdup_n_u8(0);
- result = vset_lane_u8(val0, result, 0);
- result = vset_lane_u8(val1, result, 1);
- result = vset_lane_u8(val2, result, 2);
- result = vset_lane_u8(val3, result, 3);
- result = vset_lane_u8(val4, result, 4);
- result = vset_lane_u8(val5, result, 5);
- result = vset_lane_u8(val6, result, 6);
- result = vset_lane_u8(val7, result, 7);
-
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + x * output_stride_in_bytes;
-
- vst1_u8(output.ptr() + dst_offset_in_bytes, result);
- }
- },
- input, output);
+ // Compute 8x8 elements per iteration
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const uint8x8_t row0 =
+ vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 0 * input_stride_in_bytes));
+ const uint8x8_t row1 =
+ vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 1 * input_stride_in_bytes));
+ const uint8x8_t row2 =
+ vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 2 * input_stride_in_bytes));
+ const uint8x8_t row3 =
+ vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 3 * input_stride_in_bytes));
+ const uint8x8_t row4 =
+ vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 4 * input_stride_in_bytes));
+ const uint8x8_t row5 =
+ vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 5 * input_stride_in_bytes));
+ const uint8x8_t row6 =
+ vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 6 * input_stride_in_bytes));
+ const uint8x8_t row7 =
+ vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 7 * input_stride_in_bytes));
+
+ // Transpose 2x2
+ const uint8x8x2_t k0_u8 = vtrn_u8(row0, row1);
+ const uint8x8x2_t k1_u8 = vtrn_u8(row2, row3);
+ const uint8x8x2_t k2_u8 = vtrn_u8(row4, row5);
+ const uint8x8x2_t k3_u8 = vtrn_u8(row6, row7);
+
+ // Transpose 4x4
+ const uint16x4x2_t k0_u16 =
+ vtrn_u16(vreinterpret_u16_u8(k0_u8.val[0]), vreinterpret_u16_u8(k1_u8.val[0]));
+ const uint16x4x2_t k1_u16 =
+ vtrn_u16(vreinterpret_u16_u8(k0_u8.val[1]), vreinterpret_u16_u8(k1_u8.val[1]));
+ const uint16x4x2_t k2_u16 =
+ vtrn_u16(vreinterpret_u16_u8(k2_u8.val[0]), vreinterpret_u16_u8(k3_u8.val[0]));
+ const uint16x4x2_t k3_u16 =
+ vtrn_u16(vreinterpret_u16_u8(k2_u8.val[1]), vreinterpret_u16_u8(k3_u8.val[1]));
+
+ // Transpose 8x8
+ const uint32x2x2_t k0_u32 =
+ vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k2_u16.val[0]));
+ const uint32x2x2_t k1_u32 =
+ vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k2_u16.val[1]));
+ const uint32x2x2_t k2_u32 =
+ vtrn_u32(vreinterpret_u32_u16(k1_u16.val[0]), vreinterpret_u32_u16(k3_u16.val[0]));
+ const uint32x2x2_t k3_u32 =
+ vtrn_u32(vreinterpret_u32_u16(k1_u16.val[1]), vreinterpret_u32_u16(k3_u16.val[1]));
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + x * output_stride_in_bytes;
+
+ vst1_u8(
+ reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes),
+ vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[0])));
+ vst1_u8(
+ reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes),
+ vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[0])));
+ vst1_u8(
+ reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes),
+ vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[0])));
+ vst1_u8(
+ reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes),
+ vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[0])));
+ vst1_u8(
+ reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes),
+ vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[1])));
+ vst1_u8(
+ reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes),
+ vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[1])));
+ vst1_u8(
+ reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes),
+ vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[1])));
+ vst1_u8(
+ reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes),
+ vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[1])));
+ }
+
+ // Compute left-over elements along the x dimension (1x8)
+ for (; x < window_end_x; ++x)
+ {
+ const uint8_t val0 = *(input.ptr() + x + 0 * input_stride_in_bytes);
+ const uint8_t val1 = *(input.ptr() + x + 1 * input_stride_in_bytes);
+ const uint8_t val2 = *(input.ptr() + x + 2 * input_stride_in_bytes);
+ const uint8_t val3 = *(input.ptr() + x + 3 * input_stride_in_bytes);
+ const uint8_t val4 = *(input.ptr() + x + 4 * input_stride_in_bytes);
+ const uint8_t val5 = *(input.ptr() + x + 5 * input_stride_in_bytes);
+ const uint8_t val6 = *(input.ptr() + x + 6 * input_stride_in_bytes);
+ const uint8_t val7 = *(input.ptr() + x + 7 * input_stride_in_bytes);
+
+ uint8x8_t result = vdup_n_u8(0);
+ result = vset_lane_u8(val0, result, 0);
+ result = vset_lane_u8(val1, result, 1);
+ result = vset_lane_u8(val2, result, 2);
+ result = vset_lane_u8(val3, result, 3);
+ result = vset_lane_u8(val4, result, 4);
+ result = vset_lane_u8(val5, result, 5);
+ result = vset_lane_u8(val6, result, 6);
+ result = vset_lane_u8(val7, result, 7);
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + x * output_stride_in_bytes;
+
+ vst1_u8(output.ptr() + dst_offset_in_bytes, result);
+ }
+ },
+ input, output);
}
- if(left_over_loop_y)
+ if (left_over_loop_y)
{
window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
@@ -190,16 +225,18 @@ void transpose_8bit_elements(const ITensor *in, ITensor *out, const Window &wind
Iterator output(out, window_out);
// Compute left-over elements along the y dimension (1x1)
- execute_window_loop(window_in, [&](const Coordinates & id)
- {
- const uint8_t val0 = *input.ptr();
+ execute_window_loop(
+ window_in,
+ [&](const Coordinates &id)
+ {
+ const uint8_t val0 = *input.ptr();
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + id.x() * output_stride_in_bytes;
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + id.x() * output_stride_in_bytes;
- *(output.ptr() + dst_offset_in_bytes) = val0;
- },
- input, output);
+ *(output.ptr() + dst_offset_in_bytes) = val0;
+ },
+ input, output);
}
}
@@ -220,10 +257,10 @@ void transpose_16bit_elements(const ITensor *in, ITensor *out, const Window &win
Window window_in(window);
window_in.set(Window::DimX, Window::Dimension(0, 1, 1));
- if(left_over_loop_y)
+ if (left_over_loop_y)
{
// Check if window_end_y_multiple_of is greater than window_start_y
- if(window_end_y_multiple_of > window_start_y)
+ if (window_end_y_multiple_of > window_start_y)
{
window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
}
@@ -240,61 +277,77 @@ void transpose_16bit_elements(const ITensor *in, ITensor *out, const Window &win
Iterator output(out, window_out);
// Run the SIMD path if and only if the input is not a row-vector
- if(in->info()->dimension(1) != 1)
+ if (in->info()->dimension(1) != 1)
{
Iterator input(in, window_in);
- execute_window_loop(window_in, [&](const Coordinates & id)
- {
- // Compute 4x4 elements per iteration
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ execute_window_loop(
+ window_in,
+ [&](const Coordinates &id)
{
- const uint16x4_t row0 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
- const uint16x4_t row1 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
- const uint16x4_t row2 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
- const uint16x4_t row3 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
-
- // Transpose 2x2
- const uint16x4x2_t k0_u16 = vtrn_u16(row0, row1);
- const uint16x4x2_t k1_u16 = vtrn_u16(row2, row3);
-
- // Transpose 4x4
- const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k1_u16.val[0]));
- const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k1_u16.val[1]));
-
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + x * output_stride_in_bytes;
-
- vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[0]));
- vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[0]));
- vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[1]));
- vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[1]));
- }
-
- // Compute left-over elements (1x4)
- for(; x < window_end_x; ++x)
- {
- const uint16_t val0 = *(reinterpret_cast<uint16_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
- const uint16_t val1 = *(reinterpret_cast<uint16_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
- const uint16_t val2 = *(reinterpret_cast<uint16_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
- const uint16_t val3 = *(reinterpret_cast<uint16_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
-
- uint16x4_t result = vdup_n_u16(0);
- result = vset_lane_u16(val0, result, 0);
- result = vset_lane_u16(val1, result, 1);
- result = vset_lane_u16(val2, result, 2);
- result = vset_lane_u16(val3, result, 3);
-
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + x * output_stride_in_bytes;
-
- vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes), result);
- }
- },
- input, output);
+ // Compute 4x4 elements per iteration
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const uint16x4_t row0 =
+ vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint16x4_t row1 =
+ vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint16x4_t row2 =
+ vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint16x4_t row3 =
+ vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+
+ // Transpose 2x2
+ const uint16x4x2_t k0_u16 = vtrn_u16(row0, row1);
+ const uint16x4x2_t k1_u16 = vtrn_u16(row2, row3);
+
+ // Transpose 4x4
+ const uint32x2x2_t k0_u32 =
+ vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k1_u16.val[0]));
+ const uint32x2x2_t k1_u32 =
+ vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k1_u16.val[1]));
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + x * output_stride_in_bytes;
+
+ vst1_u16(
+ reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes),
+ vreinterpret_u16_u32(k0_u32.val[0]));
+ vst1_u16(
+ reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes),
+ vreinterpret_u16_u32(k1_u32.val[0]));
+ vst1_u16(
+ reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes),
+ vreinterpret_u16_u32(k0_u32.val[1]));
+ vst1_u16(
+ reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes),
+ vreinterpret_u16_u32(k1_u32.val[1]));
+ }
+
+ // Compute left-over elements (1x4)
+ for (; x < window_end_x; ++x)
+ {
+ const uint16_t val0 = *(reinterpret_cast<uint16_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint16_t val1 = *(reinterpret_cast<uint16_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint16_t val2 = *(reinterpret_cast<uint16_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint16_t val3 = *(reinterpret_cast<uint16_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+
+ uint16x4_t result = vdup_n_u16(0);
+ result = vset_lane_u16(val0, result, 0);
+ result = vset_lane_u16(val1, result, 1);
+ result = vset_lane_u16(val2, result, 2);
+ result = vset_lane_u16(val3, result, 3);
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + x * output_stride_in_bytes;
+
+ vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes), result);
+ }
+ },
+ input, output);
}
- if(left_over_loop_y)
+ if (left_over_loop_y)
{
window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
@@ -303,16 +356,18 @@ void transpose_16bit_elements(const ITensor *in, ITensor *out, const Window &win
Iterator output(out, window_out);
// Compute left-over elements along the y dimension (1x1)
- execute_window_loop(window_in, [&](const Coordinates & id)
- {
- const uint16_t val0 = *(reinterpret_cast<uint16_t *>(input.ptr()));
+ execute_window_loop(
+ window_in,
+ [&](const Coordinates &id)
+ {
+ const uint16_t val0 = *(reinterpret_cast<uint16_t *>(input.ptr()));
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + id.x() * output_stride_in_bytes;
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + id.x() * output_stride_in_bytes;
- *(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes)) = val0;
- },
- input, output);
+ *(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes)) = val0;
+ },
+ input, output);
}
}
@@ -347,10 +402,10 @@ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &win
Window window_in(window);
window_in.set(Window::DimX, Window::Dimension(0, 1, 1));
- if(left_over_loop_y)
+ if (left_over_loop_y)
{
// Check if window_end_y_multiple_of is greater than window_start_y
- if(window_end_y_multiple_of > window_start_y)
+ if (window_end_y_multiple_of > window_start_y)
{
window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
}
@@ -367,102 +422,160 @@ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &win
Iterator output(out, window_out);
// Run the SIMD path if and only if the input is not a row-vector
- if(in->info()->dimension(1) != 1)
+ if (in->info()->dimension(1) != 1)
{
Iterator input(in, window_in);
- execute_window_loop(window_in, [&](const Coordinates & id)
- {
- // Compute 8x8 elements per iteration
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ execute_window_loop(
+ window_in,
+ [&](const Coordinates &id)
{
- // Load
- const uint32x4x2_t row0 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
- const uint32x4x2_t row1 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
- const uint32x4x2_t row2 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
- const uint32x4x2_t row3 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
- const uint32x4x2_t row4 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 4 * input_stride_in_bytes) + x);
- const uint32x4x2_t row5 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 5 * input_stride_in_bytes) + x);
- const uint32x4x2_t row6 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 6 * input_stride_in_bytes) + x);
- const uint32x4x2_t row7 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 7 * input_stride_in_bytes) + x);
-
- // Transpose 2x4
- const uint32x4x2_t k0_u32 = {vtrn1q_u32(row0.val[0], row1.val[0]), vtrn2q_u32(row0.val[0], row1.val[0])};
- const uint32x4x2_t k1_u32 = {vtrn1q_u32(row0.val[1], row1.val[1]), vtrn2q_u32(row0.val[1], row1.val[1])};
- const uint32x4x2_t k2_u32 = {vtrn1q_u32(row2.val[0], row3.val[0]), vtrn2q_u32(row2.val[0], row3.val[0])};
- const uint32x4x2_t k3_u32 = {vtrn1q_u32(row2.val[1], row3.val[1]), vtrn2q_u32(row2.val[1], row3.val[1])};
- const uint32x4x2_t k4_u32 = {vtrn1q_u32(row4.val[0], row5.val[0]), vtrn2q_u32(row4.val[0], row5.val[0])};
- const uint32x4x2_t k5_u32 = {vtrn1q_u32(row4.val[1], row5.val[1]), vtrn2q_u32(row4.val[1], row5.val[1])};
- const uint32x4x2_t k6_u32 = {vtrn1q_u32(row6.val[0], row7.val[0]), vtrn2q_u32(row6.val[0], row7.val[0])};
- const uint32x4x2_t k7_u32 = {vtrn1q_u32(row6.val[1], row7.val[1]), vtrn2q_u32(row6.val[1], row7.val[1])};
-
- // Transpose 2x2
- const uint64x2x2_t k0_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0]))};
- const uint64x2x2_t k1_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1]))};
- const uint64x2x2_t k2_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0]))};
- const uint64x2x2_t k3_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1]))};
- const uint64x2x2_t k4_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0]))};
- const uint64x2x2_t k5_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1]))};
- const uint64x2x2_t k6_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0]))};
- const uint64x2x2_t k7_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1]))};
-
- // Swap blocks
- const uint32x4x2_t col0 = {vreinterpretq_u32_u64(k0_u64.val[0]), vreinterpretq_u32_u64(k4_u64.val[0])};
- const uint32x4x2_t col1 = {vreinterpretq_u32_u64(k1_u64.val[0]), vreinterpretq_u32_u64(k5_u64.val[0])};
- const uint32x4x2_t col2 = {vreinterpretq_u32_u64(k0_u64.val[1]), vreinterpretq_u32_u64(k4_u64.val[1])};
- const uint32x4x2_t col3 = {vreinterpretq_u32_u64(k1_u64.val[1]), vreinterpretq_u32_u64(k5_u64.val[1])};
- const uint32x4x2_t col4 = {vreinterpretq_u32_u64(k2_u64.val[0]), vreinterpretq_u32_u64(k6_u64.val[0])};
- const uint32x4x2_t col5 = {vreinterpretq_u32_u64(k3_u64.val[0]), vreinterpretq_u32_u64(k7_u64.val[0])};
- const uint32x4x2_t col6 = {vreinterpretq_u32_u64(k2_u64.val[1]), vreinterpretq_u32_u64(k6_u64.val[1])};
- const uint32x4x2_t col7 = {vreinterpretq_u32_u64(k3_u64.val[1]), vreinterpretq_u32_u64(k7_u64.val[1])};
-
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
-
- // Store
- vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), col0);
- vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), col1);
- vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), col2);
- vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), col3);
- vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), col4);
- vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), col5);
- vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), col6);
- vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), col7);
- }
-
- // Compute left-over elements (8x1)
- for(; x < window_end_x; ++x)
- {
- const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
- const uint32_t val1 = *(reinterpret_cast<uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
- const uint32_t val2 = *(reinterpret_cast<uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
- const uint32_t val3 = *(reinterpret_cast<uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
- const uint32_t val4 = *(reinterpret_cast<uint32_t *>(input.ptr() + 4 * input_stride_in_bytes) + x);
- const uint32_t val5 = *(reinterpret_cast<uint32_t *>(input.ptr() + 5 * input_stride_in_bytes) + x);
- const uint32_t val6 = *(reinterpret_cast<uint32_t *>(input.ptr() + 6 * input_stride_in_bytes) + x);
- const uint32_t val7 = *(reinterpret_cast<uint32_t *>(input.ptr() + 7 * input_stride_in_bytes) + x);
-
- uint32x4_t result0 = vdupq_n_u32(0);
- uint32x4_t result1 = vdupq_n_u32(0);
- result0 = vsetq_lane_u32(val0, result0, 0);
- result0 = vsetq_lane_u32(val1, result0, 1);
- result0 = vsetq_lane_u32(val2, result0, 2);
- result0 = vsetq_lane_u32(val3, result0, 3);
- result1 = vsetq_lane_u32(val4, result1, 0);
- result1 = vsetq_lane_u32(val5, result1, 1);
- result1 = vsetq_lane_u32(val6, result1, 2);
- result1 = vsetq_lane_u32(val7, result1, 3);
-
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
-
- vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes), {result0, result1});
- }
- },
- input, output);
+ // Compute 8x8 elements per iteration
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Load
+ const uint32x4x2_t row0 =
+ vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row1 =
+ vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row2 =
+ vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row3 =
+ vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row4 =
+ vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 4 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row5 =
+ vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 5 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row6 =
+ vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 6 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row7 =
+ vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 7 * input_stride_in_bytes) + x);
+
+ // Transpose 2x4
+ const uint32x4x2_t k0_u32 = {vtrn1q_u32(row0.val[0], row1.val[0]),
+ vtrn2q_u32(row0.val[0], row1.val[0])};
+ const uint32x4x2_t k1_u32 = {vtrn1q_u32(row0.val[1], row1.val[1]),
+ vtrn2q_u32(row0.val[1], row1.val[1])};
+ const uint32x4x2_t k2_u32 = {vtrn1q_u32(row2.val[0], row3.val[0]),
+ vtrn2q_u32(row2.val[0], row3.val[0])};
+ const uint32x4x2_t k3_u32 = {vtrn1q_u32(row2.val[1], row3.val[1]),
+ vtrn2q_u32(row2.val[1], row3.val[1])};
+ const uint32x4x2_t k4_u32 = {vtrn1q_u32(row4.val[0], row5.val[0]),
+ vtrn2q_u32(row4.val[0], row5.val[0])};
+ const uint32x4x2_t k5_u32 = {vtrn1q_u32(row4.val[1], row5.val[1]),
+ vtrn2q_u32(row4.val[1], row5.val[1])};
+ const uint32x4x2_t k6_u32 = {vtrn1q_u32(row6.val[0], row7.val[0]),
+ vtrn2q_u32(row6.val[0], row7.val[0])};
+ const uint32x4x2_t k7_u32 = {vtrn1q_u32(row6.val[1], row7.val[1]),
+ vtrn2q_u32(row6.val[1], row7.val[1])};
+
+ // Transpose 2x2
+ const uint64x2x2_t k0_u64 = {
+ vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0])),
+ vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0]))};
+ const uint64x2x2_t k1_u64 = {
+ vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1])),
+ vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1]))};
+ const uint64x2x2_t k2_u64 = {
+ vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0])),
+ vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0]))};
+ const uint64x2x2_t k3_u64 = {
+ vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1])),
+ vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1]))};
+ const uint64x2x2_t k4_u64 = {
+ vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0])),
+ vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0]))};
+ const uint64x2x2_t k5_u64 = {
+ vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1])),
+ vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1]))};
+ const uint64x2x2_t k6_u64 = {
+ vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0])),
+ vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0]))};
+ const uint64x2x2_t k7_u64 = {
+ vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1])),
+ vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1]))};
+
+ // Swap blocks
+ const uint32x4x2_t col0 = {vreinterpretq_u32_u64(k0_u64.val[0]),
+ vreinterpretq_u32_u64(k4_u64.val[0])};
+ const uint32x4x2_t col1 = {vreinterpretq_u32_u64(k1_u64.val[0]),
+ vreinterpretq_u32_u64(k5_u64.val[0])};
+ const uint32x4x2_t col2 = {vreinterpretq_u32_u64(k0_u64.val[1]),
+ vreinterpretq_u32_u64(k4_u64.val[1])};
+ const uint32x4x2_t col3 = {vreinterpretq_u32_u64(k1_u64.val[1]),
+ vreinterpretq_u32_u64(k5_u64.val[1])};
+ const uint32x4x2_t col4 = {vreinterpretq_u32_u64(k2_u64.val[0]),
+ vreinterpretq_u32_u64(k6_u64.val[0])};
+ const uint32x4x2_t col5 = {vreinterpretq_u32_u64(k3_u64.val[0]),
+ vreinterpretq_u32_u64(k7_u64.val[0])};
+ const uint32x4x2_t col6 = {vreinterpretq_u32_u64(k2_u64.val[1]),
+ vreinterpretq_u32_u64(k6_u64.val[1])};
+ const uint32x4x2_t col7 = {vreinterpretq_u32_u64(k3_u64.val[1]),
+ vreinterpretq_u32_u64(k7_u64.val[1])};
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
+
+ // Store
+ vst1q_u32_x2_(
+ reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes),
+ col0);
+ vst1q_u32_x2_(
+ reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes),
+ col1);
+ vst1q_u32_x2_(
+ reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes),
+ col2);
+ vst1q_u32_x2_(
+ reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes),
+ col3);
+ vst1q_u32_x2_(
+ reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes),
+ col4);
+ vst1q_u32_x2_(
+ reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes),
+ col5);
+ vst1q_u32_x2_(
+ reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes),
+ col6);
+ vst1q_u32_x2_(
+ reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes),
+ col7);
+ }
+
+ // Compute left-over elements (8x1)
+ for (; x < window_end_x; ++x)
+ {
+ const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint32_t val1 = *(reinterpret_cast<uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint32_t val2 = *(reinterpret_cast<uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint32_t val3 = *(reinterpret_cast<uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+ const uint32_t val4 = *(reinterpret_cast<uint32_t *>(input.ptr() + 4 * input_stride_in_bytes) + x);
+ const uint32_t val5 = *(reinterpret_cast<uint32_t *>(input.ptr() + 5 * input_stride_in_bytes) + x);
+ const uint32_t val6 = *(reinterpret_cast<uint32_t *>(input.ptr() + 6 * input_stride_in_bytes) + x);
+ const uint32_t val7 = *(reinterpret_cast<uint32_t *>(input.ptr() + 7 * input_stride_in_bytes) + x);
+
+ uint32x4_t result0 = vdupq_n_u32(0);
+ uint32x4_t result1 = vdupq_n_u32(0);
+ result0 = vsetq_lane_u32(val0, result0, 0);
+ result0 = vsetq_lane_u32(val1, result0, 1);
+ result0 = vsetq_lane_u32(val2, result0, 2);
+ result0 = vsetq_lane_u32(val3, result0, 3);
+ result1 = vsetq_lane_u32(val4, result1, 0);
+ result1 = vsetq_lane_u32(val5, result1, 1);
+ result1 = vsetq_lane_u32(val6, result1, 2);
+ result1 = vsetq_lane_u32(val7, result1, 3);
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
+
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes), {result0, result1});
+ }
+ },
+ input, output);
}
- if(left_over_loop_y)
+ if (left_over_loop_y)
{
window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
@@ -471,40 +584,42 @@ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &win
Iterator output(out, window_out);
// Compute left-over elements along the y dimension (1x1)
- execute_window_loop(window_in, [&](const Coordinates & id)
- {
- const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr()));
+ execute_window_loop(
+ window_in,
+ [&](const Coordinates &id)
+ {
+ const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr()));
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes;
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes;
- *(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes)) = val0;
- },
- input, output);
+ *(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes)) = val0;
+ },
+ input, output);
}
}
#else // __aarch64__
void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &window)
{
- const int window_step_x = 4;
- const int window_step_y = 4;
- const int window_start_x = window.x().start();
- const int window_end_x = window.x().end();
- const int window_start_y = window.y().start();
- const int window_end_y = std::min(window.y().end(), static_cast<int>(in->info()->dimension(1)));
- const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
- const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
- const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
+ const int window_step_x = 4;
+ const int window_step_y = 4;
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_start_y = window.y().start();
+ const int window_end_y = std::min(window.y().end(), static_cast<int>(in->info()->dimension(1)));
+ const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
+ const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
+ const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
// Check if we need a left-over loop for the y dimension
bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
Window window_in(window);
window_in.set(Window::DimX, Window::Dimension(0, 1, 1));
- if(left_over_loop_y)
+ if (left_over_loop_y)
{
// Check if window_end_y_multiple_of is greater than window_start_y
- if(window_end_y_multiple_of > window_start_y)
+ if (window_end_y_multiple_of > window_start_y)
{
window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
}
@@ -521,60 +636,74 @@ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &win
Iterator output(out, window_out);
// Run the SIMD path if and only if the input is not a row-vector
- if(in->info()->dimension(1) != 1)
+ if (in->info()->dimension(1) != 1)
{
Iterator input(in, window_in);
- execute_window_loop(window_in, [&](const Coordinates & id)
- {
- // Compute 4x4 elements per iteration
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const uint32x4_t row0 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
- const uint32x4_t row1 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
- const uint32x4_t row2 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
- const uint32x4_t row3 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
-
- // Transpose 2x2
- const uint32x2x2_t k0_u32 = vtrn_u32(vget_low_u32(row0), vget_low_u32(row1));
- const uint32x2x2_t k1_u32 = vtrn_u32(vget_high_u32(row2), vget_high_u32(row3));
- const uint32x2x2_t k2_u32 = vtrn_u32(vget_high_u32(row0), vget_high_u32(row1));
- const uint32x2x2_t k3_u32 = vtrn_u32(vget_low_u32(row2), vget_low_u32(row3));
-
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
-
- // Swap block 01 with block 10 and store
- vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vcombine_u32(k0_u32.val[0], k3_u32.val[0]));
- vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vcombine_u32(k0_u32.val[1], k3_u32.val[1]));
- vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vcombine_u32(k2_u32.val[0], k1_u32.val[0]));
- vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vcombine_u32(k2_u32.val[1], k1_u32.val[1]));
- }
-
- // Compute left-over elements (1x4)
- for(; x < window_end_x; ++x)
+ execute_window_loop(
+ window_in,
+ [&](const Coordinates &id)
{
- const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
- const uint32_t val1 = *(reinterpret_cast<uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
- const uint32_t val2 = *(reinterpret_cast<uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
- const uint32_t val3 = *(reinterpret_cast<uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
-
- uint32x4_t result = vdupq_n_u32(0);
- result = vsetq_lane_u32(val0, result, 0);
- result = vsetq_lane_u32(val1, result, 1);
- result = vsetq_lane_u32(val2, result, 2);
- result = vsetq_lane_u32(val3, result, 3);
-
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
-
- vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes), result);
- }
- },
- input, output);
+ // Compute 4x4 elements per iteration
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const uint32x4_t row0 =
+ vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint32x4_t row1 =
+ vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint32x4_t row2 =
+ vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint32x4_t row3 =
+ vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+
+ // Transpose 2x2
+ const uint32x2x2_t k0_u32 = vtrn_u32(vget_low_u32(row0), vget_low_u32(row1));
+ const uint32x2x2_t k1_u32 = vtrn_u32(vget_high_u32(row2), vget_high_u32(row3));
+ const uint32x2x2_t k2_u32 = vtrn_u32(vget_high_u32(row0), vget_high_u32(row1));
+ const uint32x2x2_t k3_u32 = vtrn_u32(vget_low_u32(row2), vget_low_u32(row3));
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
+
+ // Swap block 01 with block 10 and store
+ vst1q_u32(
+ reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes),
+ vcombine_u32(k0_u32.val[0], k3_u32.val[0]));
+ vst1q_u32(
+ reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes),
+ vcombine_u32(k0_u32.val[1], k3_u32.val[1]));
+ vst1q_u32(
+ reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes),
+ vcombine_u32(k2_u32.val[0], k1_u32.val[0]));
+ vst1q_u32(
+ reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes),
+ vcombine_u32(k2_u32.val[1], k1_u32.val[1]));
+ }
+
+ // Compute left-over elements (1x4)
+ for (; x < window_end_x; ++x)
+ {
+ const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint32_t val1 = *(reinterpret_cast<uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint32_t val2 = *(reinterpret_cast<uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint32_t val3 = *(reinterpret_cast<uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+
+ uint32x4_t result = vdupq_n_u32(0);
+ result = vsetq_lane_u32(val0, result, 0);
+ result = vsetq_lane_u32(val1, result, 1);
+ result = vsetq_lane_u32(val2, result, 2);
+ result = vsetq_lane_u32(val3, result, 3);
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
+
+ vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes), result);
+ }
+ },
+ input, output);
}
- if(left_over_loop_y)
+ if (left_over_loop_y)
{
window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
@@ -583,16 +712,18 @@ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &win
Iterator output(out, window_out);
// Compute left-over elements along the y dimension (1x1)
- execute_window_loop(window_in, [&](const Coordinates & id)
- {
- const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr()));
+ execute_window_loop(
+ window_in,
+ [&](const Coordinates &id)
+ {
+ const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr()));
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes;
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes;
- *(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes)) = val0;
- },
- input, output);
+ *(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes)) = val0;
+ },
+ input, output);
}
}
#endif // __aarch64__
@@ -616,7 +747,8 @@ void CpuTransposeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
const unsigned int num_elems_processed_per_iteration_y = num_elems_processed(src->element_size());
// Configure kernel window
- Window win = calculate_max_window(*src, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ Window win =
+ calculate_max_window(*src, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
// The CpuTranspose doesn't need padding so update_window_and_padding() can be skipped
Coordinates coord;
@@ -637,7 +769,7 @@ Status CpuTransposeKernel::validate(const ITensorInfo *src, const ITensorInfo *d
"Element size not supported");
// Validate configured destination
- if(dst->total_size() != 0)
+ if (dst->total_size() != 0)
{
const TensorShape dst_shape = misc::shape_calculator::compute_transposed_shape(*src);
@@ -658,7 +790,7 @@ void CpuTransposeKernel::run_op(ITensorPack &tensors, const Window &window, cons
const auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
auto dst = tensors.get_tensor(TensorType::ACL_DST);
- switch(src->info()->element_size())
+ switch (src->info()->element_size())
{
case 1:
transpose_8bit_elements(src, dst, window);