diff options
Diffstat (limited to 'tests/validation/reference/Winograd.cpp')
-rw-r--r-- | tests/validation/reference/Winograd.cpp | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/tests/validation/reference/Winograd.cpp b/tests/validation/reference/Winograd.cpp index 75b1b51d46..194a78e95f 100644 --- a/tests/validation/reference/Winograd.cpp +++ b/tests/validation/reference/Winograd.cpp @@ -331,7 +331,7 @@ SimpleTensor<T> winograd_filter_transform(const SimpleTensor<T> &in, const Tenso } template <typename T> -SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info) +SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const SimpleTensor<T> &b, const TensorShape &output_shape, const WinogradInfo &winograd_info) { ARM_COMPUTE_ERROR_ON_MSG(winograd_info.output_data_layout != DataLayout::NCHW, "Only supported NCHW data format"); @@ -444,6 +444,9 @@ SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const Tenso if((xo + xi < w_out) && (yo + yi < h_out)) { out[output_offset + yi * stridey_out + xi] = output_tile[xi + yi * out_tile_w]; + + // Add bias + out[output_offset + yi * stridey_out + xi] += b[zo]; } } } @@ -456,7 +459,7 @@ SimpleTensor<T> winograd_output_transform(const SimpleTensor<T> &in, const Tenso template SimpleTensor<float> winograd_filter_transform(const SimpleTensor<float> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info); template SimpleTensor<float> winograd_input_transform(const SimpleTensor<float> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info); -template SimpleTensor<float> winograd_output_transform(const SimpleTensor<float> &in, const TensorShape &output_shape, const WinogradInfo &winograd_info); +template SimpleTensor<float> winograd_output_transform(const SimpleTensor<float> &in, const SimpleTensor<float> &b, const TensorShape &output_shape, const WinogradInfo &winograd_info); } // namespace reference } // namespace validation } // namespace test |