aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp20
1 files changed, 11 insertions, 9 deletions
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
index 6dd8f5460a..fc4b255a9c 100644
--- a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
@@ -23,7 +23,7 @@
*/
#pragma once
-#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
+#include "../winograd_gemm.hpp"
namespace winograd
{
@@ -45,9 +45,8 @@ namespace winograd
)
{
// Compute the padding required on each edge of the image
- const bool base_padding = (padding_type == PADDING_SAME) ? 1 : 0;
- const int pad_top = base_padding;
- const int pad_left = base_padding;
+ const int pad_top = (padding_type == PADDING_SAME) ? (kernel_rows - 1) / 2 : 0;
+ const int pad_left = (padding_type == PADDING_SAME) ? (kernel_cols - 1) / 2 : 0;
const int tile_overlap = kernel_rows - 1;
// Compute striding values (assuming NHWC ordered data)
@@ -68,8 +67,7 @@ namespace winograd
for (int tile_i = 0; tile_i < tile_M; tile_i++)
{
// Pointer to the row
- const int row_offset = (tile_i == 0) ?
- 0 : ((padding_type == PADDING_VALID) ? 0 : 1);
+ const int row_offset = (tile_i == 0) ? 0 : pad_top;
const T* const input_base_row = (
input_base_batch + ((inner_tile_rows - (kernel_rows - 1))*tile_i - row_offset)*input_row_stride
);
@@ -129,7 +127,9 @@ namespace winograd
T* const outptr = matrix_base + tile_j*matrix_row_stride;
// Apply the specific tile processing function
- tile_fns[pad_top][t_pad_left][pad_bottom][t_pad_right](
+ const int f_pad_top = pad_top ? 1 : 0;
+ const int f_pad_left = t_pad_left ? 1 : 0;
+ tile_fns[f_pad_top][f_pad_left][pad_bottom][t_pad_right](
n_channels,
input_base_col,
input_row_stride,
@@ -156,8 +156,10 @@ namespace winograd
) : _inptr(input), _outptr(output),
_n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels),
_matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride),
- _tiles_M(iceildiv((padding == PADDING_SAME) ? n_rows : n_rows - 2, output_tile_rows)),
- _tiles_N(iceildiv((padding == PADDING_SAME) ? n_cols : n_cols - 2, output_tile_cols)),
+ _tiles_M(iceildiv((padding == PADDING_SAME) ? n_rows : n_rows - kr + 1,
+ output_tile_rows)),
+ _tiles_N(iceildiv((padding == PADDING_SAME) ? n_cols : n_cols - kc + 1,
+ output_tile_cols)),
_padding_type(padding)
{
}