diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp')
-rw-r--r-- | arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp | 95 |
1 files changed, 7 insertions, 88 deletions
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp index 31aee35fab..71b5fd516f 100644 --- a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp +++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp @@ -31,6 +31,7 @@ #include "arm_compute/core/NEON/kernels/convolution/common/tensor.hpp" #include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp" #include "winograd_input_transform.hpp" +#include "winograd_output_transform.hpp" #include <thread> #include <utility> @@ -124,95 +125,13 @@ class WinogradGEMM /** Transform output feature maps from the Winograd to the spatial domain. */ template <typename T> - struct OutputTransform - { - /** Get the bytes read during the transform. */ - static size_t bytes_read(const Tensor4DShape &shape); - - /** Get the bytes written during the transform. */ - static size_t bytes_written(const Tensor4DShape &shape); - - /** Get the count of operations performed by the transform. */ - static int ops_performed(const Tensor4DShape &shape); - - /** Apply the transform to create a tensor. */ - static void execute( - const int n_batches, - const int out_batch_stride, - const int n_rows, - const int out_row_stride, - const int n_cols, - const int out_col_stride, - const int n_channels, - const T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - const T* const biases, - T* const output - ); - - /***********************************************************************/ - /** Create an OutputTransform operator fixed on a given problem and set - * of pointers. - */ - OutputTransform( - const T* const matrix_base, /** Pointer to base of matrices. */ - const int matrix_stride, /** Stride between matrices. */ - const int matrix_row_stride, /** Stride within a matrix. */ - const T* const biases, /** Pointer to biases vector. */ - T* const output, /** Pointer to output tensor. */ - const int n_batches, /** Number of batches in output tensor. */ - const int n_rows, /** Number of rows in output tensor. */ - const int n_cols, /** Number of columns in output tensor. */ - const int n_channels, /** Number of channels in output tensor. */ - const int out_batch_stride=0, /** Output batch stride. */ - const int out_row_stride=0, /** Output row stride. */ - const int out_col_stride=0 /** Output column stride. */ - ); - - /** Get the window of work a given operator can perform. */ - unsigned int get_window() const; - static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window - - /** Perform work upon a window of the input. */ - void run(const unsigned int start, const unsigned int stop); - /***********************************************************************/ - - private: - static void process_tile_row( - const int tile_N, - const int n_channels, - const T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - const T* const biases, - T* const output, - const int output_row_stride, - const int output_col_stride, - const int row_pad_bottom, - const int row_pad_right - ); - - // Limits on the amount of anti-padding to be applied - static constexpr int max_pad_bottom = output_tile_rows; - static constexpr int max_pad_right = output_tile_cols; - - /** Prepare a single tile of the output tensor. */ - template <int pad_bottom, int pad_right> - static void process_tile(int, const T*, int, const T*, T*, int, int); - - // Array of methods to produce tiles of output tensor. - typedef void (*TileFn)(int, const T*, int, const T*, T*, int, int); - static const TileFn tile_fns[max_pad_bottom][max_pad_right]; + using OutputTransform = OutputTransform< + KernelRows, KernelCols, + (OutputTileRows + KernelRows - 1), + (OutputTileCols + KernelCols - 1), + T + >; - /** Member constants for instances of the transform. */ - const T* const _matrix_base; - const T* const _biases; - const int _matrix_stride, _matrix_row_stride; - T* const _outptr; - const int _n_batches, _n_rows, _n_cols, _n_channels, _tile_M, _tile_N; - const int _out_col_stride, _out_row_stride, _out_batch_stride; - }; /** Perform a convolution. */ |