From 5264b7d5555ec980f9c52c719122479d0d676af8 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Mon, 21 Oct 2019 14:25:41 +0100 Subject: COMPMID-2576: Fuse activation in Winograd output transform. Change-Id: I26dd1307847adeaaefae0a7374b9858c07d71372 Signed-off-by: Pablo Tello Reviewed-on: https://review.mlplatform.org/c/2172 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Gian Marco Iodice --- .../winograd/winograd_transforms/output.hpp | 55 +++++++++++++--------- 1 file changed, 32 insertions(+), 23 deletions(-) (limited to 'src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp') diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp index d97af21a43..fe47ccbde9 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp +++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -41,24 +41,30 @@ namespace winograd { -MEMBERFN()::OutputTransform( - const int n_batches, - const int n_rows, - const int n_cols, - const int n_channels -) : _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels), - _matrix_base(nullptr), - _biases(nullptr), - _matrix_stride(0), _matrix_row_stride(0), _matrix_batch_stride(0), - _outptr(nullptr), - _tiles_M(iceildiv(n_rows, output_tile_rows)), - _tiles_N(iceildiv(n_cols, output_tile_cols)), - _out_col_stride(0), _out_row_stride(0), _out_batch_stride(0), - _working_space_col_stride(n_channels), - _working_space_row_stride(output_tile_cols * _working_space_col_stride), - _working_space(nullptr) -{ -} +MEMBERFN() +::OutputTransform(const int n_batches, const int n_rows, const int n_cols, + const int n_channels, const arm_gemm::Activation &activation) + : _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), + _n_channels(n_channels), + _output_min((activation.type == arm_gemm::Activation::Type::ReLU || + activation.type == arm_gemm::Activation::Type::BoundedReLU) + ? static_cast(0.0f) + : (std::numeric_limits::has_infinity) + ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest()), + _output_max((activation.type == arm_gemm::Activation::Type::BoundedReLU) + ? static_cast(activation.param1) + : (std::numeric_limits::has_infinity) + ? std::numeric_limits::infinity() + : std::numeric_limits::max()), + _matrix_base(nullptr), _biases(nullptr), _matrix_stride(0), + _matrix_row_stride(0), _matrix_batch_stride(0), _outptr(nullptr), + _tiles_M(iceildiv(n_rows, output_tile_rows)), + _tiles_N(iceildiv(n_cols, output_tile_cols)), _out_col_stride(0), + _out_row_stride(0), _out_batch_stride(0), + _working_space_col_stride(n_channels), + _working_space_row_stride(output_tile_cols * _working_space_col_stride), + _working_space(nullptr) {} MEMBERFN(void)::set_input_matrices(const void * const mptr, const int ldmatrix, const int ldrow) { @@ -100,9 +106,10 @@ Nx1MEMBERFN()::OutputTransform( const int n_batches, const int n_rows, const int n_cols, - const int n_channels + const int n_channels, + const arm_gemm::Activation &activation ) : OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>::OutputTransform( - n_batches, n_cols, n_rows, n_channels /* Transpose rows and columns */ + n_batches, n_cols, n_rows, n_channels, activation /* Transpose rows and columns */ ) { } @@ -212,7 +219,8 @@ MEMBERFN(void)::transform_uncropped_tile( { transform_tile( n_channels, inptr, _matrix_stride, biases, - outptr, _out_row_stride, _out_col_stride + outptr, _out_row_stride, _out_col_stride, + _output_min, _output_max ); } @@ -230,7 +238,8 @@ MEMBERFN(void)::transform_cropped_tile( TOut *wsptr = static_cast(get_working_space(threadid)); transform_tile( n_channels, inptr, _matrix_stride, biases, - wsptr, _working_space_row_stride, _working_space_col_stride + wsptr, _working_space_row_stride, _working_space_col_stride, + _output_min, _output_max ); padding::crop_and_copy_tile( -- cgit v1.2.1