aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp123
1 files changed, 7 insertions, 116 deletions
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
index 7098fc48a1..31aee35fab 100644
--- a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
@@ -30,7 +30,7 @@
#include "arm_compute/core/NEON/kernels/convolution/common/shims.hpp"
#include "arm_compute/core/NEON/kernels/convolution/common/tensor.hpp"
#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
-
+#include "winograd_input_transform.hpp"
#include <thread>
#include <utility>
@@ -114,121 +114,12 @@ class WinogradGEMM
/** Transform input feature maps from the spatial to the Winograd domain.
*/
template <typename T>
- struct InputTransform
- {
- /** Get the bytes read during the transform. */
- static size_t bytes_read(const Tensor4DShape &shape)
- {
- return shape.size() * sizeof(T);
- }
-
- /** Get the bytes written during the transform. */
- static size_t bytes_written(const Tensor4DShape &shape)
- {
- const int M = iceildiv(shape.n_rows, inner_tile_rows) *
- iceildiv(shape.n_cols, inner_tile_cols);
- const int K = shape.n_channels;
- return inner_tile_rows * inner_tile_cols * M * K * sizeof(T);
- }
-
- /** Get the count of operations performed by the transform. */
- static int ops_performed(const Tensor4DShape &shape);
-
- /** Apply the transform to a tensor. */
- static void execute(
- const T* const input, /** Input tensor data */
- const int n_batches, /** Number of batches in input tensor. */
- const int in_batch_stride, /** Stride between batches of the input. */
- const int n_rows, /** Number of rows in input tensor. */
- const int in_row_stride, /** Stride between rows of the input. */
- const int n_cols, /** Number of columns in input tensor. */
- const int in_col_stride, /** Stride between columns of the input. */
- const int n_channels, /** Number of channels in input tensor. */
- const PaddingType padding, /** Padding type. */
- const int tile_M,
- const int tile_N,
- T* const output, /** Base of output matrices. */
- const int matrix_stride, /** Stride between output matrices. */
- const int matrix_batch_stride, /** Stride between batches within the matrix. */
- const int matrix_row_stride /** Stride within matrices. */
- );
-
- /***********************************************************************/
- /** Create an InputTransform operator fixed on a given problem and set of
- * pointers.
- */
- InputTransform(
- const T* const input, /** Input tensor data */
- const int n_batches, /** Number of batches in input tensor. */
- const int n_rows, /** Number of rows in input tensor. */
- const int n_cols, /** Number of columns in input tensor. */
- const int n_channels, /** Number of channels in input tensor. */
- const PaddingType padding, /** Padding type. */
- T* const output, /** Base of output matrices. */
- const int matrix_stride, /** Stride between output matrices. */
- const int matrix_row_stride, /** Stride within matrices. */
- const int in_batch_stride=0, /** Stride between input batches. */
- const int in_row_stride=0, /** Stride between input rows. */
- const int in_col_stride=0 /** Stride between input columns. */
- );
-
- /** Get the window of work a given operator can perform. */
- unsigned int get_window() const;
- static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window
-
- /** Perform work upon a window of the input. */
- void run(const unsigned int start, const unsigned int stop);
- /***********************************************************************/
-
- private:
- static void process_tile_row(
- const int tile_N,
- int n_channels,
- const T* const input_base,
- const int input_row_stride,
- const int input_col_stride,
- T* const matrix_base,
- const int matrix_stride,
- const int matrix_row_stride,
- const int row_pad_top,
- const int row_pad_left,
- const int row_pad_bottom,
- const int n_cols
- );
-
- // Tile overlaps
- static constexpr int overlap_rows = kernel_rows - 1;
- static constexpr int overlap_cols = kernel_cols - 1;
-
- // Maximum padding and number of distinct paddings
- static constexpr int max_pad_top = kernel_rows / 2;
- static constexpr int n_pad_top = 1 + iceildiv(max_pad_top, inner_tile_rows - overlap_rows);
-
- static constexpr int max_pad_left = kernel_cols / 2;
- static constexpr int n_pad_left = 1 + iceildiv(max_pad_left, inner_tile_cols - overlap_cols);
-
- static constexpr int n_pad_bottom = inner_tile_rows;
- static constexpr int n_pad_right = inner_tile_cols;
-
-
-
- /** Process a single tile of the input tensor. */
- template <int pad_top, int pad_left, int pad_bottom, int pad_right>
- static void process_tile(int, const T*, int, int, T*, int);
-
- // Array of methods to transform tiles of the input tensor.
- typedef void (*TileFn)(int, const T*, int, int, T*, int);
- static const TileFn
- tile_fns[n_pad_top][n_pad_left][n_pad_bottom][n_pad_right];
-
- /* Member values for instance-based API. */
- const T* const _inptr;
- T* const _outptr;
- const int _n_batches, _n_rows, _n_cols, _n_channels, _matrix_stride,
- _matrix_row_stride, _tiles_M, _tiles_N;
- const int _in_col_stride, _in_row_stride, _in_batch_stride;
- const PaddingType _padding_type;
- };
+ using InputTransform = InputTransform<
+ KernelRows, KernelCols,
+ (OutputTileRows + KernelRows - 1),
+ (OutputTileCols + KernelCols - 1),
+ T
+ >;
/** Transform output feature maps from the Winograd to the spatial domain.
*/