diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp')
-rw-r--r-- | arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp index adca48a6d6..2ea70f182b 100644 --- a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp +++ b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp @@ -183,7 +183,7 @@ class WinogradGEMM const int row_pad_top, const int row_pad_left, const int row_pad_bottom, - const int row_pad_right + const int n_cols ); static constexpr int max_pad_bottom = inner_tile_rows - 1; @@ -225,6 +225,7 @@ class WinogradGEMM const T* const matrix_base, const int matrix_stride, const int matrix_row_stride, + const T* const biases, T* const output ); @@ -236,6 +237,7 @@ class WinogradGEMM const T* const matrix_base, /** Pointer to base of matrices. */ const int matrix_stride, /** Stride between matrices. */ const int matrix_row_stride, /** Stride within a matrix. */ + const T* const biases, /** Pointer to biases vector. */ T* const output, /** Pointer to output tensor. */ const int n_batches, /** Number of batches in output tensor. */ const int n_rows, /** Number of rows in output tensor. */ @@ -257,6 +259,7 @@ class WinogradGEMM const T* const matrix_base, const int matrix_stride, const int matrix_row_stride, + const T* const biases, T* const output, const int output_row_stride, const int output_col_stride, @@ -270,14 +273,15 @@ class WinogradGEMM /** Prepare a single tile of the output tensor. */ template <int pad_bottom, int pad_right> - static void process_tile(int, const T*, int, T*, int, int); + static void process_tile(int, const T*, int, const T*, T*, int, int); // Array of methods to produce tiles of output tensor. - typedef void (*TileFn)(int, const T*, int, T*, int, int); + typedef void (*TileFn)(int, const T*, int, const T*, T*, int, int); static const TileFn tile_fns[max_pad_bottom][max_pad_right]; /** Member constants for instances of the transform. */ const T* const _matrix_base; + const T* const _biases; const int _matrix_stride, _matrix_row_stride; T* const _outptr; const int _n_batches, _n_rows, _n_cols, _n_channels, _tile_M, _tile_N; @@ -328,6 +332,7 @@ class WinogradGEMM void execute( TOut* const output, const TIn* const input, + const TOut* const biases, void* working_space=NULL, const int n_threads=1 ); @@ -336,6 +341,7 @@ class WinogradGEMM void execute( TOut* const output, const TIn* const input, + const TOut* const biases, const int n_threads ); |