aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/convolver.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/convolver.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/convolver.hpp88
1 files changed, 76 insertions, 12 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/convolver.hpp b/src/core/NEON/kernels/arm_gemm/convolver.hpp
index 879d95f5bb..b15f669132 100644
--- a/src/core/NEON/kernels/arm_gemm/convolver.hpp
+++ b/src/core/NEON/kernels/arm_gemm/convolver.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Arm Limited.
+ * Copyright (c) 2020,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -103,11 +103,15 @@ private:
return (m_length_remaining == 0);
}
+ // Compute a block of output pointers, accounting for padding.
+ // This is performance critical.
std::tuple<unsigned int, unsigned int> next_block(const T ** const row_ptr) {
if (finished()) {
return std::make_tuple(0, 0);
}
+ const T *pad_ptr = m_convolver.m_pad_row.data();
+
// "in_width" in the amount of data that will be read in (copied)
// "out_width" is the total amount of data that will be produced (including padding)
unsigned int offset = (m_current_pos == m_parent.m_start_pos) ? m_parent.m_start_offset : 0;
@@ -117,23 +121,83 @@ private:
unsigned int output_y = m_start_output_y;
unsigned int output_x = m_start_output_x;
- for (unsigned int row=0; row<m_active_height; row++) {
+ // Loop over "row" (output points), but really there is one
+ // trip through this outer loop per row of output to
+ // minimize redundant padding calculations.
+ unsigned int row=0;
+ while (row < m_active_height) {
int input_y = (output_y * m_convolver.m_params.output_stride_h) + m_convolver.m_kernel_y[m_current_pos];
int input_x = (output_x * m_convolver.m_params.output_stride_w) + m_convolver.m_kernel_x[m_current_pos];
- // Out-of-bounds points will read the padding data,
- // otherwise find the correct address in the input image.
- if (input_y < 0 || input_y >= m_convolver.m_params.input_height || input_x < 0 || input_x >= m_convolver.m_params.input_width) {
- row_ptr[row] = m_convolver.m_pad_row.data();
- } else {
- row_ptr[row] = m_parent.m_input_base + ((input_y * m_convolver.m_params.input_width) + input_x) * m_parent.m_input_stride;
+ // Factor out base pointer computation.
+ const T *base_ptr = m_parent.m_input_base +
+ (input_y * m_convolver.m_params.input_width * m_parent.m_input_stride);
+
+ // To start with, check the input row is in-bounds. If
+ // not, (at least) this entire output row must be
+ // padding so handle accordingly.
+
+ // If input_y is off the bottom of the input, we are
+ // going to get padding for every remanining output
+ // point.
+ if (input_y >= m_convolver.m_params.input_height) {
+ while (row < m_active_height) {
+ row_ptr[row++] = pad_ptr;
+ }
+ break;
}
- output_x++;
- if (output_x == m_convolver.m_params.output_width) {
- output_y++;
- output_x=0;
+ // If input_y is less than zero, we are going to get
+ // padding for the rest of this output row.
+ if (input_y < 0) {
+ while (output_x < m_convolver.m_params.output_width && row<m_active_height) {
+ row_ptr[row++] = pad_ptr;
+ output_x++;
+ }
+ goto next_row;
}
+
+ // The input row is in bounds - so handle left
+ // padding, then non-padding output, then right
+ // padding.
+
+ // Left padding
+ while (row < m_active_height && input_x < 0) {
+ row_ptr[row++] = pad_ptr;
+
+ output_x++;
+ input_x+=m_convolver.m_params.output_stride_w;
+
+ // Need to detect the end of the row, in case it's
+ // all padding.
+ if (output_x == m_convolver.m_params.output_width) {
+ goto next_row;
+ }
+ }
+
+ // Non-padding output. Factor out base pointer calculation.
+ while (row < m_active_height && input_x < m_convolver.m_params.input_width) {
+ row_ptr[row++] = base_ptr + (input_x * m_parent.m_input_stride);
+
+ output_x++;
+ input_x+=m_convolver.m_params.output_stride_w;
+
+ if (output_x == m_convolver.m_params.output_width) {
+ goto next_row;
+ }
+ }
+
+ // Right padding.
+ while (row < m_active_height && output_x < m_convolver.m_params.output_width) {
+ row_ptr[row++] = pad_ptr;
+ output_x++;
+ }
+
+ // Update output indices for next row. Used as a "goto"
+ // target due to end-of-row checks in nested loops.
+next_row:
+ output_x=0;
+ output_y++;
}
m_current_pos++;