aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp')
-rw-r--r--src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp55
1 files changed, 32 insertions, 23 deletions
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp
index d97af21a43..fe47ccbde9 100644
--- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,24 +41,30 @@
namespace winograd
{
-MEMBERFN()::OutputTransform(
- const int n_batches,
- const int n_rows,
- const int n_cols,
- const int n_channels
-) : _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels),
- _matrix_base(nullptr),
- _biases(nullptr),
- _matrix_stride(0), _matrix_row_stride(0), _matrix_batch_stride(0),
- _outptr(nullptr),
- _tiles_M(iceildiv(n_rows, output_tile_rows)),
- _tiles_N(iceildiv(n_cols, output_tile_cols)),
- _out_col_stride(0), _out_row_stride(0), _out_batch_stride(0),
- _working_space_col_stride(n_channels),
- _working_space_row_stride(output_tile_cols * _working_space_col_stride),
- _working_space(nullptr)
-{
-}
+MEMBERFN()
+::OutputTransform(const int n_batches, const int n_rows, const int n_cols,
+ const int n_channels, const arm_gemm::Activation &activation)
+ : _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols),
+ _n_channels(n_channels),
+ _output_min((activation.type == arm_gemm::Activation::Type::ReLU ||
+ activation.type == arm_gemm::Activation::Type::BoundedReLU)
+ ? static_cast<TOut>(0.0f)
+ : (std::numeric_limits<TOut>::has_infinity)
+ ? -std::numeric_limits<TOut>::infinity()
+ : std::numeric_limits<TOut>::lowest()),
+ _output_max((activation.type == arm_gemm::Activation::Type::BoundedReLU)
+ ? static_cast<TOut>(activation.param1)
+ : (std::numeric_limits<TOut>::has_infinity)
+ ? std::numeric_limits<TOut>::infinity()
+ : std::numeric_limits<TOut>::max()),
+ _matrix_base(nullptr), _biases(nullptr), _matrix_stride(0),
+ _matrix_row_stride(0), _matrix_batch_stride(0), _outptr(nullptr),
+ _tiles_M(iceildiv(n_rows, output_tile_rows)),
+ _tiles_N(iceildiv(n_cols, output_tile_cols)), _out_col_stride(0),
+ _out_row_stride(0), _out_batch_stride(0),
+ _working_space_col_stride(n_channels),
+ _working_space_row_stride(output_tile_cols * _working_space_col_stride),
+ _working_space(nullptr) {}
MEMBERFN(void)::set_input_matrices(const void * const mptr, const int ldmatrix, const int ldrow)
{
@@ -100,9 +106,10 @@ Nx1MEMBERFN()::OutputTransform(
const int n_batches,
const int n_rows,
const int n_cols,
- const int n_channels
+ const int n_channels,
+ const arm_gemm::Activation &activation
) : OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>::OutputTransform(
- n_batches, n_cols, n_rows, n_channels /* Transpose rows and columns */
+ n_batches, n_cols, n_rows, n_channels, activation /* Transpose rows and columns */
)
{
}
@@ -212,7 +219,8 @@ MEMBERFN(void)::transform_uncropped_tile(
{
transform_tile(
n_channels, inptr, _matrix_stride, biases,
- outptr, _out_row_stride, _out_col_stride
+ outptr, _out_row_stride, _out_col_stride,
+ _output_min, _output_max
);
}
@@ -230,7 +238,8 @@ MEMBERFN(void)::transform_cropped_tile(
TOut *wsptr = static_cast<TOut *>(get_working_space(threadid));
transform_tile(
n_channels, inptr, _matrix_stride, biases,
- wsptr, _working_space_row_stride, _working_space_col_stride
+ wsptr, _working_space_row_stride, _working_space_col_stride,
+ _output_min, _output_max
);
padding::crop_and_copy_tile(