aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp24
1 files changed, 19 insertions, 5 deletions
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
index bc067fd07a..7098fc48a1 100644
--- a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
@@ -49,8 +49,8 @@ class WinogradGEMM
static constexpr int output_tile_cols = OutputTileCols;
static constexpr int kernel_rows = KernelRows;
static constexpr int kernel_cols = KernelCols;
- static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1; // TODO Check
- static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1; // TODO Check
+ static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1;
+ static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1;
static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols;
/** Transform weights from the spatial to the Winograd domain. */
@@ -196,8 +196,21 @@ class WinogradGEMM
const int n_cols
);
- static constexpr int max_pad_bottom = inner_tile_rows - 1;
- static constexpr int max_pad_right = inner_tile_cols - 1;
+ // Tile overlaps
+ static constexpr int overlap_rows = kernel_rows - 1;
+ static constexpr int overlap_cols = kernel_cols - 1;
+
+ // Maximum padding and number of distinct paddings
+ static constexpr int max_pad_top = kernel_rows / 2;
+ static constexpr int n_pad_top = 1 + iceildiv(max_pad_top, inner_tile_rows - overlap_rows);
+
+ static constexpr int max_pad_left = kernel_cols / 2;
+ static constexpr int n_pad_left = 1 + iceildiv(max_pad_left, inner_tile_cols - overlap_cols);
+
+ static constexpr int n_pad_bottom = inner_tile_rows;
+ static constexpr int n_pad_right = inner_tile_cols;
+
+
/** Process a single tile of the input tensor. */
template <int pad_top, int pad_left, int pad_bottom, int pad_right>
@@ -205,7 +218,8 @@ class WinogradGEMM
// Array of methods to transform tiles of the input tensor.
typedef void (*TileFn)(int, const T*, int, int, T*, int);
- static const TileFn tile_fns[2][2][max_pad_bottom][max_pad_right];
+ static const TileFn
+ tile_fns[n_pad_top][n_pad_left][n_pad_bottom][n_pad_right];
/* Member values for instance-based API. */
const T* const _inptr;