diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/convolution')
-rw-r--r-- | arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp index 6dd8f5460a..fc4b255a9c 100644 --- a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp +++ b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp @@ -23,7 +23,7 @@ */ #pragma once -#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp" +#include "../winograd_gemm.hpp" namespace winograd { @@ -45,9 +45,8 @@ namespace winograd ) { // Compute the padding required on each edge of the image - const bool base_padding = (padding_type == PADDING_SAME) ? 1 : 0; - const int pad_top = base_padding; - const int pad_left = base_padding; + const int pad_top = (padding_type == PADDING_SAME) ? (kernel_rows - 1) / 2 : 0; + const int pad_left = (padding_type == PADDING_SAME) ? (kernel_cols - 1) / 2 : 0; const int tile_overlap = kernel_rows - 1; // Compute striding values (assuming NHWC ordered data) @@ -68,8 +67,7 @@ namespace winograd for (int tile_i = 0; tile_i < tile_M; tile_i++) { // Pointer to the row - const int row_offset = (tile_i == 0) ? - 0 : ((padding_type == PADDING_VALID) ? 0 : 1); + const int row_offset = (tile_i == 0) ? 0 : pad_top; const T* const input_base_row = ( input_base_batch + ((inner_tile_rows - (kernel_rows - 1))*tile_i - row_offset)*input_row_stride ); @@ -129,7 +127,9 @@ namespace winograd T* const outptr = matrix_base + tile_j*matrix_row_stride; // Apply the specific tile processing function - tile_fns[pad_top][t_pad_left][pad_bottom][t_pad_right]( + const int f_pad_top = pad_top ? 1 : 0; + const int f_pad_left = t_pad_left ? 1 : 0; + tile_fns[f_pad_top][f_pad_left][pad_bottom][t_pad_right]( n_channels, input_base_col, input_row_stride, @@ -156,8 +156,10 @@ namespace winograd ) : _inptr(input), _outptr(output), _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels), _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride), - _tiles_M(iceildiv((padding == PADDING_SAME) ? n_rows : n_rows - 2, output_tile_rows)), - _tiles_N(iceildiv((padding == PADDING_SAME) ? n_cols : n_cols - 2, output_tile_cols)), + _tiles_M(iceildiv((padding == PADDING_SAME) ? n_rows : n_rows - kr + 1, + output_tile_rows)), + _tiles_N(iceildiv((padding == PADDING_SAME) ? n_cols : n_cols - kc + 1, + output_tile_cols)), _padding_type(padding) { } |