aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuTransposeKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuTransposeKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuTransposeKernel.cpp174
1 files changed, 174 insertions, 0 deletions
diff --git a/src/cpu/kernels/CpuTransposeKernel.cpp b/src/cpu/kernels/CpuTransposeKernel.cpp
index 2f981c15e4..b2cebc4230 100644
--- a/src/cpu/kernels/CpuTransposeKernel.cpp
+++ b/src/cpu/kernels/CpuTransposeKernel.cpp
@@ -50,8 +50,13 @@ unsigned int num_elems_processed(size_t element_size)
case 1:
return 8;
case 2:
+ return 4;
case 4:
+#ifdef __aarch64__
+ return 8;
+#else // __aarch64__
return 4;
+#endif // __aarch64__
default:
break;
}
@@ -311,6 +316,174 @@ void transpose_16bit_elements(const ITensor *in, ITensor *out, const Window &win
}
}
+#ifdef __aarch64__
+inline uint32x4x2_t vld1q_u32_x2_(const uint32_t *ptr)
+{
+ // gcc-7 doesn't support vld1q_u32_x2 instruction
+ return {vld1q_u32(ptr), vld1q_u32(ptr + 4)};
+}
+
+inline void vst1q_u32_x2_(const uint32_t *ptr, const uint32x4x2_t &val)
+{
+ // gcc-7 doesn't support vst1q_u32_x2 instruction
+ vst1q_u32(const_cast<uint32_t *>(ptr), val.val[0]);
+ vst1q_u32(const_cast<uint32_t *>(ptr + 4), val.val[1]);
+}
+
+void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &window)
+{
+ constexpr int window_step_x = 8;
+ constexpr int window_step_y = 8;
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_start_y = window.y().start();
+ const int window_end_y = std::min(window.y().end(), static_cast<int>(in->info()->dimension(1)));
+ const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
+ const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
+ const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
+
+ // Check if we need a left-over loop for the y dimension
+ bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
+
+ Window window_in(window);
+ window_in.set(Window::DimX, Window::Dimension(0, 1, 1));
+ if(left_over_loop_y)
+ {
+ // Check if window_end_y_multiple_of is greater than window_start_y
+ if(window_end_y_multiple_of > window_start_y)
+ {
+ window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
+ }
+ else
+ {
+ window_in.set(Window::DimY, Window::Dimension(0, 0, 1));
+ }
+ }
+
+ Window window_out(window);
+ window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
+ window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
+
+ Iterator output(out, window_out);
+
+ // Run the SIMD path if and only if the input is not a row-vector
+ if(in->info()->dimension(1) != 1)
+ {
+ Iterator input(in, window_in);
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ // Compute 8x8 elements per iteration
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Load
+ const uint32x4x2_t row0 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row1 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row2 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row3 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row4 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 4 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row5 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 5 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row6 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 6 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row7 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 7 * input_stride_in_bytes) + x);
+
+ // Transpose 2x4
+ const uint32x4x2_t k0_u32 = {vtrn1q_u32(row0.val[0], row1.val[0]), vtrn2q_u32(row0.val[0], row1.val[0])};
+ const uint32x4x2_t k1_u32 = {vtrn1q_u32(row0.val[1], row1.val[1]), vtrn2q_u32(row0.val[1], row1.val[1])};
+ const uint32x4x2_t k2_u32 = {vtrn1q_u32(row2.val[0], row3.val[0]), vtrn2q_u32(row2.val[0], row3.val[0])};
+ const uint32x4x2_t k3_u32 = {vtrn1q_u32(row2.val[1], row3.val[1]), vtrn2q_u32(row2.val[1], row3.val[1])};
+ const uint32x4x2_t k4_u32 = {vtrn1q_u32(row4.val[0], row5.val[0]), vtrn2q_u32(row4.val[0], row5.val[0])};
+ const uint32x4x2_t k5_u32 = {vtrn1q_u32(row4.val[1], row5.val[1]), vtrn2q_u32(row4.val[1], row5.val[1])};
+ const uint32x4x2_t k6_u32 = {vtrn1q_u32(row6.val[0], row7.val[0]), vtrn2q_u32(row6.val[0], row7.val[0])};
+ const uint32x4x2_t k7_u32 = {vtrn1q_u32(row6.val[1], row7.val[1]), vtrn2q_u32(row6.val[1], row7.val[1])};
+
+ // Transpose 2x2
+ const uint64x2x2_t k0_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0]))};
+ const uint64x2x2_t k1_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1]))};
+ const uint64x2x2_t k2_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0]))};
+ const uint64x2x2_t k3_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1]))};
+ const uint64x2x2_t k4_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0]))};
+ const uint64x2x2_t k5_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1]))};
+ const uint64x2x2_t k6_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0]))};
+ const uint64x2x2_t k7_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1]))};
+
+ // Swap blocks
+ const uint32x4x2_t col0 = {vreinterpretq_u32_u64(k0_u64.val[0]), vreinterpretq_u32_u64(k4_u64.val[0])};
+ const uint32x4x2_t col1 = {vreinterpretq_u32_u64(k1_u64.val[0]), vreinterpretq_u32_u64(k5_u64.val[0])};
+ const uint32x4x2_t col2 = {vreinterpretq_u32_u64(k0_u64.val[1]), vreinterpretq_u32_u64(k4_u64.val[1])};
+ const uint32x4x2_t col3 = {vreinterpretq_u32_u64(k1_u64.val[1]), vreinterpretq_u32_u64(k5_u64.val[1])};
+ const uint32x4x2_t col4 = {vreinterpretq_u32_u64(k2_u64.val[0]), vreinterpretq_u32_u64(k6_u64.val[0])};
+ const uint32x4x2_t col5 = {vreinterpretq_u32_u64(k3_u64.val[0]), vreinterpretq_u32_u64(k7_u64.val[0])};
+ const uint32x4x2_t col6 = {vreinterpretq_u32_u64(k2_u64.val[1]), vreinterpretq_u32_u64(k6_u64.val[1])};
+ const uint32x4x2_t col7 = {vreinterpretq_u32_u64(k3_u64.val[1]), vreinterpretq_u32_u64(k7_u64.val[1])};
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
+
+ // Store
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), col0);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), col1);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), col2);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), col3);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), col4);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), col5);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), col6);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), col7);
+ }
+
+ // Compute left-over elements (8x1)
+ for(; x < window_end_x; ++x)
+ {
+ const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint32_t val1 = *(reinterpret_cast<uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint32_t val2 = *(reinterpret_cast<uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint32_t val3 = *(reinterpret_cast<uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+ const uint32_t val4 = *(reinterpret_cast<uint32_t *>(input.ptr() + 4 * input_stride_in_bytes) + x);
+ const uint32_t val5 = *(reinterpret_cast<uint32_t *>(input.ptr() + 5 * input_stride_in_bytes) + x);
+ const uint32_t val6 = *(reinterpret_cast<uint32_t *>(input.ptr() + 6 * input_stride_in_bytes) + x);
+ const uint32_t val7 = *(reinterpret_cast<uint32_t *>(input.ptr() + 7 * input_stride_in_bytes) + x);
+
+ uint32x4_t result0 = vdupq_n_u32(0);
+ uint32x4_t result1 = vdupq_n_u32(0);
+ result0 = vsetq_lane_u32(val0, result0, 0);
+ result0 = vsetq_lane_u32(val1, result0, 1);
+ result0 = vsetq_lane_u32(val2, result0, 2);
+ result0 = vsetq_lane_u32(val3, result0, 3);
+ result1 = vsetq_lane_u32(val4, result1, 0);
+ result1 = vsetq_lane_u32(val5, result1, 1);
+ result1 = vsetq_lane_u32(val6, result1, 2);
+ result1 = vsetq_lane_u32(val7, result1, 3);
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
+
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes), {result0, result1});
+ }
+ },
+ input, output);
+ }
+
+ if(left_over_loop_y)
+ {
+ window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
+ window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
+
+ Iterator input(in, window_in);
+ Iterator output(out, window_out);
+
+ // Compute left-over elements along the y dimension (1x1)
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr()));
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes;
+
+ *(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes)) = val0;
+ },
+ input, output);
+ }
+}
+#else // __aarch64__
void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &window)
{
const int window_step_x = 4;
@@ -422,6 +595,7 @@ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &win
input, output);
}
}
+#endif // __aarch64__
} // namespace
void CpuTransposeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)