aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/NEON/kernels/NETransposeKernel.cpp436
-rw-r--r--tests/validation/NEON/Transpose.cpp57
2 files changed, 383 insertions, 110 deletions
diff --git a/src/core/NEON/kernels/NETransposeKernel.cpp b/src/core/NEON/kernels/NETransposeKernel.cpp
index e84beeeb36..fc22b05823 100644
--- a/src/core/NEON/kernels/NETransposeKernel.cpp
+++ b/src/core/NEON/kernels/NETransposeKernel.cpp
@@ -91,20 +91,23 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output)
{
- const unsigned int num_elems_processed_per_iteration = num_elems_processed(input->element_size());
+ // Note: This kernel performs 16 elements per iteration.
+ // However, since we use a left-over for loop on both dimensions (X and Y), we cannot have any read or write out of memory
+ // For this reason num_elems_processed_per_iteration_x is set to 1
+ const unsigned int num_elems_processed_per_iteration_x = 1;
+ const unsigned int num_elems_processed_per_iteration_y = num_elems_processed(input->element_size());
// Configure kernel window
- Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration, num_elems_processed_per_iteration));
+ Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- AccessWindowRectangle input_access(input, 0, 0, num_elems_processed_per_iteration, num_elems_processed_per_iteration);
+ AccessWindowStatic input_access(input, 0, 0, input->dimension(0), input->dimension(1));
bool window_changed = update_window_and_padding(win, input_access);
if(output->total_size() != 0)
{
// TODO (COMPMID-708): Replace AccessWindowStatic with AccessWindowTranspose
- AccessWindowStatic output_access(output, 0, 0, ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration), ceil_to_multiple(output->dimension(1),
- num_elems_processed_per_iteration));
+ AccessWindowStatic output_access(output, 0, 0, output->dimension(0), output->dimension(1));
window_changed = window_changed || update_window_and_padding(win, output_access);
@@ -117,133 +120,366 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
void transpose_8bit_elements(const ITensor *in, ITensor *out, const Window &window)
{
+ const int window_step_x = 8;
+ const int window_step_y = 8;
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_start_y = window.y().start();
+ const int window_end_y = std::min(window.y().end(), static_cast<int>(in->info()->dimension(1)));
+ const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
+ const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
+ const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
+
+ // Check if we need a left-over loop for the y dimension
+ bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
+
+ Window window_in(window);
+ window_in.set(Window::DimX, Window::Dimension(0, 1, 1));
+ if(left_over_loop_y)
+ {
+ // Check if window_end_y_multiple_of is greater than window_start_y
+ if(window_end_y_multiple_of > window_start_y)
+ {
+ window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
+ }
+ else
+ {
+ window_in.set(Window::DimY, Window::Dimension(0, 0, 1));
+ }
+ }
+
Window window_out(window);
window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
- Iterator input(in, window);
Iterator output(out, window_out);
- const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
- const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
+ // Run the NEON path if and only if the input is not a row-vector
+ if(in->info()->dimension(1) != 1)
+ {
+ Iterator input(in, window_in);
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ // Compute 8x8 elements per iteration
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const uint8x8_t row0 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 0 * input_stride_in_bytes));
+ const uint8x8_t row1 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 1 * input_stride_in_bytes));
+ const uint8x8_t row2 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 2 * input_stride_in_bytes));
+ const uint8x8_t row3 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 3 * input_stride_in_bytes));
+ const uint8x8_t row4 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 4 * input_stride_in_bytes));
+ const uint8x8_t row5 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 5 * input_stride_in_bytes));
+ const uint8x8_t row6 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 6 * input_stride_in_bytes));
+ const uint8x8_t row7 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 7 * input_stride_in_bytes));
+
+ // Transpose 2x2
+ const uint8x8x2_t k0_u8 = vtrn_u8(row0, row1);
+ const uint8x8x2_t k1_u8 = vtrn_u8(row2, row3);
+ const uint8x8x2_t k2_u8 = vtrn_u8(row4, row5);
+ const uint8x8x2_t k3_u8 = vtrn_u8(row6, row7);
+
+ // Transpose 4x4
+ const uint16x4x2_t k0_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[0]), vreinterpret_u16_u8(k1_u8.val[0]));
+ const uint16x4x2_t k1_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[1]), vreinterpret_u16_u8(k1_u8.val[1]));
+ const uint16x4x2_t k2_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[0]), vreinterpret_u16_u8(k3_u8.val[0]));
+ const uint16x4x2_t k3_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[1]), vreinterpret_u16_u8(k3_u8.val[1]));
+
+ // Transpose 8x8
+ const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k2_u16.val[0]));
+ const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k2_u16.val[1]));
+ const uint32x2x2_t k2_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[0]), vreinterpret_u32_u16(k3_u16.val[0]));
+ const uint32x2x2_t k3_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[1]), vreinterpret_u32_u16(k3_u16.val[1]));
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + x * output_stride_in_bytes;
+
+ vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[0])));
+ vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[0])));
+ vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[0])));
+ vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[0])));
+ vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[1])));
+ vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[1])));
+ vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[1])));
+ vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[1])));
+ }
+
+ // Compute left-over elements along the x dimension (1x8)
+ for(; x < window_end_x; ++x)
+ {
+ const uint8_t val0 = *(input.ptr() + x + 0 * input_stride_in_bytes);
+ const uint8_t val1 = *(input.ptr() + x + 1 * input_stride_in_bytes);
+ const uint8_t val2 = *(input.ptr() + x + 2 * input_stride_in_bytes);
+ const uint8_t val3 = *(input.ptr() + x + 3 * input_stride_in_bytes);
+ const uint8_t val4 = *(input.ptr() + x + 4 * input_stride_in_bytes);
+ const uint8_t val5 = *(input.ptr() + x + 5 * input_stride_in_bytes);
+ const uint8_t val6 = *(input.ptr() + x + 6 * input_stride_in_bytes);
+ const uint8_t val7 = *(input.ptr() + x + 7 * input_stride_in_bytes);
+
+ uint8x8_t result = vdup_n_u8(0);
+ result = vset_lane_u8(val0, result, 0);
+ result = vset_lane_u8(val1, result, 1);
+ result = vset_lane_u8(val2, result, 2);
+ result = vset_lane_u8(val3, result, 3);
+ result = vset_lane_u8(val4, result, 4);
+ result = vset_lane_u8(val5, result, 5);
+ result = vset_lane_u8(val6, result, 6);
+ result = vset_lane_u8(val7, result, 7);
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + x * output_stride_in_bytes;
+
+ vst1_u8(output.ptr() + dst_offset_in_bytes, result);
+ }
+ },
+ input, output);
+ }
- execute_window_loop(window, [&](const Coordinates & id)
+ if(left_over_loop_y)
{
- const uint8x8_t row0 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 0 * input_stride_in_bytes));
- const uint8x8_t row1 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 1 * input_stride_in_bytes));
- const uint8x8_t row2 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 2 * input_stride_in_bytes));
- const uint8x8_t row3 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 3 * input_stride_in_bytes));
- const uint8x8_t row4 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 4 * input_stride_in_bytes));
- const uint8x8_t row5 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 5 * input_stride_in_bytes));
- const uint8x8_t row6 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 6 * input_stride_in_bytes));
- const uint8x8_t row7 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 7 * input_stride_in_bytes));
-
- // Transpose 2x2
- const uint8x8x2_t k0_u8 = vtrn_u8(row0, row1);
- const uint8x8x2_t k1_u8 = vtrn_u8(row2, row3);
- const uint8x8x2_t k2_u8 = vtrn_u8(row4, row5);
- const uint8x8x2_t k3_u8 = vtrn_u8(row6, row7);
-
- // Transpose 4x4
- const uint16x4x2_t k0_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[0]), vreinterpret_u16_u8(k1_u8.val[0]));
- const uint16x4x2_t k1_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[1]), vreinterpret_u16_u8(k1_u8.val[1]));
- const uint16x4x2_t k2_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[0]), vreinterpret_u16_u8(k3_u8.val[0]));
- const uint16x4x2_t k3_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[1]), vreinterpret_u16_u8(k3_u8.val[1]));
-
- // Transpose 8x8
- const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k2_u16.val[0]));
- const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k2_u16.val[1]));
- const uint32x2x2_t k2_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[0]), vreinterpret_u32_u16(k3_u16.val[0]));
- const uint32x2x2_t k3_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[1]), vreinterpret_u32_u16(k3_u16.val[1]));
-
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + id.x() * output_stride_in_bytes;
-
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[0])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[0])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[0])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[0])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[1])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[1])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[1])));
- vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[1])));
- },
- input, output);
+ window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
+ window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
+
+ Iterator input(in, window_in);
+ Iterator output(out, window_out);
+
+ // Compute left-over elements along the y dimension (1x1)
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ const uint8_t val0 = *input.ptr();
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + id.x() * output_stride_in_bytes;
+
+ *(output.ptr() + dst_offset_in_bytes) = val0;
+ },
+ input, output);
+ }
}
void transpose_16bit_elements(const ITensor *in, ITensor *out, const Window &window)
{
+ const int window_step_x = 4;
+ const int window_step_y = 4;
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_start_y = window.y().start();
+ const int window_end_y = std::min(window.y().end(), static_cast<int>(in->info()->dimension(1)));
+ const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
+ const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
+ const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
+
+ // Check if we need a left-over loop for the y dimension
+ bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
+
+ Window window_in(window);
+ window_in.set(Window::DimX, Window::Dimension(0, 1, 1));
+ if(left_over_loop_y)
+ {
+ // Check if window_end_y_multiple_of is greater than window_start_y
+ if(window_end_y_multiple_of > window_start_y)
+ {
+ window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
+ }
+ else
+ {
+ window_in.set(Window::DimY, Window::Dimension(0, 0, 1));
+ }
+ }
+
Window window_out(window);
window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
- Iterator input(in, window);
Iterator output(out, window_out);
- const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
- const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
+ // Run the NEON path if and only if the input is not a row-vector
+ if(in->info()->dimension(1) != 1)
+ {
+ Iterator input(in, window_in);
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ // Compute 4x4 elements per iteration
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const uint16x4_t row0 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint16x4_t row1 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint16x4_t row2 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint16x4_t row3 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+
+ // Transpose 2x2
+ const uint16x4x2_t k0_u16 = vtrn_u16(row0, row1);
+ const uint16x4x2_t k1_u16 = vtrn_u16(row2, row3);
+
+ // Transpose 4x4
+ const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k1_u16.val[0]));
+ const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k1_u16.val[1]));
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + x * output_stride_in_bytes;
+
+ vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[0]));
+ vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[0]));
+ vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[1]));
+ vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[1]));
+ }
+
+ // Compute left-over elements (1x4)
+ for(; x < window_end_x; ++x)
+ {
+ const uint16_t val0 = *(reinterpret_cast<uint16_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint16_t val1 = *(reinterpret_cast<uint16_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint16_t val2 = *(reinterpret_cast<uint16_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint16_t val3 = *(reinterpret_cast<uint16_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+
+ uint16x4_t result = vdup_n_u16(0);
+ result = vset_lane_u16(val0, result, 0);
+ result = vset_lane_u16(val1, result, 1);
+ result = vset_lane_u16(val2, result, 2);
+ result = vset_lane_u16(val3, result, 3);
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + x * output_stride_in_bytes;
+
+ vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes), result);
+ }
+ },
+ input, output);
+ }
- execute_window_loop(window, [&](const Coordinates & id)
+ if(left_over_loop_y)
{
- const uint16x4_t row0 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 0 * input_stride_in_bytes));
- const uint16x4_t row1 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 1 * input_stride_in_bytes));
- const uint16x4_t row2 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 2 * input_stride_in_bytes));
- const uint16x4_t row3 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 3 * input_stride_in_bytes));
-
- // Transpose 2x2
- const uint16x4x2_t k0_u16 = vtrn_u16(row0, row1);
- const uint16x4x2_t k1_u16 = vtrn_u16(row2, row3);
-
- // Transpose 4x4
- const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k1_u16.val[0]));
- const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k1_u16.val[1]));
-
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + id.x() * output_stride_in_bytes;
-
- vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[0]));
- vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[0]));
- vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[1]));
- vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[1]));
- },
- input, output);
+ window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
+ window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
+
+ Iterator input(in, window_in);
+ Iterator output(out, window_out);
+
+ // Compute left-over elements along the y dimension (1x1)
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ const uint16_t val0 = *(reinterpret_cast<uint16_t *>(input.ptr()));
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + id.x() * output_stride_in_bytes;
+
+ *(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes)) = val0;
+ },
+ input, output);
+ }
}
void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &window)
{
+ const int window_step_x = 4;
+ const int window_step_y = 4;
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_start_y = window.y().start();
+ const int window_end_y = std::min(window.y().end(), static_cast<int>(in->info()->dimension(1)));
+ const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
+ const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
+ const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
+
+ // Check if we need a left-over loop for the y dimension
+ bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
+
+ Window window_in(window);
+ window_in.set(Window::DimX, Window::Dimension(0, 1, 1));
+ if(left_over_loop_y)
+ {
+ // Check if window_end_y_multiple_of is greater than window_start_y
+ if(window_end_y_multiple_of > window_start_y)
+ {
+ window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
+ }
+ else
+ {
+ window_in.set(Window::DimY, Window::Dimension(0, 0, 1));
+ }
+ }
+
Window window_out(window);
window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
- Iterator input(in, window);
Iterator output(out, window_out);
- const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
- const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
+ // Run the NEON path if and only if the input is not a row-vector
+ if(in->info()->dimension(1) != 1)
+ {
+ Iterator input(in, window_in);
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ // Compute 4x4 elements per iteration
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const uint32x4_t row0 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint32x4_t row1 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint32x4_t row2 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint32x4_t row3 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+
+ // Transpose 2x2
+ const uint32x2x2_t k0_u32 = vtrn_u32(vget_low_u32(row0), vget_low_u32(row1));
+ const uint32x2x2_t k1_u32 = vtrn_u32(vget_high_u32(row2), vget_high_u32(row3));
+ const uint32x2x2_t k2_u32 = vtrn_u32(vget_high_u32(row0), vget_high_u32(row1));
+ const uint32x2x2_t k3_u32 = vtrn_u32(vget_low_u32(row2), vget_low_u32(row3));
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
+
+ // Swap block 01 with block 10 and store
+ vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vcombine_u32(k0_u32.val[0], k3_u32.val[0]));
+ vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vcombine_u32(k0_u32.val[1], k3_u32.val[1]));
+ vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vcombine_u32(k2_u32.val[0], k1_u32.val[0]));
+ vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vcombine_u32(k2_u32.val[1], k1_u32.val[1]));
+ }
+
+ // Compute left-over elements (1x4)
+ for(; x < window_end_x; ++x)
+ {
+ const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
+ const uint32_t val1 = *(reinterpret_cast<uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
+ const uint32_t val2 = *(reinterpret_cast<uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
+ const uint32_t val3 = *(reinterpret_cast<uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
+
+ uint32x4_t result = vdupq_n_u32(0);
+ result = vsetq_lane_u32(val0, result, 0);
+ result = vsetq_lane_u32(val1, result, 1);
+ result = vsetq_lane_u32(val2, result, 2);
+ result = vsetq_lane_u32(val3, result, 3);
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
+
+ vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes), result);
+ }
+ },
+ input, output);
+ }
- execute_window_loop(window, [&](const Coordinates & id)
+ if(left_over_loop_y)
{
- const uint32x4_t row0 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes));
- const uint32x4_t row1 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes));
- const uint32x4_t row2 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes));
- const uint32x4_t row3 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes));
-
- // Transpose 2x2
- const uint32x2x2_t k0_u32 = vtrn_u32(vget_low_u32(row0), vget_low_u32(row1));
- const uint32x2x2_t k1_u32 = vtrn_u32(vget_high_u32(row2), vget_high_u32(row3));
- const uint32x2x2_t k2_u32 = vtrn_u32(vget_high_u32(row0), vget_high_u32(row1));
- const uint32x2x2_t k3_u32 = vtrn_u32(vget_low_u32(row2), vget_low_u32(row3));
-
- // Compute destination address
- const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes;
-
- // Swap block 01 with block 10 and store
- vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vcombine_u32(k0_u32.val[0], k3_u32.val[0]));
- vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vcombine_u32(k0_u32.val[1], k3_u32.val[1]));
- vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vcombine_u32(k2_u32.val[0], k1_u32.val[0]));
- vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vcombine_u32(k2_u32.val[1], k1_u32.val[1]));
- },
- input, output);
+ window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
+ window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
+
+ Iterator input(in, window_in);
+ Iterator output(out, window_out);
+
+ // Compute left-over elements along the y dimension (1x1)
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr()));
+
+ // Compute destination address
+ const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes;
+
+ *(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes)) = val0;
+ },
+ input, output);
+ }
}
} // namespace
diff --git a/tests/validation/NEON/Transpose.cpp b/tests/validation/NEON/Transpose.cpp
index d0dbd4fd3f..f2ef7162a2 100644
--- a/tests/validation/NEON/Transpose.cpp
+++ b/tests/validation/NEON/Transpose.cpp
@@ -43,6 +43,34 @@ namespace validation
TEST_SUITE(NEON)
TEST_SUITE(Transpose)
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(
+ framework::dataset::make("InputInfo", { TensorInfo(TensorShape(21U, 13U), 1, DataType::U8), // Input not a multiple of 8
+ TensorInfo(TensorShape(21U, 13U), 1, DataType::U16), // Invalid shape
+ TensorInfo(TensorShape(20U, 13U), 1, DataType::U32),
+ TensorInfo(TensorShape(20U, 13U), 1, DataType::U8), // Wrong data type
+ TensorInfo(TensorShape(20U, 13U), 1, DataType::U16),
+ TensorInfo(TensorShape(20U, 13U), 1, DataType::U32),
+ }),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(13U, 21U), 1, DataType::U8),
+ TensorInfo(TensorShape(21U, 13U), 1, DataType::U16),
+ TensorInfo(TensorShape(13U, 20U), 1, DataType::U32),
+ TensorInfo(TensorShape(31U, 20U), 1, DataType::U16),
+ TensorInfo(TensorShape(13U, 20U), 1, DataType::U16),
+ TensorInfo(TensorShape(13U, 20U), 1, DataType::U32),
+ })),
+ framework::dataset::make("Expected", { true, false, true, false, true, true })),
+ a_info, output_info, expected)
+{
+ // Lock tensors
+ Status status = NETranspose::validate(&a_info.clone()->set_is_resizable(false),
+ &output_info.clone()->set_is_resizable(false));
+ ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
+
DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(concat(datasets::Small2DShapes(), datasets::Large2DShapes()), framework::dataset::make("DataType", { DataType::S8, DataType::U8, DataType::S16, DataType::U16, DataType::U32, DataType::S32, DataType::F16, DataType::F32 })),
shape, data_type)
{
@@ -50,30 +78,35 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(concat(datase
TensorShape output_shape{ shape[1], shape[0] };
// Create tensors
- Tensor ref_src = create_tensor<Tensor>(shape, data_type);
- Tensor dst = create_tensor<Tensor>(output_shape, data_type);
+ Tensor src = create_tensor<Tensor>(shape, data_type);
+ Tensor dst = create_tensor<Tensor>(output_shape, data_type);
// Create and Configure function
NETranspose trans;
- trans.configure(&ref_src, &dst);
+ trans.configure(&src, &dst);
// Validate valid region
const ValidRegion valid_region = shape_to_valid_region(output_shape);
validate(dst.info()->valid_region(), valid_region);
- // TODO(bsgcomp): Add padding validation (COMPMID-659)
+ // Validate padding
+ const PaddingSize padding(0, 0);
+ validate(src.info()->padding(), padding);
+ validate(dst.info()->padding(), padding);
}
template <typename T>
using NETransposeFixture = TransposeValidationFixture<Tensor, Accessor, NETranspose, T>;
TEST_SUITE(U8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NETransposeFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small2DShapes(), framework::dataset::make("DataType", DataType::U8)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NETransposeFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(concat(datasets::Small1DShapes(), datasets::Small2DShapes()),
+ framework::dataset::make("DataType", DataType::U8)))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NETransposeFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(datasets::Large2DShapes(), framework::dataset::make("DataType", DataType::U8)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NETransposeFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(concat(datasets::Large1DShapes(), datasets::Large2DShapes()),
+ framework::dataset::make("DataType", DataType::U8)))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -81,12 +114,14 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NETransposeFixture<uint8_t>, framework::Dataset
TEST_SUITE_END()
TEST_SUITE(U16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NETransposeFixture<uint16_t>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small2DShapes(), framework::dataset::make("DataType", DataType::U16)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NETransposeFixture<uint16_t>, framework::DatasetMode::PRECOMMIT, combine(concat(datasets::Small1DShapes(), datasets::Small2DShapes()),
+ framework::dataset::make("DataType", DataType::U16)))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NETransposeFixture<uint16_t>, framework::DatasetMode::NIGHTLY, combine(datasets::Large2DShapes(), framework::dataset::make("DataType", DataType::U16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NETransposeFixture<uint16_t>, framework::DatasetMode::NIGHTLY, combine(concat(datasets::Large1DShapes(), datasets::Large2DShapes()),
+ framework::dataset::make("DataType", DataType::U16)))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -94,12 +129,14 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NETransposeFixture<uint16_t>, framework::Datase
TEST_SUITE_END()
TEST_SUITE(U32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NETransposeFixture<uint32_t>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small2DShapes(), framework::dataset::make("DataType", DataType::U32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NETransposeFixture<uint32_t>, framework::DatasetMode::PRECOMMIT, combine(concat(datasets::Small1DShapes(), datasets::Small2DShapes()),
+ framework::dataset::make("DataType", DataType::U32)))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NETransposeFixture<uint32_t>, framework::DatasetMode::NIGHTLY, combine(datasets::Large2DShapes(), framework::dataset::make("DataType", DataType::U32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NETransposeFixture<uint32_t>, framework::DatasetMode::NIGHTLY, combine(concat(datasets::Large1DShapes(), datasets::Large2DShapes()),
+ framework::dataset::make("DataType", DataType::U32)))
{
// Validate output
validate(Accessor(_target), _reference);