diff options
Diffstat (limited to 'src/cpu/kernels/CpuTransposeKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuTransposeKernel.cpp | 772 |
1 files changed, 452 insertions, 320 deletions
diff --git a/src/cpu/kernels/CpuTransposeKernel.cpp b/src/cpu/kernels/CpuTransposeKernel.cpp index b2cebc4230..615bc6ce1e 100644 --- a/src/cpu/kernels/CpuTransposeKernel.cpp +++ b/src/cpu/kernels/CpuTransposeKernel.cpp @@ -28,8 +28,9 @@ #include "arm_compute/core/ITensor.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" -#include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/core/Validate.h" + #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" @@ -45,7 +46,7 @@ namespace { unsigned int num_elems_processed(size_t element_size) { - switch(element_size) + switch (element_size) { case 1: return 8; @@ -81,10 +82,10 @@ void transpose_8bit_elements(const ITensor *in, ITensor *out, const Window &wind Window window_in(window); window_in.set(Window::DimX, Window::Dimension(0, 1, 1)); - if(left_over_loop_y) + if (left_over_loop_y) { // Check if window_end_y_multiple_of is greater than window_start_y - if(window_end_y_multiple_of > window_start_y) + if (window_end_y_multiple_of > window_start_y) { window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y)); } @@ -101,87 +102,121 @@ void transpose_8bit_elements(const ITensor *in, ITensor *out, const Window &wind Iterator output(out, window_out); // Run the SIMD path if and only if the input is not a row-vector - if(in->info()->dimension(1) != 1) + if (in->info()->dimension(1) != 1) { Iterator input(in, window_in); - execute_window_loop(window_in, [&](const Coordinates & id) - { - // Compute 8x8 elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + execute_window_loop( + window_in, + [&](const Coordinates &id) { - const uint8x8_t row0 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 0 * input_stride_in_bytes)); - const uint8x8_t row1 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 1 * input_stride_in_bytes)); - const uint8x8_t row2 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 2 * input_stride_in_bytes)); - const uint8x8_t row3 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 3 * input_stride_in_bytes)); - const uint8x8_t row4 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 4 * input_stride_in_bytes)); - const uint8x8_t row5 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 5 * input_stride_in_bytes)); - const uint8x8_t row6 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 6 * input_stride_in_bytes)); - const uint8x8_t row7 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 7 * input_stride_in_bytes)); - - // Transpose 2x2 - const uint8x8x2_t k0_u8 = vtrn_u8(row0, row1); - const uint8x8x2_t k1_u8 = vtrn_u8(row2, row3); - const uint8x8x2_t k2_u8 = vtrn_u8(row4, row5); - const uint8x8x2_t k3_u8 = vtrn_u8(row6, row7); - - // Transpose 4x4 - const uint16x4x2_t k0_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[0]), vreinterpret_u16_u8(k1_u8.val[0])); - const uint16x4x2_t k1_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[1]), vreinterpret_u16_u8(k1_u8.val[1])); - const uint16x4x2_t k2_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[0]), vreinterpret_u16_u8(k3_u8.val[0])); - const uint16x4x2_t k3_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[1]), vreinterpret_u16_u8(k3_u8.val[1])); - - // Transpose 8x8 - const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k2_u16.val[0])); - const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k2_u16.val[1])); - const uint32x2x2_t k2_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[0]), vreinterpret_u32_u16(k3_u16.val[0])); - const uint32x2x2_t k3_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[1]), vreinterpret_u32_u16(k3_u16.val[1])); - - // Compute destination address - const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + x * output_stride_in_bytes; - - vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[0]))); - vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[0]))); - vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[0]))); - vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[0]))); - vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[1]))); - vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[1]))); - vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[1]))); - vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[1]))); - } - - // Compute left-over elements along the x dimension (1x8) - for(; x < window_end_x; ++x) - { - const uint8_t val0 = *(input.ptr() + x + 0 * input_stride_in_bytes); - const uint8_t val1 = *(input.ptr() + x + 1 * input_stride_in_bytes); - const uint8_t val2 = *(input.ptr() + x + 2 * input_stride_in_bytes); - const uint8_t val3 = *(input.ptr() + x + 3 * input_stride_in_bytes); - const uint8_t val4 = *(input.ptr() + x + 4 * input_stride_in_bytes); - const uint8_t val5 = *(input.ptr() + x + 5 * input_stride_in_bytes); - const uint8_t val6 = *(input.ptr() + x + 6 * input_stride_in_bytes); - const uint8_t val7 = *(input.ptr() + x + 7 * input_stride_in_bytes); - - uint8x8_t result = vdup_n_u8(0); - result = vset_lane_u8(val0, result, 0); - result = vset_lane_u8(val1, result, 1); - result = vset_lane_u8(val2, result, 2); - result = vset_lane_u8(val3, result, 3); - result = vset_lane_u8(val4, result, 4); - result = vset_lane_u8(val5, result, 5); - result = vset_lane_u8(val6, result, 6); - result = vset_lane_u8(val7, result, 7); - - // Compute destination address - const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + x * output_stride_in_bytes; - - vst1_u8(output.ptr() + dst_offset_in_bytes, result); - } - }, - input, output); + // Compute 8x8 elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const uint8x8_t row0 = + vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 0 * input_stride_in_bytes)); + const uint8x8_t row1 = + vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 1 * input_stride_in_bytes)); + const uint8x8_t row2 = + vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 2 * input_stride_in_bytes)); + const uint8x8_t row3 = + vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 3 * input_stride_in_bytes)); + const uint8x8_t row4 = + vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 4 * input_stride_in_bytes)); + const uint8x8_t row5 = + vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 5 * input_stride_in_bytes)); + const uint8x8_t row6 = + vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 6 * input_stride_in_bytes)); + const uint8x8_t row7 = + vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 7 * input_stride_in_bytes)); + + // Transpose 2x2 + const uint8x8x2_t k0_u8 = vtrn_u8(row0, row1); + const uint8x8x2_t k1_u8 = vtrn_u8(row2, row3); + const uint8x8x2_t k2_u8 = vtrn_u8(row4, row5); + const uint8x8x2_t k3_u8 = vtrn_u8(row6, row7); + + // Transpose 4x4 + const uint16x4x2_t k0_u16 = + vtrn_u16(vreinterpret_u16_u8(k0_u8.val[0]), vreinterpret_u16_u8(k1_u8.val[0])); + const uint16x4x2_t k1_u16 = + vtrn_u16(vreinterpret_u16_u8(k0_u8.val[1]), vreinterpret_u16_u8(k1_u8.val[1])); + const uint16x4x2_t k2_u16 = + vtrn_u16(vreinterpret_u16_u8(k2_u8.val[0]), vreinterpret_u16_u8(k3_u8.val[0])); + const uint16x4x2_t k3_u16 = + vtrn_u16(vreinterpret_u16_u8(k2_u8.val[1]), vreinterpret_u16_u8(k3_u8.val[1])); + + // Transpose 8x8 + const uint32x2x2_t k0_u32 = + vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k2_u16.val[0])); + const uint32x2x2_t k1_u32 = + vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k2_u16.val[1])); + const uint32x2x2_t k2_u32 = + vtrn_u32(vreinterpret_u32_u16(k1_u16.val[0]), vreinterpret_u32_u16(k3_u16.val[0])); + const uint32x2x2_t k3_u32 = + vtrn_u32(vreinterpret_u32_u16(k1_u16.val[1]), vreinterpret_u32_u16(k3_u16.val[1])); + + // Compute destination address + const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + x * output_stride_in_bytes; + + vst1_u8( + reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), + vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[0]))); + vst1_u8( + reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), + vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[0]))); + vst1_u8( + reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), + vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[0]))); + vst1_u8( + reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), + vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[0]))); + vst1_u8( + reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), + vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[1]))); + vst1_u8( + reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), + vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[1]))); + vst1_u8( + reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), + vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[1]))); + vst1_u8( + reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), + vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[1]))); + } + + // Compute left-over elements along the x dimension (1x8) + for (; x < window_end_x; ++x) + { + const uint8_t val0 = *(input.ptr() + x + 0 * input_stride_in_bytes); + const uint8_t val1 = *(input.ptr() + x + 1 * input_stride_in_bytes); + const uint8_t val2 = *(input.ptr() + x + 2 * input_stride_in_bytes); + const uint8_t val3 = *(input.ptr() + x + 3 * input_stride_in_bytes); + const uint8_t val4 = *(input.ptr() + x + 4 * input_stride_in_bytes); + const uint8_t val5 = *(input.ptr() + x + 5 * input_stride_in_bytes); + const uint8_t val6 = *(input.ptr() + x + 6 * input_stride_in_bytes); + const uint8_t val7 = *(input.ptr() + x + 7 * input_stride_in_bytes); + + uint8x8_t result = vdup_n_u8(0); + result = vset_lane_u8(val0, result, 0); + result = vset_lane_u8(val1, result, 1); + result = vset_lane_u8(val2, result, 2); + result = vset_lane_u8(val3, result, 3); + result = vset_lane_u8(val4, result, 4); + result = vset_lane_u8(val5, result, 5); + result = vset_lane_u8(val6, result, 6); + result = vset_lane_u8(val7, result, 7); + + // Compute destination address + const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + x * output_stride_in_bytes; + + vst1_u8(output.ptr() + dst_offset_in_bytes, result); + } + }, + input, output); } - if(left_over_loop_y) + if (left_over_loop_y) { window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1)); window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1)); @@ -190,16 +225,18 @@ void transpose_8bit_elements(const ITensor *in, ITensor *out, const Window &wind Iterator output(out, window_out); // Compute left-over elements along the y dimension (1x1) - execute_window_loop(window_in, [&](const Coordinates & id) - { - const uint8_t val0 = *input.ptr(); + execute_window_loop( + window_in, + [&](const Coordinates &id) + { + const uint8_t val0 = *input.ptr(); - // Compute destination address - const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + id.x() * output_stride_in_bytes; + // Compute destination address + const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + id.x() * output_stride_in_bytes; - *(output.ptr() + dst_offset_in_bytes) = val0; - }, - input, output); + *(output.ptr() + dst_offset_in_bytes) = val0; + }, + input, output); } } @@ -220,10 +257,10 @@ void transpose_16bit_elements(const ITensor *in, ITensor *out, const Window &win Window window_in(window); window_in.set(Window::DimX, Window::Dimension(0, 1, 1)); - if(left_over_loop_y) + if (left_over_loop_y) { // Check if window_end_y_multiple_of is greater than window_start_y - if(window_end_y_multiple_of > window_start_y) + if (window_end_y_multiple_of > window_start_y) { window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y)); } @@ -240,61 +277,77 @@ void transpose_16bit_elements(const ITensor *in, ITensor *out, const Window &win Iterator output(out, window_out); // Run the SIMD path if and only if the input is not a row-vector - if(in->info()->dimension(1) != 1) + if (in->info()->dimension(1) != 1) { Iterator input(in, window_in); - execute_window_loop(window_in, [&](const Coordinates & id) - { - // Compute 4x4 elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + execute_window_loop( + window_in, + [&](const Coordinates &id) { - const uint16x4_t row0 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 0 * input_stride_in_bytes) + x); - const uint16x4_t row1 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 1 * input_stride_in_bytes) + x); - const uint16x4_t row2 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 2 * input_stride_in_bytes) + x); - const uint16x4_t row3 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 3 * input_stride_in_bytes) + x); - - // Transpose 2x2 - const uint16x4x2_t k0_u16 = vtrn_u16(row0, row1); - const uint16x4x2_t k1_u16 = vtrn_u16(row2, row3); - - // Transpose 4x4 - const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k1_u16.val[0])); - const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k1_u16.val[1])); - - // Compute destination address - const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + x * output_stride_in_bytes; - - vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[0])); - vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[0])); - vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[1])); - vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[1])); - } - - // Compute left-over elements (1x4) - for(; x < window_end_x; ++x) - { - const uint16_t val0 = *(reinterpret_cast<uint16_t *>(input.ptr() + 0 * input_stride_in_bytes) + x); - const uint16_t val1 = *(reinterpret_cast<uint16_t *>(input.ptr() + 1 * input_stride_in_bytes) + x); - const uint16_t val2 = *(reinterpret_cast<uint16_t *>(input.ptr() + 2 * input_stride_in_bytes) + x); - const uint16_t val3 = *(reinterpret_cast<uint16_t *>(input.ptr() + 3 * input_stride_in_bytes) + x); - - uint16x4_t result = vdup_n_u16(0); - result = vset_lane_u16(val0, result, 0); - result = vset_lane_u16(val1, result, 1); - result = vset_lane_u16(val2, result, 2); - result = vset_lane_u16(val3, result, 3); - - // Compute destination address - const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + x * output_stride_in_bytes; - - vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes), result); - } - }, - input, output); + // Compute 4x4 elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const uint16x4_t row0 = + vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 0 * input_stride_in_bytes) + x); + const uint16x4_t row1 = + vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 1 * input_stride_in_bytes) + x); + const uint16x4_t row2 = + vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 2 * input_stride_in_bytes) + x); + const uint16x4_t row3 = + vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 3 * input_stride_in_bytes) + x); + + // Transpose 2x2 + const uint16x4x2_t k0_u16 = vtrn_u16(row0, row1); + const uint16x4x2_t k1_u16 = vtrn_u16(row2, row3); + + // Transpose 4x4 + const uint32x2x2_t k0_u32 = + vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k1_u16.val[0])); + const uint32x2x2_t k1_u32 = + vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k1_u16.val[1])); + + // Compute destination address + const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + x * output_stride_in_bytes; + + vst1_u16( + reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), + vreinterpret_u16_u32(k0_u32.val[0])); + vst1_u16( + reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), + vreinterpret_u16_u32(k1_u32.val[0])); + vst1_u16( + reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), + vreinterpret_u16_u32(k0_u32.val[1])); + vst1_u16( + reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), + vreinterpret_u16_u32(k1_u32.val[1])); + } + + // Compute left-over elements (1x4) + for (; x < window_end_x; ++x) + { + const uint16_t val0 = *(reinterpret_cast<uint16_t *>(input.ptr() + 0 * input_stride_in_bytes) + x); + const uint16_t val1 = *(reinterpret_cast<uint16_t *>(input.ptr() + 1 * input_stride_in_bytes) + x); + const uint16_t val2 = *(reinterpret_cast<uint16_t *>(input.ptr() + 2 * input_stride_in_bytes) + x); + const uint16_t val3 = *(reinterpret_cast<uint16_t *>(input.ptr() + 3 * input_stride_in_bytes) + x); + + uint16x4_t result = vdup_n_u16(0); + result = vset_lane_u16(val0, result, 0); + result = vset_lane_u16(val1, result, 1); + result = vset_lane_u16(val2, result, 2); + result = vset_lane_u16(val3, result, 3); + + // Compute destination address + const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + x * output_stride_in_bytes; + + vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes), result); + } + }, + input, output); } - if(left_over_loop_y) + if (left_over_loop_y) { window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1)); window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1)); @@ -303,16 +356,18 @@ void transpose_16bit_elements(const ITensor *in, ITensor *out, const Window &win Iterator output(out, window_out); // Compute left-over elements along the y dimension (1x1) - execute_window_loop(window_in, [&](const Coordinates & id) - { - const uint16_t val0 = *(reinterpret_cast<uint16_t *>(input.ptr())); + execute_window_loop( + window_in, + [&](const Coordinates &id) + { + const uint16_t val0 = *(reinterpret_cast<uint16_t *>(input.ptr())); - // Compute destination address - const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + id.x() * output_stride_in_bytes; + // Compute destination address + const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + id.x() * output_stride_in_bytes; - *(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes)) = val0; - }, - input, output); + *(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes)) = val0; + }, + input, output); } } @@ -347,10 +402,10 @@ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &win Window window_in(window); window_in.set(Window::DimX, Window::Dimension(0, 1, 1)); - if(left_over_loop_y) + if (left_over_loop_y) { // Check if window_end_y_multiple_of is greater than window_start_y - if(window_end_y_multiple_of > window_start_y) + if (window_end_y_multiple_of > window_start_y) { window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y)); } @@ -367,102 +422,160 @@ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &win Iterator output(out, window_out); // Run the SIMD path if and only if the input is not a row-vector - if(in->info()->dimension(1) != 1) + if (in->info()->dimension(1) != 1) { Iterator input(in, window_in); - execute_window_loop(window_in, [&](const Coordinates & id) - { - // Compute 8x8 elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + execute_window_loop( + window_in, + [&](const Coordinates &id) { - // Load - const uint32x4x2_t row0 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x); - const uint32x4x2_t row1 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x); - const uint32x4x2_t row2 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x); - const uint32x4x2_t row3 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x); - const uint32x4x2_t row4 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 4 * input_stride_in_bytes) + x); - const uint32x4x2_t row5 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 5 * input_stride_in_bytes) + x); - const uint32x4x2_t row6 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 6 * input_stride_in_bytes) + x); - const uint32x4x2_t row7 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 7 * input_stride_in_bytes) + x); - - // Transpose 2x4 - const uint32x4x2_t k0_u32 = {vtrn1q_u32(row0.val[0], row1.val[0]), vtrn2q_u32(row0.val[0], row1.val[0])}; - const uint32x4x2_t k1_u32 = {vtrn1q_u32(row0.val[1], row1.val[1]), vtrn2q_u32(row0.val[1], row1.val[1])}; - const uint32x4x2_t k2_u32 = {vtrn1q_u32(row2.val[0], row3.val[0]), vtrn2q_u32(row2.val[0], row3.val[0])}; - const uint32x4x2_t k3_u32 = {vtrn1q_u32(row2.val[1], row3.val[1]), vtrn2q_u32(row2.val[1], row3.val[1])}; - const uint32x4x2_t k4_u32 = {vtrn1q_u32(row4.val[0], row5.val[0]), vtrn2q_u32(row4.val[0], row5.val[0])}; - const uint32x4x2_t k5_u32 = {vtrn1q_u32(row4.val[1], row5.val[1]), vtrn2q_u32(row4.val[1], row5.val[1])}; - const uint32x4x2_t k6_u32 = {vtrn1q_u32(row6.val[0], row7.val[0]), vtrn2q_u32(row6.val[0], row7.val[0])}; - const uint32x4x2_t k7_u32 = {vtrn1q_u32(row6.val[1], row7.val[1]), vtrn2q_u32(row6.val[1], row7.val[1])}; - - // Transpose 2x2 - const uint64x2x2_t k0_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0]))}; - const uint64x2x2_t k1_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1]))}; - const uint64x2x2_t k2_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0]))}; - const uint64x2x2_t k3_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1]))}; - const uint64x2x2_t k4_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0]))}; - const uint64x2x2_t k5_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1]))}; - const uint64x2x2_t k6_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0]))}; - const uint64x2x2_t k7_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1]))}; - - // Swap blocks - const uint32x4x2_t col0 = {vreinterpretq_u32_u64(k0_u64.val[0]), vreinterpretq_u32_u64(k4_u64.val[0])}; - const uint32x4x2_t col1 = {vreinterpretq_u32_u64(k1_u64.val[0]), vreinterpretq_u32_u64(k5_u64.val[0])}; - const uint32x4x2_t col2 = {vreinterpretq_u32_u64(k0_u64.val[1]), vreinterpretq_u32_u64(k4_u64.val[1])}; - const uint32x4x2_t col3 = {vreinterpretq_u32_u64(k1_u64.val[1]), vreinterpretq_u32_u64(k5_u64.val[1])}; - const uint32x4x2_t col4 = {vreinterpretq_u32_u64(k2_u64.val[0]), vreinterpretq_u32_u64(k6_u64.val[0])}; - const uint32x4x2_t col5 = {vreinterpretq_u32_u64(k3_u64.val[0]), vreinterpretq_u32_u64(k7_u64.val[0])}; - const uint32x4x2_t col6 = {vreinterpretq_u32_u64(k2_u64.val[1]), vreinterpretq_u32_u64(k6_u64.val[1])}; - const uint32x4x2_t col7 = {vreinterpretq_u32_u64(k3_u64.val[1]), vreinterpretq_u32_u64(k7_u64.val[1])}; - - // Compute destination address - const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes; - - // Store - vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), col0); - vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), col1); - vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), col2); - vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), col3); - vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), col4); - vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), col5); - vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), col6); - vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), col7); - } - - // Compute left-over elements (8x1) - for(; x < window_end_x; ++x) - { - const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x); - const uint32_t val1 = *(reinterpret_cast<uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x); - const uint32_t val2 = *(reinterpret_cast<uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x); - const uint32_t val3 = *(reinterpret_cast<uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x); - const uint32_t val4 = *(reinterpret_cast<uint32_t *>(input.ptr() + 4 * input_stride_in_bytes) + x); - const uint32_t val5 = *(reinterpret_cast<uint32_t *>(input.ptr() + 5 * input_stride_in_bytes) + x); - const uint32_t val6 = *(reinterpret_cast<uint32_t *>(input.ptr() + 6 * input_stride_in_bytes) + x); - const uint32_t val7 = *(reinterpret_cast<uint32_t *>(input.ptr() + 7 * input_stride_in_bytes) + x); - - uint32x4_t result0 = vdupq_n_u32(0); - uint32x4_t result1 = vdupq_n_u32(0); - result0 = vsetq_lane_u32(val0, result0, 0); - result0 = vsetq_lane_u32(val1, result0, 1); - result0 = vsetq_lane_u32(val2, result0, 2); - result0 = vsetq_lane_u32(val3, result0, 3); - result1 = vsetq_lane_u32(val4, result1, 0); - result1 = vsetq_lane_u32(val5, result1, 1); - result1 = vsetq_lane_u32(val6, result1, 2); - result1 = vsetq_lane_u32(val7, result1, 3); - - // Compute destination address - const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes; - - vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes), {result0, result1}); - } - }, - input, output); + // Compute 8x8 elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + // Load + const uint32x4x2_t row0 = + vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x); + const uint32x4x2_t row1 = + vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x); + const uint32x4x2_t row2 = + vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x); + const uint32x4x2_t row3 = + vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x); + const uint32x4x2_t row4 = + vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 4 * input_stride_in_bytes) + x); + const uint32x4x2_t row5 = + vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 5 * input_stride_in_bytes) + x); + const uint32x4x2_t row6 = + vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 6 * input_stride_in_bytes) + x); + const uint32x4x2_t row7 = + vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 7 * input_stride_in_bytes) + x); + + // Transpose 2x4 + const uint32x4x2_t k0_u32 = {vtrn1q_u32(row0.val[0], row1.val[0]), + vtrn2q_u32(row0.val[0], row1.val[0])}; + const uint32x4x2_t k1_u32 = {vtrn1q_u32(row0.val[1], row1.val[1]), + vtrn2q_u32(row0.val[1], row1.val[1])}; + const uint32x4x2_t k2_u32 = {vtrn1q_u32(row2.val[0], row3.val[0]), + vtrn2q_u32(row2.val[0], row3.val[0])}; + const uint32x4x2_t k3_u32 = {vtrn1q_u32(row2.val[1], row3.val[1]), + vtrn2q_u32(row2.val[1], row3.val[1])}; + const uint32x4x2_t k4_u32 = {vtrn1q_u32(row4.val[0], row5.val[0]), + vtrn2q_u32(row4.val[0], row5.val[0])}; + const uint32x4x2_t k5_u32 = {vtrn1q_u32(row4.val[1], row5.val[1]), + vtrn2q_u32(row4.val[1], row5.val[1])}; + const uint32x4x2_t k6_u32 = {vtrn1q_u32(row6.val[0], row7.val[0]), + vtrn2q_u32(row6.val[0], row7.val[0])}; + const uint32x4x2_t k7_u32 = {vtrn1q_u32(row6.val[1], row7.val[1]), + vtrn2q_u32(row6.val[1], row7.val[1])}; + + // Transpose 2x2 + const uint64x2x2_t k0_u64 = { + vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0])), + vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0]))}; + const uint64x2x2_t k1_u64 = { + vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1])), + vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1]))}; + const uint64x2x2_t k2_u64 = { + vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0])), + vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0]))}; + const uint64x2x2_t k3_u64 = { + vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1])), + vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1]))}; + const uint64x2x2_t k4_u64 = { + vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0])), + vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0]))}; + const uint64x2x2_t k5_u64 = { + vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1])), + vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1]))}; + const uint64x2x2_t k6_u64 = { + vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0])), + vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0]))}; + const uint64x2x2_t k7_u64 = { + vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1])), + vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1]))}; + + // Swap blocks + const uint32x4x2_t col0 = {vreinterpretq_u32_u64(k0_u64.val[0]), + vreinterpretq_u32_u64(k4_u64.val[0])}; + const uint32x4x2_t col1 = {vreinterpretq_u32_u64(k1_u64.val[0]), + vreinterpretq_u32_u64(k5_u64.val[0])}; + const uint32x4x2_t col2 = {vreinterpretq_u32_u64(k0_u64.val[1]), + vreinterpretq_u32_u64(k4_u64.val[1])}; + const uint32x4x2_t col3 = {vreinterpretq_u32_u64(k1_u64.val[1]), + vreinterpretq_u32_u64(k5_u64.val[1])}; + const uint32x4x2_t col4 = {vreinterpretq_u32_u64(k2_u64.val[0]), + vreinterpretq_u32_u64(k6_u64.val[0])}; + const uint32x4x2_t col5 = {vreinterpretq_u32_u64(k3_u64.val[0]), + vreinterpretq_u32_u64(k7_u64.val[0])}; + const uint32x4x2_t col6 = {vreinterpretq_u32_u64(k2_u64.val[1]), + vreinterpretq_u32_u64(k6_u64.val[1])}; + const uint32x4x2_t col7 = {vreinterpretq_u32_u64(k3_u64.val[1]), + vreinterpretq_u32_u64(k7_u64.val[1])}; + + // Compute destination address + const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes; + + // Store + vst1q_u32_x2_( + reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), + col0); + vst1q_u32_x2_( + reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), + col1); + vst1q_u32_x2_( + reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), + col2); + vst1q_u32_x2_( + reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), + col3); + vst1q_u32_x2_( + reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), + col4); + vst1q_u32_x2_( + reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), + col5); + vst1q_u32_x2_( + reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), + col6); + vst1q_u32_x2_( + reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), + col7); + } + + // Compute left-over elements (8x1) + for (; x < window_end_x; ++x) + { + const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x); + const uint32_t val1 = *(reinterpret_cast<uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x); + const uint32_t val2 = *(reinterpret_cast<uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x); + const uint32_t val3 = *(reinterpret_cast<uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x); + const uint32_t val4 = *(reinterpret_cast<uint32_t *>(input.ptr() + 4 * input_stride_in_bytes) + x); + const uint32_t val5 = *(reinterpret_cast<uint32_t *>(input.ptr() + 5 * input_stride_in_bytes) + x); + const uint32_t val6 = *(reinterpret_cast<uint32_t *>(input.ptr() + 6 * input_stride_in_bytes) + x); + const uint32_t val7 = *(reinterpret_cast<uint32_t *>(input.ptr() + 7 * input_stride_in_bytes) + x); + + uint32x4_t result0 = vdupq_n_u32(0); + uint32x4_t result1 = vdupq_n_u32(0); + result0 = vsetq_lane_u32(val0, result0, 0); + result0 = vsetq_lane_u32(val1, result0, 1); + result0 = vsetq_lane_u32(val2, result0, 2); + result0 = vsetq_lane_u32(val3, result0, 3); + result1 = vsetq_lane_u32(val4, result1, 0); + result1 = vsetq_lane_u32(val5, result1, 1); + result1 = vsetq_lane_u32(val6, result1, 2); + result1 = vsetq_lane_u32(val7, result1, 3); + + // Compute destination address + const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes; + + vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes), {result0, result1}); + } + }, + input, output); } - if(left_over_loop_y) + if (left_over_loop_y) { window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1)); window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1)); @@ -471,40 +584,42 @@ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &win Iterator output(out, window_out); // Compute left-over elements along the y dimension (1x1) - execute_window_loop(window_in, [&](const Coordinates & id) - { - const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr())); + execute_window_loop( + window_in, + [&](const Coordinates &id) + { + const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr())); - // Compute destination address - const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes; + // Compute destination address + const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes; - *(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes)) = val0; - }, - input, output); + *(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes)) = val0; + }, + input, output); } } #else // __aarch64__ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &window) { - const int window_step_x = 4; - const int window_step_y = 4; - const int window_start_x = window.x().start(); - const int window_end_x = window.x().end(); - const int window_start_y = window.y().start(); - const int window_end_y = std::min(window.y().end(), static_cast<int>(in->info()->dimension(1))); - const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y; - const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1]; - const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1]; + const int window_step_x = 4; + const int window_step_y = 4; + const int window_start_x = window.x().start(); + const int window_end_x = window.x().end(); + const int window_start_y = window.y().start(); + const int window_end_y = std::min(window.y().end(), static_cast<int>(in->info()->dimension(1))); + const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y; + const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1]; + const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1]; // Check if we need a left-over loop for the y dimension bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0); Window window_in(window); window_in.set(Window::DimX, Window::Dimension(0, 1, 1)); - if(left_over_loop_y) + if (left_over_loop_y) { // Check if window_end_y_multiple_of is greater than window_start_y - if(window_end_y_multiple_of > window_start_y) + if (window_end_y_multiple_of > window_start_y) { window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y)); } @@ -521,60 +636,74 @@ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &win Iterator output(out, window_out); // Run the SIMD path if and only if the input is not a row-vector - if(in->info()->dimension(1) != 1) + if (in->info()->dimension(1) != 1) { Iterator input(in, window_in); - execute_window_loop(window_in, [&](const Coordinates & id) - { - // Compute 4x4 elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const uint32x4_t row0 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x); - const uint32x4_t row1 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x); - const uint32x4_t row2 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x); - const uint32x4_t row3 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x); - - // Transpose 2x2 - const uint32x2x2_t k0_u32 = vtrn_u32(vget_low_u32(row0), vget_low_u32(row1)); - const uint32x2x2_t k1_u32 = vtrn_u32(vget_high_u32(row2), vget_high_u32(row3)); - const uint32x2x2_t k2_u32 = vtrn_u32(vget_high_u32(row0), vget_high_u32(row1)); - const uint32x2x2_t k3_u32 = vtrn_u32(vget_low_u32(row2), vget_low_u32(row3)); - - // Compute destination address - const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes; - - // Swap block 01 with block 10 and store - vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vcombine_u32(k0_u32.val[0], k3_u32.val[0])); - vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vcombine_u32(k0_u32.val[1], k3_u32.val[1])); - vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vcombine_u32(k2_u32.val[0], k1_u32.val[0])); - vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vcombine_u32(k2_u32.val[1], k1_u32.val[1])); - } - - // Compute left-over elements (1x4) - for(; x < window_end_x; ++x) + execute_window_loop( + window_in, + [&](const Coordinates &id) { - const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x); - const uint32_t val1 = *(reinterpret_cast<uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x); - const uint32_t val2 = *(reinterpret_cast<uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x); - const uint32_t val3 = *(reinterpret_cast<uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x); - - uint32x4_t result = vdupq_n_u32(0); - result = vsetq_lane_u32(val0, result, 0); - result = vsetq_lane_u32(val1, result, 1); - result = vsetq_lane_u32(val2, result, 2); - result = vsetq_lane_u32(val3, result, 3); - - // Compute destination address - const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes; - - vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes), result); - } - }, - input, output); + // Compute 4x4 elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const uint32x4_t row0 = + vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x); + const uint32x4_t row1 = + vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x); + const uint32x4_t row2 = + vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x); + const uint32x4_t row3 = + vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x); + + // Transpose 2x2 + const uint32x2x2_t k0_u32 = vtrn_u32(vget_low_u32(row0), vget_low_u32(row1)); + const uint32x2x2_t k1_u32 = vtrn_u32(vget_high_u32(row2), vget_high_u32(row3)); + const uint32x2x2_t k2_u32 = vtrn_u32(vget_high_u32(row0), vget_high_u32(row1)); + const uint32x2x2_t k3_u32 = vtrn_u32(vget_low_u32(row2), vget_low_u32(row3)); + + // Compute destination address + const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes; + + // Swap block 01 with block 10 and store + vst1q_u32( + reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), + vcombine_u32(k0_u32.val[0], k3_u32.val[0])); + vst1q_u32( + reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), + vcombine_u32(k0_u32.val[1], k3_u32.val[1])); + vst1q_u32( + reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), + vcombine_u32(k2_u32.val[0], k1_u32.val[0])); + vst1q_u32( + reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), + vcombine_u32(k2_u32.val[1], k1_u32.val[1])); + } + + // Compute left-over elements (1x4) + for (; x < window_end_x; ++x) + { + const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x); + const uint32_t val1 = *(reinterpret_cast<uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x); + const uint32_t val2 = *(reinterpret_cast<uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x); + const uint32_t val3 = *(reinterpret_cast<uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x); + + uint32x4_t result = vdupq_n_u32(0); + result = vsetq_lane_u32(val0, result, 0); + result = vsetq_lane_u32(val1, result, 1); + result = vsetq_lane_u32(val2, result, 2); + result = vsetq_lane_u32(val3, result, 3); + + // Compute destination address + const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes; + + vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes), result); + } + }, + input, output); } - if(left_over_loop_y) + if (left_over_loop_y) { window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1)); window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1)); @@ -583,16 +712,18 @@ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &win Iterator output(out, window_out); // Compute left-over elements along the y dimension (1x1) - execute_window_loop(window_in, [&](const Coordinates & id) - { - const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr())); + execute_window_loop( + window_in, + [&](const Coordinates &id) + { + const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr())); - // Compute destination address - const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes; + // Compute destination address + const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes; - *(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes)) = val0; - }, - input, output); + *(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes)) = val0; + }, + input, output); } } #endif // __aarch64__ @@ -616,7 +747,8 @@ void CpuTransposeKernel::configure(const ITensorInfo *src, ITensorInfo *dst) const unsigned int num_elems_processed_per_iteration_y = num_elems_processed(src->element_size()); // Configure kernel window - Window win = calculate_max_window(*src, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); + Window win = + calculate_max_window(*src, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); // The CpuTranspose doesn't need padding so update_window_and_padding() can be skipped Coordinates coord; @@ -637,7 +769,7 @@ Status CpuTransposeKernel::validate(const ITensorInfo *src, const ITensorInfo *d "Element size not supported"); // Validate configured destination - if(dst->total_size() != 0) + if (dst->total_size() != 0) { const TensorShape dst_shape = misc::shape_calculator::compute_transposed_shape(*src); @@ -658,7 +790,7 @@ void CpuTransposeKernel::run_op(ITensorPack &tensors, const Window &window, cons const auto src = tensors.get_const_tensor(TensorType::ACL_SRC); auto dst = tensors.get_tensor(TensorType::ACL_DST); - switch(src->info()->element_size()) + switch (src->info()->element_size()) { case 1: transpose_8bit_elements(src, dst, window); |