aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEthan Doe <yidoe@amazon.com>2023-04-14 17:24:33 +0000
committerPablo Marquez Tello <pablo.tello@arm.com>2023-04-19 14:26:25 +0000
commita07c01b6cad1fa37f98a05f08019b40bd4303a92 (patch)
treee543988368a7526877a434a27f4dccf6562a6e6f
parent9c7c2d2d23693877867bb3284c577b33cfbff471 (diff)
downloadComputeLibrary-a07c01b6cad1fa37f98a05f08019b40bd4303a92.tar.gz
NETranspose 8x8 kernel for 32-bit elements
The existing 4x4 tiling for 32-bit transpose is not efficient on aarch64, given that there are a lot more Neon registers available. So making the tile size to 8x8 will greatly improve NETranspose latency. For example, on AWS Graviton3 processors, with this change I have observed transposing a 768x768 matrix improves latency from 0.32ms down to 0.19ms. Improvement can also be seen across different matrix sizes. Further enlarging the tile size to 8x16 or 16x16 won't make it perform as good as 8x8 due to register pressure. This change is to mitigate the issue reported at: https://github.com/ARM-software/ComputeLibrary/issues/1045 Signed-off-by: Ethan Doe <yidoe@amazon.com> Change-Id: Ia09859cdf2f6d312e67219a9d95a3a3bf1db1999 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9448 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
-rw-r--r--src/cpu/kernels/CpuTransposeKernel.cpp174
1 files changed, 174 insertions, 0 deletions
diff --git a/src/cpu/kernels/CpuTransposeKernel.cpp b/src/cpu/kernels/CpuTransposeKernel.cpp
index 2f981c15e4..b2cebc4230 100644
--- a/src/cpu/kernels/CpuTransposeKernel.cpp
+++ b/src/cpu/kernels/CpuTransposeKernel.cpp
@@ -50,8 +50,13 @@ unsigned int num_elems_processed(size_t element_size)
case 1:
return 8;
case 2:
+ return 4;
case 4:
+#ifdef __aarch64__
+ return 8;
+#else // __aarch64__
return 4;
+#endif // __aarch64__
default:
break;
}
@@ -311,6 +316,174 @@ void transpose_16bit_elements(const ITensor *in, ITensor *out, const Window &win
}
}
+#ifdef __aarch64__
+inline uint32x4x2_t vld1q_u32_x2_(const uint32_t *ptr)
+{
+ // gcc-7 doesn't support vld1q_u32_x2 instruction
+ return {vld1q_u32(ptr), vld1q_u32(ptr + 4)};
+}
+
+inline void vst1q_u32_x2_(const uint32_t *ptr, const uint32x4x2_t &val)
+{
+ // gcc-7 doesn't support vst1q_u32_x2 instruction
+ vst1q_u32(const_cast<uint32_t *>(ptr), val.val[0]);
+ vst1q_u32(const_cast<uint32_t *>(ptr + 4), val.val[1]);
+}
+
+void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &window)
+{
+ constexpr int window_step_x = 8;
+ constexpr int window_step_y = 8;
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_start_y = window.y().start();
+ const int window_end_y = std::min(window.y().end(), static_cast<int>(in->info()->dimension(1)));
+ const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
+ const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
+ const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
+
+ // Check if we need a left-over loop for the y dimension
+ bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
+
+ Window window_in(window);
+ window_in.set(Window::DimX, Window::Dimension(0, 1, 1));
+ if(left_over_loop_y)
+ {
+ // Check if window_end_y_multiple_of is greater than window_start_y
+ if(window_end_y_multiple_of > window_start_y)
+ {
+ window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
+ }
+ else
+ {
+ window_in.set(Window::DimY, Window::Dimension(0, 0, 1));
+ }
+ }
+
+ Window window_out(window);
+ window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
+ window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
+
+ Iterator output(out, window_out);
+
+ // Run the SIMD path if and only if the input is not a row-vector
+ if(in->info()->dimension(1) != 1)
+ {
+ Iterator input(in, window_in);
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ // Compute 8x8 elements per iteration
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Load
+ const uint32x4x2_t row0 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row1 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row2 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row3 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row4 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 4 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row5 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 5 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row6 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 6 * input_stride_in_bytes) + x);
+ const uint32x4x2_t row7 = vld1q_u32_x2_(reinterpret_cast<const uint32_t *>(input.ptr() + 7 * input_stride_in_bytes) + x);
+
+ // Transpose 2x4
+ const uint32x4x2_t k0_u32 = {vtrn1q_u32(row0.val[0], row1.val[0]), vtrn2q_u32(row0.val[0], row1.val[0])};
+ const uint32x4x2_t k1_u32 = {vtrn1q_u32(row0.val[1], row1.val[1]), vtrn2q_u32(row0.val[1], row1.val[1])};
+ const uint32x4x2_t k2_u32 = {vtrn1q_u32(row2.val[0], row3.val[0]), vtrn2q_u32(row2.val[0], row3.val[0])};
+ const uint32x4x2_t k3_u32 = {vtrn1q_u32(row2.val[1], row3.val[1]), vtrn2q_u32(row2.val[1], row3.val[1])};
+ const uint32x4x2_t k4_u32 = {vtrn1q_u32(row4.val[0], row5.val[0]), vtrn2q_u32(row4.val[0], row5.val[0])};
+ const uint32x4x2_t k5_u32 = {vtrn1q_u32(row4.val[1], row5.val[1]), vtrn2q_u32(row4.val[1], row5.val[1])};
+ const uint32x4x2_t k6_u32 = {vtrn1q_u32(row6.val[0], row7.val[0]), vtrn2q_u32(row6.val[0], row7.val[0])};
+ const uint32x4x2_t k7_u32 = {vtrn1q_u32(row6.val[1], row7.val[1]), vtrn2q_u32(row6.val[1], row7.val[1])};
+
+ // Transpose 2x2
+ const uint64x2x2_t k0_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[0]), vreinterpretq_u64_u32(k2_u32.val[0]))};
+ const uint64x2x2_t k1_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k0_u32.val[1]), vreinterpretq_u64_u32(k2_u32.val[1]))};
+ const uint64x2x2_t k2_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[0]), vreinterpretq_u64_u32(k3_u32.val[0]))};
+ const uint64x2x2_t k3_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k1_u32.val[1]), vreinterpretq_u64_u32(k3_u32.val[1]))};
+ const uint64x2x2_t k4_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[0]), vreinterpretq_u64_u32(k6_u32.val[0]))};
+ const uint64x2x2_t k5_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k4_u32.val[1]), vreinterpretq_u64_u32(k6_u32.val[1]))};
+ const uint64x2x2_t k6_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0])), vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[0]), vreinterpretq_u64_u32(k7_u32.val[0]))};
+ const uint64x2x2_t k7_u64 = {vtrn1q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1])), vtrn2q_u64(vreinterpretq_u64_u32(k5_u32.val[1]), vreinterpretq_u64_u32(k7_u32.val[1]))};
+
+ // Swap blocks
+ const uint32x4x2_t col0 = {vreinterpretq_u32_u64(k0_u64.val[0]), vreinterpretq_u32_u64(k4_u64.val[0])};
+ const uint32x4x2_t col1 = {vreinterpretq_u32_u64(k1_u64.val[0]), vreinterpretq_u32_u64(k5_u64.val[0])};
+ const uint32x4x2_t col2 = {vreinterpretq_u32_u64(k0_u64.val[1]), vreinterpretq_u32_u64(k4_u64.val[1])};
+ const uint32x4x2_t col3 = {vreinterpretq_u32_u64(k1_u64.val[1]), vreinterpretq_u32_u64(k5_u64.val[1])};
+ const uint32x4x2_t col4 = {vreinterpretq_u32_u64(k2_u64.val[0]), vreinterpretq_u32_u64(k6_u64.val[0])};
+ const uint32x4x2_t col5 = {vreinterpretq_u32_u64(k3_u64.val[0]), vreinterpretq_u32_u64(k7_u64.val[0])};
+ const uint32x4x2_t col6 = {vreinterpretq_u32_u64(k2_u64.val[1]), vreinterpretq_u32_u64(k6_u64.val[1])};
+ const uint32x4x2_t col7 = {vreinterpretq_u32_u64(k3_u64.val[1]), vreinterpretq_u32_u64(k7_u64.val[1])};
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
+
+ // Store
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), col0);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), col1);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), col2);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), col3);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), col4);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), col5);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), col6);
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), col7);
+ }
+
+ // Compute left-over elements (8x1)
+ for(; x < window_end_x; ++x)
+ {
+ const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint32_t val1 = *(reinterpret_cast<uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint32_t val2 = *(reinterpret_cast<uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint32_t val3 = *(reinterpret_cast<uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+ const uint32_t val4 = *(reinterpret_cast<uint32_t *>(input.ptr() + 4 * input_stride_in_bytes) + x);
+ const uint32_t val5 = *(reinterpret_cast<uint32_t *>(input.ptr() + 5 * input_stride_in_bytes) + x);
+ const uint32_t val6 = *(reinterpret_cast<uint32_t *>(input.ptr() + 6 * input_stride_in_bytes) + x);
+ const uint32_t val7 = *(reinterpret_cast<uint32_t *>(input.ptr() + 7 * input_stride_in_bytes) + x);
+
+ uint32x4_t result0 = vdupq_n_u32(0);
+ uint32x4_t result1 = vdupq_n_u32(0);
+ result0 = vsetq_lane_u32(val0, result0, 0);
+ result0 = vsetq_lane_u32(val1, result0, 1);
+ result0 = vsetq_lane_u32(val2, result0, 2);
+ result0 = vsetq_lane_u32(val3, result0, 3);
+ result1 = vsetq_lane_u32(val4, result1, 0);
+ result1 = vsetq_lane_u32(val5, result1, 1);
+ result1 = vsetq_lane_u32(val6, result1, 2);
+ result1 = vsetq_lane_u32(val7, result1, 3);
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
+
+ vst1q_u32_x2_(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes), {result0, result1});
+ }
+ },
+ input, output);
+ }
+
+ if(left_over_loop_y)
+ {
+ window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
+ window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
+
+ Iterator input(in, window_in);
+ Iterator output(out, window_out);
+
+ // Compute left-over elements along the y dimension (1x1)
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr()));
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes;
+
+ *(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes)) = val0;
+ },
+ input, output);
+ }
+}
+#else // __aarch64__
void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &window)
{
const int window_step_x = 4;
@@ -422,6 +595,7 @@ void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &win
input, output);
}
}
+#endif // __aarch64__
} // namespace
void CpuTransposeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)