diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp')
-rw-r--r-- | arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp | 110 |
1 files changed, 34 insertions, 76 deletions
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp index dd67e97035..bc067fd07a 100644 --- a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp +++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp @@ -31,6 +31,7 @@ #include "arm_compute/core/NEON/kernels/convolution/common/tensor.hpp" #include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp" + #include <thread> #include <utility> #include <vector> @@ -135,15 +136,21 @@ class WinogradGEMM /** Apply the transform to a tensor. */ static void execute( - const T *inptr, - const Tensor4DShape& input_shape, - const PaddingType padding_type, + const T* const input, /** Input tensor data */ + const int n_batches, /** Number of batches in input tensor. */ + const int in_batch_stride, /** Stride between batches of the input. */ + const int n_rows, /** Number of rows in input tensor. */ + const int in_row_stride, /** Stride between rows of the input. */ + const int n_cols, /** Number of columns in input tensor. */ + const int in_col_stride, /** Stride between columns of the input. */ + const int n_channels, /** Number of channels in input tensor. */ + const PaddingType padding, /** Padding type. */ const int tile_M, const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride + T* const output, /** Base of output matrices. */ + const int matrix_stride, /** Stride between output matrices. */ + const int matrix_batch_stride, /** Stride between batches within the matrix. */ + const int matrix_row_stride /** Stride within matrices. */ ); /***********************************************************************/ @@ -159,11 +166,15 @@ class WinogradGEMM const PaddingType padding, /** Padding type. */ T* const output, /** Base of output matrices. */ const int matrix_stride, /** Stride between output matrices. */ - const int matrix_row_stride /** Stride within matrices. */ + const int matrix_row_stride, /** Stride within matrices. */ + const int in_batch_stride=0, /** Stride between input batches. */ + const int in_row_stride=0, /** Stride between input rows. */ + const int in_col_stride=0 /** Stride between input columns. */ ); - /** Get the winodw of work a given operator can perform. */ + /** Get the window of work a given operator can perform. */ unsigned int get_window() const; + static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window /** Perform work upon a window of the input. */ void run(const unsigned int start, const unsigned int stop); @@ -201,6 +212,7 @@ class WinogradGEMM T* const _outptr; const int _n_batches, _n_rows, _n_cols, _n_channels, _matrix_stride, _matrix_row_stride, _tiles_M, _tiles_N; + const int _in_col_stride, _in_row_stride, _in_batch_stride; const PaddingType _padding_type; }; @@ -220,7 +232,13 @@ class WinogradGEMM /** Apply the transform to create a tensor. */ static void execute( - const Tensor4DShape &output_shape, + const int n_batches, + const int out_batch_stride, + const int n_rows, + const int out_row_stride, + const int n_cols, + const int out_col_stride, + const int n_channels, const T* const matrix_base, const int matrix_stride, const int matrix_row_stride, @@ -241,11 +259,15 @@ class WinogradGEMM const int n_batches, /** Number of batches in output tensor. */ const int n_rows, /** Number of rows in output tensor. */ const int n_cols, /** Number of columns in output tensor. */ - const int n_channels /** Number of channels in output tensor. */ + const int n_channels, /** Number of channels in output tensor. */ + const int out_batch_stride=0, /** Output batch stride. */ + const int out_row_stride=0, /** Output row stride. */ + const int out_col_stride=0 /** Output column stride. */ ); /** Get the window of work a given operator can perform. */ unsigned int get_window() const; + static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window /** Perform work upon a window of the input. */ void run(const unsigned int start, const unsigned int stop); @@ -284,6 +306,7 @@ class WinogradGEMM const int _matrix_stride, _matrix_row_stride; T* const _outptr; const int _n_batches, _n_rows, _n_cols, _n_channels, _tile_M, _tile_N; + const int _out_col_stride, _out_row_stride, _out_batch_stride; }; /** Perform a convolution. @@ -296,54 +319,6 @@ class WinogradGEMM typedef TOut OutputType; typedef TIn InputType; - /** Create a new Winograd operator. */ - Convolution( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding, - void *kernel_storage=NULL - ); - - Convolution(const Convolution&) = delete; - Convolution operator=(const Convolution&) = delete; - - /** Create a new Winograd operator and initialise the weights. */ - Convolution( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding, - const TIn* const kernel, - void *kernel_storage=NULL, - void *transform_working_space=NULL - ); - - /** Clean up a convolution engine. */ - ~Convolution(); - - /** Transform the weights into the Winograd domain. */ - template <typename WeightsTransform=WeightsTransform<TIn>> - void transform_weights( - const TIn* const kernel, - void *transform_working_space=NULL - ); - - /* Apply the Winograd operator to some input. */ - void execute( - TOut* const output, - const TIn* const input, - const TOut* const biases, - void* working_space=NULL, - const int n_threads=1 - ); - - /* Apply the Winograd operator to some input. */ - void execute( - TOut* const output, - const TIn* const input, - const TOut* const biases, - const int n_threads - ); - /** Get the output shape of a convolution. */ static Tensor4DShape get_output_shape( const KernelShape &kernel_shape, @@ -421,23 +396,6 @@ class WinogradGEMM static constexpr int M_BLOCK = 4; /** Size of block used by GEMM. */ static constexpr int N_BLOCK = 16; /** Size of block used by GEMM. */ - - private: - const KernelShape kernel_shape; /** Shape of the kernel to be applied. */ - TIn *kernel_matrices[N_GEMMS]; /** Pointers into the kernel matrices. */ - const int kernel_matrix_row_stride; /** Stride within the kernel matrices. */ - - const bool manage_kernel_storage; /** Kernel storage is managed by the instance. */ - void* const _kernel_storage; /** Base pointer for kernel storage. */ - - const Tensor4DShape input_shape; /** Shape of the input tensor. */ - const PaddingType padding; /** Padding applied by the operator. */ - - const Tensor4DShape output_shape; /** Output shape produced by the operator. */ - - const int tile_rows; /** Number of rows of tiles. */ - const int tile_cols; /** Number of columns of tiles. */ - const int M, K, N; /** Sizes of underlying fundamental matrix multiplications. */ }; }; |