diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/convolver.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/convolver.hpp | 88 |
1 files changed, 76 insertions, 12 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/convolver.hpp b/src/core/NEON/kernels/arm_gemm/convolver.hpp index 879d95f5bb..b15f669132 100644 --- a/src/core/NEON/kernels/arm_gemm/convolver.hpp +++ b/src/core/NEON/kernels/arm_gemm/convolver.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Arm Limited. + * Copyright (c) 2020,2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -103,11 +103,15 @@ private: return (m_length_remaining == 0); } + // Compute a block of output pointers, accounting for padding. + // This is performance critical. std::tuple<unsigned int, unsigned int> next_block(const T ** const row_ptr) { if (finished()) { return std::make_tuple(0, 0); } + const T *pad_ptr = m_convolver.m_pad_row.data(); + // "in_width" in the amount of data that will be read in (copied) // "out_width" is the total amount of data that will be produced (including padding) unsigned int offset = (m_current_pos == m_parent.m_start_pos) ? m_parent.m_start_offset : 0; @@ -117,23 +121,83 @@ private: unsigned int output_y = m_start_output_y; unsigned int output_x = m_start_output_x; - for (unsigned int row=0; row<m_active_height; row++) { + // Loop over "row" (output points), but really there is one + // trip through this outer loop per row of output to + // minimize redundant padding calculations. + unsigned int row=0; + while (row < m_active_height) { int input_y = (output_y * m_convolver.m_params.output_stride_h) + m_convolver.m_kernel_y[m_current_pos]; int input_x = (output_x * m_convolver.m_params.output_stride_w) + m_convolver.m_kernel_x[m_current_pos]; - // Out-of-bounds points will read the padding data, - // otherwise find the correct address in the input image. - if (input_y < 0 || input_y >= m_convolver.m_params.input_height || input_x < 0 || input_x >= m_convolver.m_params.input_width) { - row_ptr[row] = m_convolver.m_pad_row.data(); - } else { - row_ptr[row] = m_parent.m_input_base + ((input_y * m_convolver.m_params.input_width) + input_x) * m_parent.m_input_stride; + // Factor out base pointer computation. + const T *base_ptr = m_parent.m_input_base + + (input_y * m_convolver.m_params.input_width * m_parent.m_input_stride); + + // To start with, check the input row is in-bounds. If + // not, (at least) this entire output row must be + // padding so handle accordingly. + + // If input_y is off the bottom of the input, we are + // going to get padding for every remanining output + // point. + if (input_y >= m_convolver.m_params.input_height) { + while (row < m_active_height) { + row_ptr[row++] = pad_ptr; + } + break; } - output_x++; - if (output_x == m_convolver.m_params.output_width) { - output_y++; - output_x=0; + // If input_y is less than zero, we are going to get + // padding for the rest of this output row. + if (input_y < 0) { + while (output_x < m_convolver.m_params.output_width && row<m_active_height) { + row_ptr[row++] = pad_ptr; + output_x++; + } + goto next_row; } + + // The input row is in bounds - so handle left + // padding, then non-padding output, then right + // padding. + + // Left padding + while (row < m_active_height && input_x < 0) { + row_ptr[row++] = pad_ptr; + + output_x++; + input_x+=m_convolver.m_params.output_stride_w; + + // Need to detect the end of the row, in case it's + // all padding. + if (output_x == m_convolver.m_params.output_width) { + goto next_row; + } + } + + // Non-padding output. Factor out base pointer calculation. + while (row < m_active_height && input_x < m_convolver.m_params.input_width) { + row_ptr[row++] = base_ptr + (input_x * m_parent.m_input_stride); + + output_x++; + input_x+=m_convolver.m_params.output_stride_w; + + if (output_x == m_convolver.m_params.output_width) { + goto next_row; + } + } + + // Right padding. + while (row < m_active_height && output_x < m_convolver.m_params.output_width) { + row_ptr[row++] = pad_ptr; + output_x++; + } + + // Update output indices for next row. Used as a "goto" + // target due to end-of-row checks in nested loops. +next_row: + output_x=0; + output_y++; } m_current_pos++; |