aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp12
1 files changed, 9 insertions, 3 deletions
diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp
index adca48a6d6..2ea70f182b 100644
--- a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp
@@ -183,7 +183,7 @@ class WinogradGEMM
const int row_pad_top,
const int row_pad_left,
const int row_pad_bottom,
- const int row_pad_right
+ const int n_cols
);
static constexpr int max_pad_bottom = inner_tile_rows - 1;
@@ -225,6 +225,7 @@ class WinogradGEMM
const T* const matrix_base,
const int matrix_stride,
const int matrix_row_stride,
+ const T* const biases,
T* const output
);
@@ -236,6 +237,7 @@ class WinogradGEMM
const T* const matrix_base, /** Pointer to base of matrices. */
const int matrix_stride, /** Stride between matrices. */
const int matrix_row_stride, /** Stride within a matrix. */
+ const T* const biases, /** Pointer to biases vector. */
T* const output, /** Pointer to output tensor. */
const int n_batches, /** Number of batches in output tensor. */
const int n_rows, /** Number of rows in output tensor. */
@@ -257,6 +259,7 @@ class WinogradGEMM
const T* const matrix_base,
const int matrix_stride,
const int matrix_row_stride,
+ const T* const biases,
T* const output,
const int output_row_stride,
const int output_col_stride,
@@ -270,14 +273,15 @@ class WinogradGEMM
/** Prepare a single tile of the output tensor. */
template <int pad_bottom, int pad_right>
- static void process_tile(int, const T*, int, T*, int, int);
+ static void process_tile(int, const T*, int, const T*, T*, int, int);
// Array of methods to produce tiles of output tensor.
- typedef void (*TileFn)(int, const T*, int, T*, int, int);
+ typedef void (*TileFn)(int, const T*, int, const T*, T*, int, int);
static const TileFn tile_fns[max_pad_bottom][max_pad_right];
/** Member constants for instances of the transform. */
const T* const _matrix_base;
+ const T* const _biases;
const int _matrix_stride, _matrix_row_stride;
T* const _outptr;
const int _n_batches, _n_rows, _n_cols, _n_channels, _tile_M, _tile_N;
@@ -328,6 +332,7 @@ class WinogradGEMM
void execute(
TOut* const output,
const TIn* const input,
+ const TOut* const biases,
void* working_space=NULL,
const int n_threads=1
);
@@ -336,6 +341,7 @@ class WinogradGEMM
void execute(
TOut* const output,
const TIn* const input,
+ const TOut* const biases,
const int n_threads
);