aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2019-10-21 14:25:41 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-11-08 12:07:21 +0000
commit5264b7d5555ec980f9c52c719122479d0d676af8 (patch)
tree78260be4ee31d89d00705acbf1e0ed2361144bd4
parent68adf4449b1f92dd2362d88bb0fd565c2c06d22c (diff)
downloadComputeLibrary-5264b7d5555ec980f9c52c719122479d0d676af8.tar.gz
COMPMID-2576: Fuse activation in Winograd output transform.
Change-Id: I26dd1307847adeaaefae0a7374b9858c07d71372 Signed-off-by: Pablo Tello <pablo.tello@arm.com> Reviewed-on: https://review.mlplatform.org/c/2172 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
-rw-r--r--arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h126
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp105
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/winograd_layer.hpp36
-rw-r--r--src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp69
-rw-r--r--src/core/NEON/kernels/convolution/winograd/winograd.cpp250
-rw-r--r--src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp7
-rw-r--r--src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp55
-rw-r--r--src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp16
-rw-r--r--src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp19
-rw-r--r--src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp19
-rw-r--r--src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp19
-rw-r--r--src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp1769
-rw-r--r--src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp16
-rw-r--r--src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp68
14 files changed, 518 insertions, 2056 deletions
diff --git a/arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h b/arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h
index f6b189cb1c..962037dd4f 100644
--- a/arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h
@@ -64,13 +64,15 @@ public:
/** Gets the stride between matrices in the input worspace
*
- * @param[in] kernel_shape The shape of the weights tensor.
- * @param[in] input_shape The shape of the input tensor.
- * @param[in] padding_type The type of padding to be used.
+ * @param[in] num_batches Number of batches in the input tensor.
+ * @param[in] num_channels Number of feature maps in the input tensor.
+ * @param[in] num_rows Number of rows in each feature map.
+ * @param[in] num_cols Number of columns in each feature map.
+ * @param[in] same_padding Use "SAME" padding, otherwise use "VALID".
*
* @return Stride expressed in bytes.
*/
- virtual int get_matrix_stride(const KernelShape &kernel_shape, const Tensor4DShape &input_shape, const PaddingType padding_type) const = 0;
+ virtual int get_matrix_stride(int num_batches, int num_channels, int num_rows, int num_cols, bool same_padding) const = 0;
/** Configure the output transform kernel.
*
@@ -141,13 +143,20 @@ public:
/** Gets the stride between matrices in the input worspace
*
- * @param[in] kernel_shape The shape of the weights tensor.
- * @param[in] input_shape The shape of the input tensor.
- * @param[in] padding_type The type of padding to be used.
+ * @param[in] num_batches Number of batches in the input tensor.
+ * @param[in] num_channels Number of feature maps in the input tensor.
+ * @param[in] num_rows Number of rows in each feature map.
+ * @param[in] num_cols Number of columns in each feature map.
+ * @param[in] same_padding Use "SAME" padding, otherwise use "VALID".
*
* @return Stride expressed in bytes.
*/
- int get_matrix_stride(const KernelShape &kernel_shape, const Tensor4DShape &input_shape, const PaddingType padding_type) const override;
+ int get_matrix_stride(
+ int num_batches,
+ int num_channels,
+ int num_rows,
+ int num_cols,
+ bool same_padding) const override;
/** Default constructor */
NEWinogradLayerTransformInputKernel();
@@ -241,31 +250,35 @@ public:
* @param[in] num_rows Number of rows in each feature map of the input tensor.
* @param[in] num_cols Number of columns in each feature map of the input tensor.
* @param[in] num_output_channels Number of feature maps in the output tensor.
- * @param[in] same_padding Use "SAME" padding, otherwise use "VALID".
*
* @return Storage size (in units of TOut) required.
*/
- virtual unsigned int get_output_storage_size(int num_batches, int num_rows, int num_cols, int num_output_channels, bool same_padding) const = 0;
+ virtual unsigned int get_output_storage_size(int num_batches, int num_rows, int num_cols, int num_output_channels) const = 0;
/** Gets the stride between matrices in the output worspace
*
- * @param[in] kernel_shape The shape of the weights tensor.
- * @param[in] input_shape The shape of the input tensor.
- * @param[in] padding_type The type of padding to be used.
+ * @param[in] num_batches Number of batches in the output tensor.
+ * @param[in] num_rows Number of rows in each feature map of the input tensor.
+ * @param[in] num_cols Number of columns in each feature map of the input tensor.
+ * @param[in] num_output_channels Number of feature maps in the output tensor.
*
* @return Stride expressed in bytes.
*/
- virtual int get_matrix_stride(const KernelShape &kernel_shape, const Tensor4DShape &input_shape, const PaddingType padding_type) const = 0;
+ virtual int get_matrix_stride(int num_batches, int num_rows, int num_cols, int num_output_channels) const = 0;
/** Get the output shape of a convolution.
*
- * @param[in] kernel_shape The shape of the weights tensor.
- * @param[in] in_shape The shape of the input tensor.
- * @param[in] padding The type of padding to be used.
+ * @param[in] num_rows Number of rows in each feature map of the input tensor.
+ * @param[in] num_cols Number of columns in each feature map of the input tensor.
+ * @param[in] padding_same True if padding is SAME, false otherwise
*
- * @return Stride expressed in bytes.
+ * @return Shape of the output tensor
*/
- virtual Tensor4DShape get_output_shape(const KernelShape &kernel_shape, const Tensor4DShape &in_shape, const PaddingType padding) const = 0;
+ virtual std::pair<unsigned int, unsigned int> get_output_shape(
+ int num_rows, /* Number of rows in each feature map of the input tensor. */
+ int num_cols, /* Number of columns in each feature map of the input tensor. */
+ bool padding_same /* True if padding is SAME, false otherwise */
+ ) const = 0;
/** Configure the output transform kernel.
*
@@ -278,17 +291,19 @@ public:
* @param[in] num_cols Number of columns in output tensor.
* @param[in] num_channels Number of feature maps in the output tensor.
* @param[in] workspace Tensor to be used as the working space during the computation.
+ * @param[in] activation Activation to be used
*/
virtual void configure(
- const ITensor *biases,
- const ITensor *transformed_output,
- const int matrix_stride,
- ITensor *output_nhwc,
- const int num_batches,
- const int num_rows,
- const int num_cols,
- const int num_channels,
- ITensor *workspace) = 0;
+ const ITensor *biases,
+ const ITensor *transformed_output,
+ const int matrix_stride,
+ ITensor *output_nhwc,
+ const int num_batches,
+ const int num_rows,
+ const int num_cols,
+ const int num_channels,
+ ITensor *workspace,
+ const arm_gemm::Activation &activation) = 0;
virtual ~INEWinogradLayerTransformOutputKernel()
{
@@ -326,30 +341,33 @@ public:
* @param[in] num_rows Number of rows in each feature map of the input tensor.
* @param[in] num_cols Number of columns in each feature map of the input tensor.
* @param[in] num_output_channels Number of feature maps in the output tensor.
- * @param[in] same_padding Use "SAME" padding, otherwise use "VALID".
*
* @return Storage size (in units of TOut) required.
*/
- unsigned int get_output_storage_size(int num_batches, int num_rows, int num_cols, int num_output_channels, bool same_padding) const override;
+ unsigned int get_output_storage_size(int num_batches, int num_rows, int num_cols, int num_output_channels) const override;
/** Gets the stride between matrices in the output worspace
*
- * @param[in] kernel_shape The shape of the weights tensor.
- * @param[in] input_shape The shape of the input tensor.
- * @param[in] padding_type The type of padding to be used.
+ * @param[in] num_batches Number of batches in the output tensor.
+ * @param[in] num_rows Number of rows in each feature map of the input tensor.
+ * @param[in] num_cols Number of columns in each feature map of the input tensor.
+ * @param[in] num_output_channels Number of feature maps in the output tensor.
*
* @return Stride expressed in bytes.
*/
- int get_matrix_stride(const KernelShape &kernel_shape, const Tensor4DShape &input_shape, const PaddingType padding_type) const override;
+ int get_matrix_stride(int num_batches, int num_rows, int num_cols, int num_output_channels) const override;
/** Get the output shape of a convolution.
*
- * @param[in] kernel_shape The shape of the weights tensor.
- * @param[in] in_shape The shape of the input tensor.
- * @param[in] padding The type of padding to be used.
+ * @param[in] num_rows Number of rows in each feature map of the input tensor.
+ * @param[in] num_cols Number of columns in each feature map of the input tensor.
+ * @param[in] padding_same True if padding is SAME, false otherwise
*
- * @return Stride expressed in bytes.
+ * @return Shape of the output tensor
*/
- Tensor4DShape get_output_shape(const KernelShape &kernel_shape, const Tensor4DShape &in_shape, const PaddingType padding) const override;
+ std::pair<unsigned int, unsigned int> get_output_shape(
+ int num_rows, /* Number of rows in each feature map of the input tensor. */
+ int num_cols, /* Number of columns in each feature map of the input tensor. */
+ bool padding_same) const override;
/** Get the working space required to perform the transformation.
*
@@ -374,17 +392,19 @@ public:
* @param[in] num_cols Number of columns in output tensor.
* @param[in] num_channels Number of feature maps in the output tensor.
* @param[in] workspace Tensor to be used as the working space during the computation.
+ * @param[in] activation Activation to be used
*/
void configure(
- const ITensor *biases,
- const ITensor *transformed_output,
- const int matrix_stride,
- ITensor *output_nhwc,
- const int num_batches,
- const int num_rows,
- const int num_cols,
- const int num_channels,
- ITensor *workspace) override;
+ const ITensor *biases,
+ const ITensor *transformed_output,
+ const int matrix_stride,
+ ITensor *output_nhwc,
+ const int num_batches,
+ const int num_rows,
+ const int num_cols,
+ const int num_channels,
+ ITensor *workspace,
+ const arm_gemm::Activation &activation) override;
void run(const Window &window, const ThreadInfo &info) override;
@@ -448,11 +468,12 @@ public:
virtual unsigned int get_weight_storage_size(int num_output_channels, int num_input_channels) const = 0;
/** Gets the stride between matrices in the kernel worspace
*
- * @param[in] kernel_shape The shape of the weights tensor.
+ * @param[in] num_output_channels Number of output feature maps.
+ * @param[in] num_input_channels Number of input feature maps.
*
* @return Stride expressed in bytes.
*/
- virtual int get_matrix_stride(const KernelShape &kernel_shape) const = 0;
+ virtual int get_matrix_stride(int num_output_channels, int num_input_channels) const = 0;
/** Configure the weights transform kernel.
*
@@ -535,11 +556,12 @@ public:
/** Gets the stride between matrices in the input worspace
*
- * @param[in] kernel_shape The shape of the weights tensor.
+ * @param[in] num_output_channels Number of output feature maps.
+ * @param[in] num_input_channels Number of input feature maps.
*
* @return Stride expressed in bytes.
*/
- int get_matrix_stride(const KernelShape &kernel_shape) const override;
+ int get_matrix_stride(int num_output_channels, int num_input_channels) const override;
void run(const Window &window, const ThreadInfo &info) override;
bool is_parallelisable() const override;
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp
index 183c9c1061..bc0d9d4296 100644
--- a/arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,9 +24,10 @@
#pragma once
-#include "convolution.hpp"
-#include "tensor.hpp"
-#include "utils.hpp"
+#include "arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp"
+
+#include <cstddef>
+#include <utility>
namespace winograd
{
@@ -308,7 +309,8 @@ class OutputTransform : public IOutputTransform
int n_batches, /**< Number of batches in output tensor. */
int n_rows, /**< Number of rows in output tensor. */
int n_cols, /**< Number of columns in output tensor. */
- int n_channels /**< Number of channels in output tensor. */
+ int n_channels, /**< Number of channels in output tensor. */
+ const arm_gemm::Activation &activation
);
OutputTransform(OutputTransform&) = delete;
@@ -344,6 +346,7 @@ class OutputTransform : public IOutputTransform
static constexpr int output_tile_cols = InnerTileCols - KernelCols + 1;
const int _n_batches, _n_rows, _n_cols, _n_channels;
+ const TOut _output_min, _output_max;
private:
void transform_uncropped_tile(
@@ -372,7 +375,9 @@ class OutputTransform : public IOutputTransform
const TOut* biases,
TOut* output,
int output_row_stride,
- int output_col_stride
+ int output_col_stride,
+ TOut output_min,
+ TOut output_max
);
/** Get the working space for a thread. */
@@ -405,7 +410,8 @@ class OutputTransform<KernelRows, 1, InnerTileRows, 1, TIn, TOut, Roots> :
int n_batches, /**< Number of batches in output tensor. */
int n_rows, /**< Number of rows in output tensor. */
int n_cols, /**< Number of columns in output tensor. */
- int n_channels /**< Number of channels in output tensor. */
+ int n_channels, /**< Number of channels in output tensor. */
+ const arm_gemm::Activation &activation
);
/** Set pointers to the output tensor written by the transform. */
@@ -528,79 +534,84 @@ class WinogradGEMM
typedef TIn InputType;
/** Get the output shape of a convolution. */
- static Tensor4DShape get_output_shape(
- const KernelShape &kernel_shape,
- const Tensor4DShape &in_shape,
- const PaddingType padding
- );
-
- /* Get the memory required to transform the kernel.
- */
- static size_t get_kernel_transform_working_size(const KernelShape &shape);
+ static std::pair<unsigned int, unsigned int> get_output_shape(
+ const std::pair<unsigned int, unsigned int> input_shape,
+ bool padding_same);
/** Get the memory required to store the kernel transformed into the
* Winograd domain.
*/
- static size_t get_kernel_storage_size(const KernelShape &shape);
+ static size_t get_kernel_storage_size(unsigned int n_input_channels,
+ unsigned int n_output_channels);
/** Get the memory required to store the input tensor transformed into
* the Winograd domain.
*/
static size_t get_input_storage_size(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding_type
- );
+ unsigned int n_batches, // Number of batches
+ unsigned int n_rows, // Number of input rows
+ unsigned int n_cols, // Number of input columns
+ unsigned int n_channels, // Number of input channels
+ bool padding_same);
/** Get the memory required to store the output tensor in the Winograd
* domain.
*/
static size_t get_output_storage_size(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding_type
- );
+ unsigned int n_batches, // Number of batches
+ unsigned int n_rows, // Number of output rows
+ unsigned int n_cols, // Number of output columns
+ unsigned int n_channels // Number of output channels
+ );
/** Get the memory required to apply a Winograd operator to some input.
*/
static size_t get_working_space_size(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding_type
- );
+ unsigned int n_batches,
+ unsigned int n_rows, // Number of input rows
+ unsigned int n_cols, // Number of input columns
+ unsigned int n_input_channels, // Number of input channels
+ unsigned int n_output_channels, // Number of output channels
+ bool padding_same);
/* Get the memory required by a single "input" matrix.
*/
static size_t get_input_matrix_size(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding_type
- );
+ unsigned int n_batches, // Number of batches
+ unsigned int n_rows, // Number of input rows
+ unsigned int n_cols, // Number of input columns
+ unsigned int n_channels, // Number of input channels
+ bool padding_same);
static int get_input_matrix_stride(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding_type
- );
+ unsigned int n_batches, // Number of batches
+ unsigned int n_rows, // Number of input rows
+ unsigned int n_cols, // Number of input columns
+ unsigned int n_channels, // Number of input channels
+ bool padding_same);
/* Get the memory required by a single "output" matrix.
*/
static size_t get_output_matrix_size(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding_type
- );
+ unsigned int n_batches, // Number of batches
+ unsigned int n_rows, // Number of output rows
+ unsigned int n_cols, // Number of output columns
+ unsigned int n_channels // Number of output channels
+ );
static int get_output_matrix_stride(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding_type
- );
+ unsigned int n_batches, // Number of batches
+ unsigned int n_rows, // Number of output rows
+ unsigned int n_cols, // Number of output columns
+ unsigned int n_channels // Number of output channels
+ );
/* Get the memory required by a single "kernel" matrix.
*/
- static size_t get_kernel_matrix_size(const KernelShape &shape);
- static int get_kernel_matrix_stride(const KernelShape &shape);
+ static size_t get_kernel_matrix_size(unsigned int n_input_channels,
+ unsigned int n_output_channels);
+ static int get_kernel_matrix_stride(unsigned int n_input_channels,
+ unsigned int n_output_channels);
static constexpr int M_BLOCK = 4; /** Size of block used by GEMM. */
static constexpr int N_BLOCK = 16; /** Size of block used by GEMM. */
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_layer.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_layer.hpp
index 9d418bebb4..ed8fede385 100644
--- a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_layer.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_layer.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,9 +23,6 @@
*/
#pragma once
-
-#include <utility>
-
#include "arm_gemm_local.hpp"
#include "arm_gemm.hpp"
#include "winograd.hpp"
@@ -42,8 +39,8 @@ class IWinogradConvolutionLayer
virtual unsigned int weight_transform_get_window(void) const = 0;
virtual void weight_transform_run(unsigned int start, unsigned int stop) = 0;
- virtual ITransform& input_transform(void) = 0; // Expose the input transform
- virtual ITransform& output_transform(void) = 0; // Expose the output transform
+ virtual IInputTransform& input_transform(void) = 0; // Expose the input transform
+ virtual IOutputTransform& output_transform(void) = 0; // Expose the output transform
virtual arm_gemm::IGemmCommon *gemm(void) = 0; // Expose the underlying GEMM
};
@@ -65,15 +62,18 @@ template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols
WinogradRoots Roots>
class WinogradConvolutionLayer : public IWinogradConvolutionLayer
{
+ public:
+ using WinogradBase = winograd::WinogradGEMM<OutputTileRows, OutputTileCols, KernelRows, KernelCols, Roots>;
+ using WeightsTransform = typename WinogradBase::template WeightsTransform<TIn, TInGEMM>;
+ using InputTransform = typename WinogradBase::template InputTransform<TIn, TInGEMM>;
+ using WinogradConv = typename WinogradBase::template Convolution<TOut, TIn, TInGEMM, TOutGEMM>;
+ using OutputTransform = typename WinogradBase::template OutputTransform<TOutGEMM, TOut>;
+
private:
static constexpr int InnerTileRows = OutputTileRows + KernelRows - 1;
static constexpr int InnerTileCols = OutputTileCols + KernelCols - 1;
static constexpr int N_GEMMS = InnerTileRows * InnerTileCols;
- const KernelShape _kernel_shape;
- const Tensor4DShape _input_shape;
- const PaddingType _padding;
- const Tensor4DShape _output_shape;
const int _n_output_rows, _n_output_cols;
const int _kernel_matrix_stride, _kernel_matrix_row_stride;
const int _input_matrix_stride, _input_matrix_row_stride;
@@ -81,19 +81,14 @@ class WinogradConvolutionLayer : public IWinogradConvolutionLayer
const int _tile_rows, _tile_cols;
const int _m, _k, _n;
- public:
- using WinogradBase = winograd::WinogradGEMM<OutputTileRows, OutputTileCols, KernelRows, KernelCols, Roots>;
- using WeightsTransform = typename WinogradBase::template WeightsTransform<TIn, TInGEMM>;
- using InputTransform = typename WinogradBase::template InputTransform<TIn, TInGEMM>;
- using WinogradConv = typename WinogradBase::template Convolution<TOut, TIn, TInGEMM, TOutGEMM>;
- using OutputTransform = typename WinogradBase::template OutputTransform<TOutGEMM, TOut>;
-
- /* Public member variables. */
WeightsTransform weights_transform; /** Operator to transform weights to Winograd domain. */
InputTransform _input_transform; /** Operator to transform input to Winograd domain. */
+ const arm_gemm::GemmArgs gemm_args;
arm_gemm::UniqueGemmCommon<TInGEMM, TOutGEMM> gemms; /** Operator to perform multiple GEMMs. */
OutputTransform _output_transform; /** Operator to transform output from Winograd domain. */
+ public:
+
/** Determine how much memory (in units of TIn) to allocate for the
* transformed weights.
*/
@@ -186,6 +181,7 @@ class WinogradConvolutionLayer : public IWinogradConvolutionLayer
const int n_input_cols, /** Number of columns in a feature map of the input tensor. */
const int n_output_channels, /** Number of feature maps in the output tensor. */
const bool same_padding, /** Use "SAME" padding, otherwise use "VALID". */
+ const arm_gemm::Activation &activation,
const TIn* const weights, /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */
TInGEMM* const weights_storage, /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */
const TIn* const input, /** Pointer to NHWC ordered input tensor, in the spatial domain. */
@@ -201,8 +197,8 @@ class WinogradConvolutionLayer : public IWinogradConvolutionLayer
unsigned int weight_transform_get_window(void) const;
void weight_transform_run(const unsigned int start, const unsigned int stop);
- ITransform& input_transform(void);
- ITransform& output_transform(void);
+ IInputTransform& input_transform(void);
+ IOutputTransform& output_transform(void);
/* Get a pointer to the GEMM underlying the Winograd transform. */
arm_gemm::IGemmCommon *gemm(void);
diff --git a/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
index 263ded0b84..fda384bc62 100644
--- a/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
@@ -28,6 +28,7 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
@@ -233,7 +234,7 @@ unsigned int NEWinogradLayerTransformWeightsKernel<T, OutputTileRows, OutputTile
const KernelShape shape(num_output_channels, KernelRows, KernelCols, num_input_channels);
return static_cast<unsigned int>(
// WinogradConv returns the size in bytes, we divide by `sizeof(T)` to express that in units of T
- WinogradConv::get_kernel_storage_size(shape) / sizeof(T));
+ WinogradConv::get_kernel_storage_size(num_input_channels, num_output_channels) / sizeof(T));
}
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
@@ -243,9 +244,9 @@ NEWinogradLayerTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelR
}
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
-int NEWinogradLayerTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_matrix_stride(const KernelShape &kernel_shape) const
+int NEWinogradLayerTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_matrix_stride(int num_output_channels, int num_input_channels) const
{
- return WinogradConv::get_kernel_matrix_stride(kernel_shape);
+ return WinogradConv::get_kernel_matrix_stride(num_input_channels, num_output_channels);
}
#ifndef DOXYGEN_SKIP_THIS
@@ -325,9 +326,8 @@ unsigned int NEWinogradLayerTransformInputKernel<T, OutputTileRows, OutputTileCo
// Construct shapes for the input and kernel tensors.
const Tensor4DShape input_shape(num_batches, num_rows, num_cols, num_channels);
const KernelShape kern_shape(1, KernelRows, KernelCols, num_channels);
- const PaddingType padding = (same_padding) ? PADDING_SAME : PADDING_VALID;
// Return the size, converted into units of TIn
- return static_cast<unsigned int>(WinogradConv::get_input_storage_size(kern_shape, input_shape, padding) / sizeof(T));
+ return static_cast<unsigned int>(WinogradConv::get_input_storage_size(num_batches, num_rows, num_cols, num_channels, same_padding) / sizeof(T));
}
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
@@ -338,9 +338,13 @@ unsigned int NEWinogradLayerTransformInputKernel<T, OutputTileRows, OutputTileCo
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
int NEWinogradLayerTransformInputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_matrix_stride(
- const KernelShape &kernel_shape, const Tensor4DShape &input_shape, const PaddingType padding_type) const
+ int num_batches, /* Number of batches in the input tensor. */
+ int num_channels, /* Number of feature maps in the input tensor. */
+ int num_rows, /* Number of rows in each feature map. */
+ int num_cols, /* Number of columns in each feature map. */
+ bool same_padding /* Use "SAME" padding, otherwise use "VALID". */) const
{
- return WinogradConv::get_input_matrix_stride(kernel_shape, input_shape, padding_type);
+ return WinogradConv::get_input_matrix_stride(num_batches, num_rows, num_cols, num_channels, same_padding);
}
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
@@ -446,21 +450,18 @@ template class NEWinogradLayerTransformInputKernel<float, 2, 1, 7, 1>;
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
unsigned int NEWinogradLayerTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_output_storage_size(
- int num_batches, /* Number of batches in the output tensor. */
- int num_rows, /* Number of rows in each feature map of the input tensor. */
- int num_cols, /* Number of columns in each feature map of the input tensor. */
- int num_output_channels, /* Number of feature maps in the output tensor. */
- bool same_padding /* Use "SAME" padding, otherwise use "VALID". */
+ int num_batches, /* Number of batches in the output tensor. */
+ int num_rows, /* Number of rows in each feature map of the input tensor. */
+ int num_cols, /* Number of columns in each feature map of the input tensor. */
+ int num_output_channels /* Number of feature maps in the output tensor. */
) const
{
// Construct shapes for the input and kernel tensors.
const Tensor4DShape input_shape(num_batches, num_rows, num_cols, 1);
const KernelShape kern_shape(num_output_channels, KernelRows, KernelCols, 1);
- const PaddingType padding = (same_padding) ? PADDING_SAME : PADDING_VALID;
-
// Return the size, converted into units of TOut
return static_cast<unsigned int>(
- WinogradConv::get_output_storage_size(kern_shape, input_shape, padding) / sizeof(T));
+ WinogradConv::get_output_storage_size(num_batches, num_rows, num_cols, num_output_channels) / sizeof(T));
}
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
@@ -478,28 +479,36 @@ unsigned int NEWinogradLayerTransformOutputKernel<T, OutputTileRows, OutputTileC
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
int NEWinogradLayerTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_matrix_stride(
- const KernelShape &kernel_shape, const Tensor4DShape &input_shape, const PaddingType padding_type) const
+ int num_batches, /* Number of batches in the output tensor. */
+ int num_rows, /* Number of rows in each feature map of the input tensor. */
+ int num_cols, /* Number of columns in each feature map of the input tensor. */
+ int num_output_channels /* Number of feature maps in the output tensor. */
+) const
{
- return WinogradConv::get_output_matrix_stride(kernel_shape, input_shape, padding_type);
+ return WinogradConv::get_output_matrix_stride(num_batches, num_rows, num_cols, num_output_channels);
}
+
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
-Tensor4DShape NEWinogradLayerTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_output_shape(
- const KernelShape &kernel_shape, const Tensor4DShape &in_shape, const PaddingType padding) const
+std::pair<unsigned int, unsigned int> NEWinogradLayerTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_output_shape(
+ int num_rows, /* Number of rows in each feature map of the input tensor. */
+ int num_cols, /* Number of columns in each feature map of the input tensor. */
+ bool padding_same) const
{
- return WinogradConv::get_output_shape(kernel_shape, in_shape, padding);
+ return WinogradConv::get_output_shape(std::make_pair<unsigned int, unsigned int>(num_rows, num_cols), padding_same);
}
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
void NEWinogradLayerTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::configure(
- const ITensor *biases,
- const ITensor *transformed_output,
- const int matrix_stride,
- ITensor *output_nhwc,
- const int num_batches,
- const int num_rows,
- const int num_cols,
- const int num_channels,
- ITensor *workspace)
+ const ITensor *biases,
+ const ITensor *transformed_output,
+ const int matrix_stride,
+ ITensor *output_nhwc,
+ const int num_batches,
+ const int num_rows,
+ const int num_cols,
+ const int num_channels,
+ ITensor *workspace,
+ const arm_gemm::Activation &activation)
{
_biases = biases;
_workspace = workspace;
@@ -512,7 +521,7 @@ void NEWinogradLayerTransformOutputKernel<T, OutputTileRows, OutputTileCols, Ker
_num_cols = num_cols;
_num_channels = num_channels;
// We don't have the biases buffer at this stage as it hasn't been allocated, we pass in nullptr OutputTransform is only used here to compute the window
- _transform = arm_compute::support::cpp14::make_unique<OutputTransform>(num_batches, num_rows, num_cols, num_channels);
+ _transform = arm_compute::support::cpp14::make_unique<OutputTransform>(num_batches, num_rows, num_cols, num_channels, activation);
Window win;
auto win_last = _transform->get_window();
win.set(Window::DimX, Window::Dimension(0, win_last, 1));
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd.cpp b/src/core/NEON/kernels/convolution/winograd/winograd.cpp
index 226f303c7d..a4eb9fce59 100644
--- a/src/core/NEON/kernels/convolution/winograd/winograd.cpp
+++ b/src/core/NEON/kernels/convolution/winograd/winograd.cpp
@@ -21,205 +21,147 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
#include <cstring>
+#include "utils.hpp"
#include "winograd.hpp"
+
using namespace winograd;
+using array2 = std::pair<unsigned int, unsigned int>;
-/** Get the output shape of a convolution. */
-template <int kr, int kc, int itr, int itc, WinogradRoots R>
-template <typename TOut, typename TIn, typename TInGEMM, typename TOutGEMM>
-Tensor4DShape WinogradGEMM<kr, kc, itr, itc, R>::Convolution<TOut, TIn, TInGEMM, TOutGEMM>::get_output_shape(
- const KernelShape &kernel_shape,
- const Tensor4DShape &in_shape,
- const PaddingType padding
-)
-{
- return Tensor4DShape {
- in_shape.n_batches,
- (padding == PADDING_SAME) ? in_shape.n_rows : in_shape.n_rows - (kernel_rows - 1),
- (padding == PADDING_SAME) ? in_shape.n_cols : in_shape.n_cols - (kernel_cols - 1),
- kernel_shape.n_output_channels,
- in_shape.ordering
- };
-}
+#define MEMBERFN(RTYPE) \
+ template <int output_tile_rows, int output_tile_cols, int kernel_rows, \
+ int kernel_cols, WinogradRoots roots> \
+ template <typename TOut, typename TIn, typename TGEMMIn, typename TGEMMOut> \
+ RTYPE WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, \
+ kernel_cols, \
+ roots>::Convolution<TOut, TIn, TGEMMIn, TGEMMOut>
-/* Get the memory required to transform the kernel.
- */
-template <int kernel_rows, int kernel_cols,
- int output_tile_rows, int output_tile_cols, WinogradRoots roots>
-template <typename TOut, typename TIn, typename TGIn, typename TGOut>
-size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols, roots>::Convolution<TOut, TIn, TGIn, TGOut>::get_kernel_transform_working_size(const KernelShape &shape)
-{
- if (shape.ordering == HWIO)
- {
- // Kernel is already in the correct order, so no additional memory is
- // required.
- return 0;
- }
- else
- {
- // Need to re-order the kernel into HWIO form, require enough space to
- // represent the tensor.
- return sizeof(TIn) * shape.size();
- }
+/** Get the output shape of a convolution. */
+MEMBERFN(array2)
+::get_output_shape(const std::pair<unsigned int, unsigned int> input_shape,
+ const bool padding_same) {
+ const unsigned int n_rows =
+ padding_same ? input_shape.first : input_shape.first - (kernel_rows - 1);
+ const unsigned int n_cols = padding_same
+ ? input_shape.second
+ : input_shape.second - (kernel_cols - 1);
+ return {n_rows, n_cols};
}
/** Get the memory required to store the kernel transformed into the
* Winograd domain.
*/
-template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols, WinogradRoots roots>
-template <typename TOut, typename TIn, typename TGIn, typename TGOut>
-size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols, roots>::Convolution<TOut, TIn, TGIn, TGOut>::get_kernel_storage_size(const KernelShape &shape)
-{
- return N_GEMMS * get_kernel_matrix_size(shape);
+MEMBERFN(size_t)
+::get_kernel_storage_size(const unsigned int n_input_channels,
+ const unsigned int n_output_channels) {
+ return N_GEMMS * get_kernel_matrix_size(n_input_channels, n_output_channels);
}
-
-template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols, WinogradRoots roots>
-template <typename TOut, typename TIn, typename TGIn, typename TGOut>
-size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols, roots>::Convolution<TOut, TIn, TGIn, TGOut>::get_input_storage_size(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding
-)
-{
- return N_GEMMS * get_input_matrix_size(kernel_shape, input_shape, padding);
+MEMBERFN(size_t)
+::get_input_storage_size(const unsigned int n_batches,
+ const unsigned int n_rows, const unsigned int n_cols,
+ const unsigned int n_channels,
+ const bool same_padding) {
+ return N_GEMMS * get_input_matrix_size(n_batches, n_rows, n_cols, n_channels,
+ same_padding);
}
-
-template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols, WinogradRoots roots>
-template <typename TOut, typename TIn, typename TGIn, typename TGOut>
-size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols, roots>::Convolution<TOut, TIn, TGIn, TGOut>::get_output_storage_size(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding
-)
-{
- return N_GEMMS * get_output_matrix_size(kernel_shape, input_shape, padding);
+MEMBERFN(size_t)
+::get_output_storage_size(const unsigned int n_batches,
+ const unsigned int n_rows, const unsigned int n_cols,
+ const unsigned int n_channels) {
+ return N_GEMMS *
+ get_output_matrix_size(n_batches, n_rows, n_cols, n_channels);
}
-
/** Get the memory required to apply a Winograd operator to some input.
*/
-template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols, WinogradRoots roots>
-template <typename TOut, typename TIn, typename TGIn, typename TGOut>
-size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols, roots>::Convolution<TOut, TIn, TGIn, TGOut>::get_working_space_size(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding_type
-)
-{
- const auto output_shape = get_output_shape(kernel_shape, input_shape, padding_type);
+MEMBERFN(size_t)
+::get_working_space_size(const unsigned int n_batches,
+ const unsigned int n_rows, const unsigned int n_cols,
+ const unsigned int n_input_channels,
+ const unsigned int n_output_channels,
+ const bool padding_same) {
+ const auto output_shape = get_output_shape({n_rows, n_cols}, padding_same);
// Get the memory required to store the matrices
- const size_t matrix_sizes = N_GEMMS * (
- get_input_matrix_size(kernel_shape, input_shape, padding_type) +
- get_output_matrix_size(kernel_shape, input_shape, padding_type)
- );
-
- // Add additional space to re-order the input and output if the input tensor
- // is not in NHWC format.
- if (input_shape.ordering == NHWC)
- {
- return matrix_sizes; // No extra spacing required
- }
- else // NCHW, must reorder the input and output tensors
- {
- // We only need to re-order the input or output at any one time, so request
- // enough memory to do the largest of these.
- const size_t extra_memory = std::max(
- sizeof(TIn) * input_shape.size(),
- sizeof(TOut) * output_shape.size()
- );
- return matrix_sizes + extra_memory;
- }
+ const size_t matrix_sizes =
+ N_GEMMS *
+ (get_input_matrix_size(n_batches, n_rows, n_cols, n_input_channels,
+ padding_same) +
+ get_output_matrix_size(n_batches, output_shape.first,
+ output_shape.second, n_output_channels));
+ return matrix_sizes;
}
-
/* Get the memory required by a single "input" matrix.
*/
-template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols, WinogradRoots roots>
-template <typename TOut, typename TIn, typename TGIn, typename TGOut>
-size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols, roots>::Convolution<TOut, TIn, TGIn, TGOut>::get_input_matrix_size(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding_type
-)
-{
- return get_input_matrix_stride(kernel_shape, input_shape, padding_type) * sizeof(TGIn);
+MEMBERFN(size_t)
+::get_input_matrix_size(const unsigned int n_batches, const unsigned int n_rows,
+ const unsigned int n_cols,
+ const unsigned int n_channels,
+ const bool same_padding) {
+ return get_input_matrix_stride(n_batches, n_rows, n_cols, n_channels,
+ same_padding) *
+ sizeof(TGEMMIn);
}
-template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols, WinogradRoots roots>
-template <typename TOut, typename TIn, typename TGIn, typename TGOut>
-int WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols, roots>::Convolution<TOut, TIn, TGIn, TGOut>::get_input_matrix_stride(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding_type
-)
-{
- // Compute shape for the GEMM
- const auto output_shape = get_output_shape(kernel_shape, input_shape, padding_type);
- const int tile_rows = iceildiv(output_shape.n_rows, output_tile_rows);
- const int tile_cols = iceildiv(output_shape.n_cols, output_tile_cols);
- const int M = roundup(input_shape.n_batches * tile_rows * tile_cols, M_BLOCK);
- const int K = kernel_shape.n_input_channels;
+MEMBERFN(int)
+::get_input_matrix_stride(const unsigned int n_batches, const unsigned int n_rows,
+ const unsigned int n_cols,
+ const unsigned int n_channels,
+ const bool same_padding) {
+ const auto output_shape = get_output_shape({n_rows, n_cols}, same_padding);
+ const unsigned int tile_rows = iceildiv(output_shape.first, output_tile_rows);
+ const unsigned int tile_cols =
+ iceildiv(output_shape.second, output_tile_cols);
+ const unsigned int M =
+ roundup<unsigned int>(n_batches * tile_rows * tile_cols, M_BLOCK);
+ const unsigned int K = n_channels;
return M * K;
}
-
/* Get the memory required by a single "output" matrix.
*/
-template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols, WinogradRoots roots>
-template <typename TOut, typename TIn, typename TGIn, typename TGOut>
-size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols, roots>::Convolution<TOut, TIn, TGIn, TGOut>::get_output_matrix_size(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding_type
-)
-{
- return get_output_matrix_stride(kernel_shape, input_shape, padding_type) * sizeof(TGOut);
+MEMBERFN(size_t)
+::get_output_matrix_size(const unsigned int n_batches,
+ const unsigned int n_rows, const unsigned int n_cols,
+ const unsigned int n_channels) {
+ return get_output_matrix_stride(n_batches, n_rows, n_cols, n_channels) *
+ sizeof(TGEMMOut);
}
-
-template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols, WinogradRoots roots>
-template <typename TOut, typename TIn, typename TGIn, typename TGOut>
-int WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols, roots>::Convolution<TOut, TIn, TGIn, TGOut>::get_output_matrix_stride(
- const KernelShape &kernel_shape,
- const Tensor4DShape &input_shape,
- const PaddingType padding_type
-)
-{
+MEMBERFN(int)
+::get_output_matrix_stride(const unsigned int n_batches,
+ const unsigned int n_rows, const unsigned int n_cols,
+ const unsigned int n_channels) {
// Compute shape for the GEMM
- const auto output_shape = get_output_shape(kernel_shape, input_shape, padding_type);
- const int tile_rows = iceildiv(output_shape.n_rows, output_tile_rows);
- const int tile_cols = iceildiv(output_shape.n_cols, output_tile_cols);
- const int M = roundup(tile_rows * tile_cols, M_BLOCK);
- const int N = roundup(kernel_shape.n_output_channels, N_BLOCK);
+ const int tile_rows = iceildiv(n_rows, output_tile_rows);
+ const int tile_cols = iceildiv(n_cols, output_tile_cols);
+ const int M = roundup<int>(tile_rows * tile_cols, M_BLOCK);
+ const int N = roundup<int>(n_channels, N_BLOCK);
- return input_shape.n_batches * M * N;
+ return n_batches * M * N;
}
/* Get the memory required by a single "kernel" matrix.
*/
-template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols, WinogradRoots roots>
-template <typename TOut, typename TIn, typename TGIn, typename TGOut>
-size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols, roots>::Convolution<TOut, TIn, TGIn, TGOut>::get_kernel_matrix_size(const KernelShape &shape)
-{
- return sizeof(TGIn) * get_kernel_matrix_stride(shape);
+MEMBERFN(size_t)
+::get_kernel_matrix_size(const unsigned int n_input_channels,
+ const unsigned int n_output_channels) {
+ return sizeof(TGEMMIn) *
+ get_kernel_matrix_stride(n_input_channels, n_output_channels);
}
-template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols, WinogradRoots roots>
-template <typename TOut, typename TIn, typename TGIn, typename TGOut>
-int WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols, roots>::Convolution<TOut, TIn, TGIn, TGOut>::get_kernel_matrix_stride(const KernelShape &shape)
-{
- const int K = shape.n_input_channels;
- const int N = roundup(shape.n_output_channels, N_BLOCK);
- return K * N;
+MEMBERFN(int)
+::get_kernel_matrix_stride(const unsigned int n_input_channels,
+ const unsigned int n_output_channels) {
+ return n_input_channels * roundup<int>(n_output_channels, N_BLOCK);
}
-
// Instantiate required implementations
template class WinogradGEMM<2, 2, 3, 3, WinogradRoots::Integers>::Convolution<float, float, float, float>;
template class WinogradGEMM<4, 4, 3, 3, WinogradRoots::Integers>::Convolution<float, float, float, float>;
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp
index fcbd21fcd0..8e4bebcd20 100644
--- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,8 +24,11 @@
#pragma once
-#include "winograd.hpp"
+#include <algorithm>
+
#include "padding.hpp"
+#include "utils.hpp"
+#include "winograd.hpp"
#define MEMBERFN(RTYPE) template <\
int InnerTileRows, int InnerTileCols,\
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp
index d97af21a43..fe47ccbde9 100644
--- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,24 +41,30 @@
namespace winograd
{
-MEMBERFN()::OutputTransform(
- const int n_batches,
- const int n_rows,
- const int n_cols,
- const int n_channels
-) : _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels),
- _matrix_base(nullptr),
- _biases(nullptr),
- _matrix_stride(0), _matrix_row_stride(0), _matrix_batch_stride(0),
- _outptr(nullptr),
- _tiles_M(iceildiv(n_rows, output_tile_rows)),
- _tiles_N(iceildiv(n_cols, output_tile_cols)),
- _out_col_stride(0), _out_row_stride(0), _out_batch_stride(0),
- _working_space_col_stride(n_channels),
- _working_space_row_stride(output_tile_cols * _working_space_col_stride),
- _working_space(nullptr)
-{
-}
+MEMBERFN()
+::OutputTransform(const int n_batches, const int n_rows, const int n_cols,
+ const int n_channels, const arm_gemm::Activation &activation)
+ : _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols),
+ _n_channels(n_channels),
+ _output_min((activation.type == arm_gemm::Activation::Type::ReLU ||
+ activation.type == arm_gemm::Activation::Type::BoundedReLU)
+ ? static_cast<TOut>(0.0f)
+ : (std::numeric_limits<TOut>::has_infinity)
+ ? -std::numeric_limits<TOut>::infinity()
+ : std::numeric_limits<TOut>::lowest()),
+ _output_max((activation.type == arm_gemm::Activation::Type::BoundedReLU)
+ ? static_cast<TOut>(activation.param1)
+ : (std::numeric_limits<TOut>::has_infinity)
+ ? std::numeric_limits<TOut>::infinity()
+ : std::numeric_limits<TOut>::max()),
+ _matrix_base(nullptr), _biases(nullptr), _matrix_stride(0),
+ _matrix_row_stride(0), _matrix_batch_stride(0), _outptr(nullptr),
+ _tiles_M(iceildiv(n_rows, output_tile_rows)),
+ _tiles_N(iceildiv(n_cols, output_tile_cols)), _out_col_stride(0),
+ _out_row_stride(0), _out_batch_stride(0),
+ _working_space_col_stride(n_channels),
+ _working_space_row_stride(output_tile_cols * _working_space_col_stride),
+ _working_space(nullptr) {}
MEMBERFN(void)::set_input_matrices(const void * const mptr, const int ldmatrix, const int ldrow)
{
@@ -100,9 +106,10 @@ Nx1MEMBERFN()::OutputTransform(
const int n_batches,
const int n_rows,
const int n_cols,
- const int n_channels
+ const int n_channels,
+ const arm_gemm::Activation &activation
) : OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>::OutputTransform(
- n_batches, n_cols, n_rows, n_channels /* Transpose rows and columns */
+ n_batches, n_cols, n_rows, n_channels, activation /* Transpose rows and columns */
)
{
}
@@ -212,7 +219,8 @@ MEMBERFN(void)::transform_uncropped_tile(
{
transform_tile(
n_channels, inptr, _matrix_stride, biases,
- outptr, _out_row_stride, _out_col_stride
+ outptr, _out_row_stride, _out_col_stride,
+ _output_min, _output_max
);
}
@@ -230,7 +238,8 @@ MEMBERFN(void)::transform_cropped_tile(
TOut *wsptr = static_cast<TOut *>(get_working_space(threadid));
transform_tile(
n_channels, inptr, _matrix_stride, biases,
- wsptr, _working_space_row_stride, _working_space_col_stride
+ wsptr, _working_space_row_stride, _working_space_col_stride,
+ _output_min, _output_max
);
padding::crop_and_copy_tile(
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp
index c32d7f2f58..f231bdd67e 100644
--- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,7 +36,9 @@ void OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::transfo
const float* bptr,
float* const output,
const int, // No need to stride across rows
- const int output_col_stride
+ const int output_col_stride,
+ const float output_min,
+ const float output_max
)
{
// Construct a map to the output cells
@@ -72,7 +74,9 @@ void OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::transfo
}
for (int j = 0; j < output_tile_cols; j++)
{
- vst1q_f32(outptrs[j], f[j] + b);
+ const auto y = vminq_f32(vmaxq_f32(f[j] + b, vdupq_n_f32(output_min)),
+ vdupq_n_f32(output_max));
+ vst1q_f32(outptrs[j], y);
outptrs[j] += 4;
}
}
@@ -99,7 +103,9 @@ void OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::transfo
}
for (int j = 0; j < output_tile_cols; j++)
{
- vst1_f32(outptrs[j], f[j] + b);
+ const auto y = vmin_f32(vmax_f32(f[j] + b, vdup_n_f32(output_min)),
+ vdup_n_f32(output_max));
+ vst1_f32(outptrs[j], y);
outptrs[j] += 2;
}
}
@@ -126,7 +132,7 @@ void OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::transfo
}
for (int j = 0; j < output_tile_cols; j++)
{
- *(outptrs[j]++) = f[j] + b;
+ *(outptrs[j]++) = std::max(std::min(f[j] + b, output_max), output_min);
}
}
}
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp
index d6ebf44f41..5136bc15c4 100644
--- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,7 +36,9 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo
const float* bptr,
float* const output,
const int output_row_stride,
- const int output_col_stride
+ const int output_col_stride,
+ const float output_min,
+ const float output_max
)
{
// Construct a map to the output cells
@@ -103,7 +105,10 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo
{
for (int j = 0; j < output_tile_cols; j++)
{
- vst1q_f32(outptrs[i][j], vaddq_f32(f[i][j], b));
+ const auto y =
+ vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)),
+ vdupq_n_f32(output_min));
+ vst1q_f32(outptrs[i][j], y);
outptrs[i][j] += 4;
}
}
@@ -161,7 +166,10 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo
{
for (int j = 0; j < output_tile_cols; j++)
{
- vst1_f32(outptrs[i][j], vadd_f32(f[i][j], b));
+ const auto y =
+ vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)),
+ vdup_n_f32(output_min));
+ vst1_f32(outptrs[i][j], y);
outptrs[i][j] += 2;
}
}
@@ -211,7 +219,8 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo
{
for (int j = 0; j < output_tile_cols; j++)
{
- *(outptrs[i][j]++) = f[i][j] + b;
+ const auto y = std::max(std::min(f[i][j] + b, output_max), output_min);
+ *(outptrs[i][j]++) = y;
}
}
}
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp
index d93d9e234a..0f911f14a3 100644
--- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,7 +36,9 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo
const float* bptr,
float* const output,
const int output_row_stride,
- const int output_col_stride
+ const int output_col_stride,
+ const float output_min,
+ const float output_max
)
{
// Construct a map to the output cells
@@ -101,7 +103,10 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo
{
for (int j = 0; j < output_tile_cols; j++)
{
- vst1q_f32(outptrs[i][j], vaddq_f32(f[i][j], b));
+ const auto y =
+ vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)),
+ vdupq_n_f32(output_min));
+ vst1q_f32(outptrs[i][j], y);
outptrs[i][j] += 4;
}
}
@@ -157,7 +162,10 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo
{
for (int j = 0; j < output_tile_cols; j++)
{
- vst1_f32(outptrs[i][j], vadd_f32(f[i][j], b));
+ const auto y =
+ vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)),
+ vdup_n_f32(output_min));
+ vst1_f32(outptrs[i][j], y);
outptrs[i][j] += 2;
}
}
@@ -205,7 +213,8 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo
{
for (int j = 0; j < output_tile_cols; j++)
{
- *(outptrs[i][j]++) = f[i][j] + b;
+ const auto y = std::max(std::min(f[i][j] + b, output_max), output_min);
+ *(outptrs[i][j]++) = y;
}
}
}
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp
index 7187ef2d20..49a3f41182 100644
--- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,7 +36,9 @@ void OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::transfo
const float* bptr,
float* const output,
const int, // No need to stride across rows
- const int output_col_stride
+ const int output_col_stride,
+ const float output_min,
+ const float output_max
)
{
// Construct a map to the output cells
@@ -74,7 +76,10 @@ void OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::transfo
}
for (int j = 0; j < output_tile_cols; j++)
{
- vst1q_f32(outptrs[j], f[j] + b);
+ const auto y =
+ vmaxq_f32(vminq_f32(vaddq_f32(f[j], b), vdupq_n_f32(output_max)),
+ vdupq_n_f32(output_min));
+ vst1q_f32(outptrs[j], y);
outptrs[j] += 4;
}
}
@@ -103,7 +108,10 @@ void OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::transfo
}
for (int j = 0; j < output_tile_cols; j++)
{
- vst1_f32(outptrs[j], f[j] + b);
+ const auto y =
+ vmax_f32(vmin_f32(vadd_f32(f[j], b), vdup_n_f32(output_max)),
+ vdup_n_f32(output_min));
+ vst1_f32(outptrs[j], y);
outptrs[j] += 2;
}
}
@@ -132,7 +140,8 @@ void OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::transfo
}
for (int j = 0; j < output_tile_cols; j++)
{
- *(outptrs[j]++) = f[j] + b;
+ const auto y = std::max(std::min(f[j] + b, output_max), output_min);
+ *(outptrs[j]++) = y;
}
}
}
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp
index fd16a4df1c..292999c8bc 100644
--- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,1683 +28,6 @@
namespace winograd
{
-#ifdef __aarch64__
-
-template <>
-void OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots::Integers>::transform_tile(
- int n_channels,
- const float* inptr,
- const int matrix_stride,
- const float* bptr,
- float* output,
- const int output_row_stride,
- const int output_col_stride
-)
-{
- const float coeffs[2] = {2.0f, 4.0f};
- if (bptr != nullptr)
- {
- __asm__ __volatile__ (
- "ldr d0, [%[pcoeffs]]\n"
- "add x21, %[in_col_stride1], %[in_col_stride1]\n"
- "add x22, x21, %[in_col_stride1]\n"
- "add x25, %[inptr0], %[in_row_stride]\n"
- "add x15, %[output_col_stride1], %[output_col_stride1]\n"
- "add x23, x22, %[in_col_stride1]\n"
- "add x13, x25, %[in_row_stride]\n"
- "add x16, x15, %[output_col_stride1]\n"
- "add x24, x23, %[in_col_stride1]\n"
- "add x26, x13, %[in_row_stride]\n"
- "add x17, %[outptr0], %[output_row_stride]\n"
- "add x14, x26, %[in_row_stride]\n"
- "add x28, x17, %[output_row_stride]\n"
- "lsr x19, %[n_channels], #2\n"
- "add x27, x14, %[in_row_stride]\n"
- "add x18, x28, %[output_row_stride]\n"
- "and x20, %[n_channels], #3\n"
- "cbz x19, 4f\n"
- "1:\n"
- "ldr q19, [%[inptr0]]\n"
- "subs x19, x19, #1\n"
- "ldr q20, [%[inptr0], %[in_col_stride1]]\n"
- "ldr q4, [%[inptr0], x21]\n"
- "fadd v1.4s, v20.4s, v4.4s\n"
- "ldr q17, [%[inptr0], x22]\n"
- "fsub v7.4s, v20.4s, v4.4s\n"
- "ldr q22, [%[inptr0], x23]\n"
- "fadd v5.4s, v17.4s, v22.4s\n"
- "ldr q18, [%[inptr0], x24]\n"
- "fsub v10.4s, v17.4s, v22.4s\n"
- "ldr q25, [x25]\n"
- "fadd v8.4s, v19.4s, v1.4s\n"
- "ldr q12, [x25, %[in_col_stride1]]\n"
- "mov v4.16b, v1.16b\n"
- "ldr q23, [x25, x21]\n"
- "mov v1.16b, v7.16b\n"
- "ldr q9, [x25, x22]\n"
- "fmul v10.4s, v10.4s, v0.s[0]\n"
- "ldr q11, [x25, x23]\n"
- "fadd v8.4s, v8.4s, v5.4s\n"
- "ldr q6, [x25, x24]\n"
- "fmla v4.4s, v5.4s, v0.s[1]\n"
- "fadd v7.4s, v7.4s, v10.4s\n"
- "fmla v1.4s, v10.4s, v0.s[1]\n"
- "fadd v1.4s, v1.4s, v18.4s\n"
- "beq 3f\n"
- "2:\n"
- "fadd v3.4s, v12.4s, v23.4s\n"
- "ldr q2, [x13]\n"
- "fadd v27.4s, v9.4s, v11.4s\n"
- "ldr q21, [x13, %[in_col_stride1]]\n"
- "fsub v16.4s, v12.4s, v23.4s\n"
- "ldr q26, [x13, x21]\n"
- "fsub v9.4s, v9.4s, v11.4s\n"
- "ldr q17, [x13, x22]\n"
- "fadd v14.4s, v25.4s, v3.4s\n"
- "ldr q19, [x13, x23]\n"
- "mov v11.16b, v3.16b\n"
- "ldr q10, [x13, x24]\n"
- "mov v3.16b, v16.16b\n"
- "ldr q15, [x26]\n"
- "fmul v9.4s, v9.4s, v0.s[0]\n"
- "ldr q12, [x26, %[in_col_stride1]]\n"
- "fadd v14.4s, v14.4s, v27.4s\n"
- "ldr q20, [x26, x21]\n"
- "fmla v11.4s, v27.4s, v0.s[1]\n"
- "ldr q24, [x26, x22]\n"
- "fadd v23.4s, v21.4s, v26.4s\n"
- "ldr q29, [x26, x23]\n"
- "fadd v13.4s, v16.4s, v9.4s\n"
- "ldr q5, [x26, x24]\n"
- "fmla v3.4s, v9.4s, v0.s[1]\n"
- "ldr q18, [x14]\n"
- "fadd v30.4s, v17.4s, v19.4s\n"
- "add %[inptr0], %[inptr0], #16\n"
- "fadd v16.4s, v2.4s, v23.4s\n"
- "add x25, x25, #16\n"
- "fsub v21.4s, v21.4s, v26.4s\n"
- "ldr q22, [x14, %[in_col_stride1]]\n"
- "fadd v3.4s, v3.4s, v6.4s\n"
- "ldr q28, [x14, x21]\n"
- "fsub v19.4s, v17.4s, v19.4s\n"
- "add x13, x13, #16\n"
- "fadd v16.4s, v16.4s, v30.4s\n"
- "add x26, x26, #16\n"
- "mov v17.16b, v23.16b\n"
- "subs x19, x19, #1\n"
- "fadd v26.4s, v12.4s, v20.4s\n"
- "fsub v9.4s, v12.4s, v20.4s\n"
- "fmul v19.4s, v19.4s, v0.s[0]\n"
- "ldr q20, [x14, x22]\n"
- "fmla v17.4s, v30.4s, v0.s[1]\n"
- "fadd v25.4s, v24.4s, v29.4s\n"
- "fsub v12.4s, v24.4s, v29.4s\n"
- "fadd v24.4s, v22.4s, v28.4s\n"
- "fadd v23.4s, v15.4s, v26.4s\n"
- "mov v15.16b, v26.16b\n"
- "fsub v22.4s, v22.4s, v28.4s\n"
- "fadd v29.4s, v14.4s, v16.4s\n"
- "fsub v16.4s, v14.4s, v16.4s\n"
- "ldr q28, [x14, x23]\n"
- "fmul v12.4s, v12.4s, v0.s[0]\n"
- "fmla v15.4s, v25.4s, v0.s[1]\n"
- "fadd v23.4s, v23.4s, v25.4s\n"
- "mov v6.16b, v21.16b\n"
- "fadd v30.4s, v21.4s, v19.4s\n"
- "fadd v26.4s, v18.4s, v24.4s\n"
- "mov v25.16b, v24.16b\n"
- "fadd v18.4s, v8.4s, v29.4s\n"
- "fmla v6.4s, v19.4s, v0.s[1]\n"
- "fadd v27.4s, v20.4s, v28.4s\n"
- "fsub v21.4s, v20.4s, v28.4s\n"
- "mov v19.16b, v29.16b\n"
- "fadd v29.4s, v13.4s, v30.4s\n"
- "fsub v8.4s, v13.4s, v30.4s\n"
- "fadd v14.4s, v9.4s, v12.4s\n"
- "fadd v6.4s, v6.4s, v10.4s\n"
- "ldr q20, [x14, x24]\n"
- "fadd v26.4s, v26.4s, v27.4s\n"
- "add x14, x14, #16\n"
- "fmla v9.4s, v12.4s, v0.s[1]\n"
- "ldr q24, [x27]\n"
- "fmul v21.4s, v21.4s, v0.s[0]\n"
- "fmla v25.4s, v27.4s, v0.s[1]\n"
- "fadd v10.4s, v7.4s, v29.4s\n"
- "ldr q2, [%[bptr]]\n"
- "mov v7.16b, v29.16b\n"
- "add %[bptr], %[bptr], #16\n"
- "fadd v9.4s, v9.4s, v5.4s\n"
- "fadd v13.4s, v23.4s, v26.4s\n"
- "fsub v23.4s, v23.4s, v26.4s\n"
- "fadd v27.4s, v11.4s, v17.4s\n"
- "fsub v11.4s, v11.4s, v17.4s\n"
- "fadd v30.4s, v15.4s, v25.4s\n"
- "fsub v15.4s, v15.4s, v25.4s\n"
- "ldr q28, [x27, %[in_col_stride1]]\n"
- "fadd v18.4s, v18.4s, v13.4s\n"
- "fmla v19.4s, v13.4s, v0.s[1]\n"
- "fadd v26.4s, v22.4s, v21.4s\n"
- "mov v12.16b, v22.16b\n"
- "fmul v23.4s, v23.4s, v0.s[0]\n"
- "fadd v17.4s, v4.4s, v27.4s\n"
- "fmul v15.4s, v15.4s, v0.s[0]\n"
- "mov v4.16b, v27.16b\n"
- "fmla v12.4s, v21.4s, v0.s[1]\n"
- "ldr q22, [x27, x21]\n"
- "fadd v18.4s, v18.4s, v2.4s\n"
- "fadd v19.4s, v19.4s, v2.4s\n"
- "fadd v17.4s, v17.4s, v30.4s\n"
- "fmla v4.4s, v30.4s, v0.s[1]\n"
- "fadd v25.4s, v28.4s, v22.4s\n"
- "fsub v27.4s, v28.4s, v22.4s\n"
- "fadd v12.4s, v12.4s, v20.4s\n"
- "ldr q29, [x27, x22]\n"
- "str q18, [%[outptr0]]\n"
- "fadd v22.4s, v16.4s, v23.4s\n"
- "str q19, [x28]\n"
- "fadd v28.4s, v24.4s, v25.4s\n"
- "ldr q30, [x27, x23]\n"
- "fadd v20.4s, v29.4s, v30.4s\n"
- "fsub v18.4s, v29.4s, v30.4s\n"
- "mov v21.16b, v25.16b\n"
- "ldr q25, [x27, x24]\n"
- "fmla v16.4s, v23.4s, v0.s[1]\n"
- "ldr q19, [%[inptr0]]\n"
- "fadd v17.4s, v17.4s, v2.4s\n"
- "add x27, x27, #16\n"
- "fadd v28.4s, v28.4s, v20.4s\n"
- "fmul v18.4s, v18.4s, v0.s[0]\n"
- "fmla v21.4s, v20.4s, v0.s[1]\n"
- "ldr q20, [%[inptr0], %[in_col_stride1]]\n"
- "fadd v22.4s, v22.4s, v2.4s\n"
- "fadd v4.4s, v4.4s, v2.4s\n"
- "str q17, [%[outptr0], x15]\n"
- "mov v24.16b, v27.16b\n"
- "fadd v23.4s, v27.4s, v18.4s\n"
- "fadd v16.4s, v16.4s, v28.4s\n"
- "fadd v13.4s, v14.4s, v26.4s\n"
- "fsub v30.4s, v14.4s, v26.4s\n"
- "str q22, [x17]\n"
- "fmla v24.4s, v18.4s, v0.s[1]\n"
- "str q4, [x28, x15]\n"
- "mov v14.16b, v8.16b\n"
- "fadd v29.4s, v11.4s, v15.4s\n"
- "ldr q4, [%[inptr0], x21]\n"
- "fadd v10.4s, v10.4s, v13.4s\n"
- "ldr q17, [%[inptr0], x22]\n"
- "fadd v24.4s, v24.4s, v25.4s\n"
- "ldr q22, [%[inptr0], x23]\n"
- "fmul v30.4s, v30.4s, v0.s[0]\n"
- "fmla v7.4s, v13.4s, v0.s[1]\n"
- "mov v26.16b, v11.16b\n"
- "fadd v13.4s, v3.4s, v6.4s\n"
- "fsub v3.4s, v3.4s, v6.4s\n"
- "ldr q18, [%[inptr0], x24]\n"
- "fadd v10.4s, v10.4s, v2.4s\n"
- "fadd v29.4s, v29.4s, v2.4s\n"
- "fadd v8.4s, v8.4s, v30.4s\n"
- "fmla v14.4s, v30.4s, v0.s[1]\n"
- "fmla v26.4s, v15.4s, v0.s[1]\n"
- "ldr q25, [x25]\n"
- "fadd v27.4s, v9.4s, v12.4s\n"
- "fadd v1.4s, v1.4s, v13.4s\n"
- "str q10, [%[outptr0], %[output_col_stride1]]\n"
- "fsub v6.4s, v9.4s, v12.4s\n"
- "str q29, [x17, x15]\n"
- "fadd v14.4s, v14.4s, v23.4s\n"
- "fadd v26.4s, v26.4s, v21.4s\n"
- "ldr q12, [x25, %[in_col_stride1]]\n"
- "fadd v1.4s, v1.4s, v27.4s\n"
- "ldr q23, [x25, x21]\n"
- "fmul v6.4s, v6.4s, v0.s[0]\n"
- "ldr q9, [x25, x22]\n"
- "mov v5.16b, v13.16b\n"
- "ldr q11, [x25, x23]\n"
- "mov v13.16b, v3.16b\n"
- "fadd v8.4s, v8.4s, v2.4s\n"
- "fadd v1.4s, v1.4s, v2.4s\n"
- "fadd v7.4s, v7.4s, v2.4s\n"
- "fadd v10.4s, v3.4s, v6.4s\n"
- "fmla v5.4s, v27.4s, v0.s[1]\n"
- "fmla v13.4s, v6.4s, v0.s[1]\n"
- "ldr q6, [x25, x24]\n"
- "str q8, [x17, %[output_col_stride1]]\n"
- "fadd v16.4s, v16.4s, v2.4s\n"
- "str q1, [%[outptr0], x16]\n"
- "fadd v14.4s, v14.4s, v2.4s\n"
- "str q7, [x28, %[output_col_stride1]]\n"
- "fadd v10.4s, v10.4s, v2.4s\n"
- "fadd v13.4s, v13.4s, v24.4s\n"
- "add %[outptr0], %[outptr0], #16\n"
- "str q16, [x18]\n"
- "fadd v5.4s, v5.4s, v2.4s\n"
- "str q14, [x18, %[output_col_stride1]]\n"
- "fadd v26.4s, v26.4s, v2.4s\n"
- "str q10, [x17, x16]\n"
- "fadd v1.4s, v20.4s, v4.4s\n"
- "fadd v13.4s, v13.4s, v2.4s\n"
- "add x17, x17, #16\n"
- "str q5, [x28, x16]\n"
- "fadd v5.4s, v17.4s, v22.4s\n"
- "str q26, [x18, x15]\n"
- "fsub v7.4s, v20.4s, v4.4s\n"
- "fadd v8.4s, v19.4s, v1.4s\n"
- "add x28, x28, #16\n"
- "str q13, [x18, x16]\n"
- "mov v4.16b, v1.16b\n"
- "fsub v10.4s, v17.4s, v22.4s\n"
- "add x18, x18, #16\n"
- "mov v1.16b, v7.16b\n"
- "fadd v8.4s, v8.4s, v5.4s\n"
- "fmla v4.4s, v5.4s, v0.s[1]\n"
- "fmul v10.4s, v10.4s, v0.s[0]\n"
- "fadd v7.4s, v7.4s, v10.4s\n"
- "fmla v1.4s, v10.4s, v0.s[1]\n"
- "fadd v1.4s, v1.4s, v18.4s\n"
- "bne 2b\n"
- "3:\n"
- "fadd v3.4s, v12.4s, v23.4s\n"
- "ldr q2, [x13]\n"
- "fadd v27.4s, v9.4s, v11.4s\n"
- "ldr q21, [x13, %[in_col_stride1]]\n"
- "fsub v16.4s, v12.4s, v23.4s\n"
- "ldr q26, [x13, x21]\n"
- "fsub v9.4s, v9.4s, v11.4s\n"
- "ldr q17, [x13, x22]\n"
- "fadd v14.4s, v25.4s, v3.4s\n"
- "ldr q19, [x13, x23]\n"
- "mov v11.16b, v3.16b\n"
- "ldr q10, [x13, x24]\n"
- "mov v3.16b, v16.16b\n"
- "ldr q15, [x26]\n"
- "fmul v9.4s, v9.4s, v0.s[0]\n"
- "ldr q12, [x26, %[in_col_stride1]]\n"
- "fadd v14.4s, v14.4s, v27.4s\n"
- "ldr q20, [x26, x21]\n"
- "fmla v11.4s, v27.4s, v0.s[1]\n"
- "ldr q24, [x26, x22]\n"
- "fadd v23.4s, v21.4s, v26.4s\n"
- "ldr q29, [x26, x23]\n"
- "fadd v13.4s, v16.4s, v9.4s\n"
- "ldr q5, [x26, x24]\n"
- "fmla v3.4s, v9.4s, v0.s[1]\n"
- "ldr q18, [x14]\n"
- "fadd v30.4s, v17.4s, v19.4s\n"
- "add %[inptr0], %[inptr0], #16\n"
- "fadd v16.4s, v2.4s, v23.4s\n"
- "add x25, x25, #16\n"
- "fsub v21.4s, v21.4s, v26.4s\n"
- "ldr q22, [x14, %[in_col_stride1]]\n"
- "fadd v3.4s, v3.4s, v6.4s\n"
- "ldr q28, [x14, x21]\n"
- "fsub v19.4s, v17.4s, v19.4s\n"
- "add x13, x13, #16\n"
- "fadd v16.4s, v16.4s, v30.4s\n"
- "add x26, x26, #16\n"
- "mov v17.16b, v23.16b\n"
- "fadd v26.4s, v12.4s, v20.4s\n"
- "fsub v9.4s, v12.4s, v20.4s\n"
- "ldr q2, [%[bptr]]\n"
- "fmul v19.4s, v19.4s, v0.s[0]\n"
- "add %[bptr], %[bptr], #16\n"
- "fmla v17.4s, v30.4s, v0.s[1]\n"
- "fadd v25.4s, v24.4s, v29.4s\n"
- "fadd v23.4s, v15.4s, v26.4s\n"
- "fsub v12.4s, v24.4s, v29.4s\n"
- "mov v15.16b, v26.16b\n"
- "fadd v24.4s, v22.4s, v28.4s\n"
- "fsub v22.4s, v22.4s, v28.4s\n"
- "fadd v29.4s, v14.4s, v16.4s\n"
- "fsub v16.4s, v14.4s, v16.4s\n"
- "ldr q20, [x14, x22]\n"
- "fadd v23.4s, v23.4s, v25.4s\n"
- "fmul v12.4s, v12.4s, v0.s[0]\n"
- "fmla v15.4s, v25.4s, v0.s[1]\n"
- "mov v6.16b, v21.16b\n"
- "fadd v30.4s, v21.4s, v19.4s\n"
- "fadd v26.4s, v18.4s, v24.4s\n"
- "mov v25.16b, v24.16b\n"
- "fadd v18.4s, v8.4s, v29.4s\n"
- "fmla v6.4s, v19.4s, v0.s[1]\n"
- "mov v19.16b, v29.16b\n"
- "fadd v27.4s, v11.4s, v17.4s\n"
- "fsub v11.4s, v11.4s, v17.4s\n"
- "fadd v29.4s, v13.4s, v30.4s\n"
- "fsub v8.4s, v13.4s, v30.4s\n"
- "fadd v14.4s, v9.4s, v12.4s\n"
- "fadd v6.4s, v6.4s, v10.4s\n"
- "ldr q28, [x14, x23]\n"
- "fadd v17.4s, v4.4s, v27.4s\n"
- "mov v4.16b, v27.16b\n"
- "fmla v9.4s, v12.4s, v0.s[1]\n"
- "fadd v27.4s, v20.4s, v28.4s\n"
- "fsub v21.4s, v20.4s, v28.4s\n"
- "fadd v10.4s, v7.4s, v29.4s\n"
- "mov v7.16b, v29.16b\n"
- "fadd v13.4s, v3.4s, v6.4s\n"
- "fsub v3.4s, v3.4s, v6.4s\n"
- "ldr q20, [x14, x24]\n"
- "fadd v9.4s, v9.4s, v5.4s\n"
- "fadd v26.4s, v26.4s, v27.4s\n"
- "fmul v21.4s, v21.4s, v0.s[0]\n"
- "add x14, x14, #16\n"
- "fmla v25.4s, v27.4s, v0.s[1]\n"
- "mov v12.16b, v22.16b\n"
- "fadd v1.4s, v1.4s, v13.4s\n"
- "mov v5.16b, v13.16b\n"
- "fadd v13.4s, v23.4s, v26.4s\n"
- "fsub v23.4s, v23.4s, v26.4s\n"
- "fadd v26.4s, v22.4s, v21.4s\n"
- "ldr q24, [x27]\n"
- "fmla v12.4s, v21.4s, v0.s[1]\n"
- "fadd v30.4s, v15.4s, v25.4s\n"
- "fsub v15.4s, v15.4s, v25.4s\n"
- "ldr q28, [x27, %[in_col_stride1]]\n"
- "fadd v18.4s, v18.4s, v13.4s\n"
- "fmul v23.4s, v23.4s, v0.s[0]\n"
- "fmla v19.4s, v13.4s, v0.s[1]\n"
- "ldr q22, [x27, x21]\n"
- "fadd v12.4s, v12.4s, v20.4s\n"
- "ldr q29, [x27, x22]\n"
- "fadd v17.4s, v17.4s, v30.4s\n"
- "fmul v15.4s, v15.4s, v0.s[0]\n"
- "fmla v4.4s, v30.4s, v0.s[1]\n"
- "fadd v25.4s, v28.4s, v22.4s\n"
- "fsub v27.4s, v28.4s, v22.4s\n"
- "fadd v22.4s, v16.4s, v23.4s\n"
- "fadd v18.4s, v18.4s, v2.4s\n"
- "fadd v17.4s, v17.4s, v2.4s\n"
- "fadd v19.4s, v19.4s, v2.4s\n"
- "fadd v28.4s, v24.4s, v25.4s\n"
- "mov v21.16b, v25.16b\n"
- "fmla v16.4s, v23.4s, v0.s[1]\n"
- "ldr q30, [x27, x23]\n"
- "str q18, [%[outptr0]]\n"
- "fadd v20.4s, v29.4s, v30.4s\n"
- "str q17, [%[outptr0], x15]\n"
- "fsub v18.4s, v29.4s, v30.4s\n"
- "str q19, [x28]\n"
- "mov v24.16b, v27.16b\n"
- "fadd v13.4s, v14.4s, v26.4s\n"
- "ldr q25, [x27, x24]\n"
- "fadd v28.4s, v28.4s, v20.4s\n"
- "add x27, x27, #16\n"
- "fmul v18.4s, v18.4s, v0.s[0]\n"
- "fmla v21.4s, v20.4s, v0.s[1]\n"
- "fsub v30.4s, v14.4s, v26.4s\n"
- "mov v14.16b, v8.16b\n"
- "fadd v10.4s, v10.4s, v13.4s\n"
- "fmla v7.4s, v13.4s, v0.s[1]\n"
- "fadd v16.4s, v16.4s, v28.4s\n"
- "fadd v29.4s, v11.4s, v15.4s\n"
- "fadd v23.4s, v27.4s, v18.4s\n"
- "fmla v24.4s, v18.4s, v0.s[1]\n"
- "fmul v30.4s, v30.4s, v0.s[0]\n"
- "mov v26.16b, v11.16b\n"
- "fadd v27.4s, v9.4s, v12.4s\n"
- "fsub v6.4s, v9.4s, v12.4s\n"
- "mov v13.16b, v3.16b\n"
- "fadd v10.4s, v10.4s, v2.4s\n"
- "fadd v24.4s, v24.4s, v25.4s\n"
- "fmla v26.4s, v15.4s, v0.s[1]\n"
- "fadd v8.4s, v8.4s, v30.4s\n"
- "fmla v14.4s, v30.4s, v0.s[1]\n"
- "fadd v1.4s, v1.4s, v27.4s\n"
- "fmul v6.4s, v6.4s, v0.s[0]\n"
- "str q10, [%[outptr0], %[output_col_stride1]]\n"
- "fmla v5.4s, v27.4s, v0.s[1]\n"
- "fadd v26.4s, v26.4s, v21.4s\n"
- "fadd v22.4s, v22.4s, v2.4s\n"
- "fadd v14.4s, v14.4s, v23.4s\n"
- "fadd v8.4s, v8.4s, v2.4s\n"
- "fadd v10.4s, v3.4s, v6.4s\n"
- "fmla v13.4s, v6.4s, v0.s[1]\n"
- "fadd v1.4s, v1.4s, v2.4s\n"
- "fadd v29.4s, v29.4s, v2.4s\n"
- "str q22, [x17]\n"
- "fadd v7.4s, v7.4s, v2.4s\n"
- "str q8, [x17, %[output_col_stride1]]\n"
- "fadd v4.4s, v4.4s, v2.4s\n"
- "fadd v13.4s, v13.4s, v24.4s\n"
- "fadd v10.4s, v10.4s, v2.4s\n"
- "str q1, [%[outptr0], x16]\n"
- "fadd v5.4s, v5.4s, v2.4s\n"
- "str q29, [x17, x15]\n"
- "fadd v16.4s, v16.4s, v2.4s\n"
- "str q7, [x28, %[output_col_stride1]]\n"
- "fadd v14.4s, v14.4s, v2.4s\n"
- "str q10, [x17, x16]\n"
- "fadd v26.4s, v26.4s, v2.4s\n"
- "str q4, [x28, x15]\n"
- "fadd v13.4s, v13.4s, v2.4s\n"
- "str q5, [x28, x16]\n"
- "add %[outptr0], %[outptr0], #16\n"
- "str q16, [x18]\n"
- "add x17, x17, #16\n"
- "str q14, [x18, %[output_col_stride1]]\n"
- "add x28, x28, #16\n"
- "str q26, [x18, x15]\n"
- "str q13, [x18, x16]\n"
- "add x18, x18, #16\n"
- "4:\n"
- "cmp x20, #2\n"
- "blt 5f\n"
- "ldr d19, [%[inptr0]]\n"
- "ldr d20, [%[inptr0], %[in_col_stride1]]\n"
- "sub x20, x20, #2\n"
- "ldr d4, [%[inptr0], x21]\n"
- "ldr d17, [%[inptr0], x22]\n"
- "fadd v1.4s, v20.4s, v4.4s\n"
- "ldr d22, [%[inptr0], x23]\n"
- "fadd v5.4s, v17.4s, v22.4s\n"
- "ldr d18, [%[inptr0], x24]\n"
- "fsub v7.4s, v20.4s, v4.4s\n"
- "ldr d25, [x25]\n"
- "fsub v10.4s, v17.4s, v22.4s\n"
- "ldr d12, [x25, %[in_col_stride1]]\n"
- "fadd v8.4s, v19.4s, v1.4s\n"
- "ldr d23, [x25, x21]\n"
- "mov v4.16b, v1.16b\n"
- "ldr d9, [x25, x22]\n"
- "mov v1.16b, v7.16b\n"
- "ldr d11, [x25, x23]\n"
- "fmul v10.4s, v10.4s, v0.s[0]\n"
- "ldr d6, [x25, x24]\n"
- "fadd v8.4s, v8.4s, v5.4s\n"
- "ldr d2, [x13]\n"
- "fmla v4.4s, v5.4s, v0.s[1]\n"
- "ldr d21, [x13, %[in_col_stride1]]\n"
- "fadd v3.4s, v12.4s, v23.4s\n"
- "ldr d26, [x13, x21]\n"
- "fadd v7.4s, v7.4s, v10.4s\n"
- "ldr d17, [x13, x22]\n"
- "fmla v1.4s, v10.4s, v0.s[1]\n"
- "ldr d19, [x13, x23]\n"
- "fadd v27.4s, v9.4s, v11.4s\n"
- "ldr d10, [x13, x24]\n"
- "fadd v14.4s, v25.4s, v3.4s\n"
- "ldr d15, [x26]\n"
- "fsub v16.4s, v12.4s, v23.4s\n"
- "ldr d12, [x26, %[in_col_stride1]]\n"
- "fadd v1.4s, v1.4s, v18.4s\n"
- "ldr d20, [x26, x21]\n"
- "fsub v9.4s, v9.4s, v11.4s\n"
- "ldr d24, [x26, x22]\n"
- "fadd v14.4s, v14.4s, v27.4s\n"
- "ldr d29, [x26, x23]\n"
- "mov v11.16b, v3.16b\n"
- "ldr d5, [x26, x24]\n"
- "mov v3.16b, v16.16b\n"
- "ldr d18, [x14]\n"
- "fmul v9.4s, v9.4s, v0.s[0]\n"
- "add %[inptr0], %[inptr0], #8\n"
- "fmla v11.4s, v27.4s, v0.s[1]\n"
- "add x25, x25, #8\n"
- "fadd v23.4s, v21.4s, v26.4s\n"
- "add x13, x13, #8\n"
- "fsub v21.4s, v21.4s, v26.4s\n"
- "ldr d22, [x14, %[in_col_stride1]]\n"
- "fadd v13.4s, v16.4s, v9.4s\n"
- "add x26, x26, #8\n"
- "fmla v3.4s, v9.4s, v0.s[1]\n"
- "fadd v30.4s, v17.4s, v19.4s\n"
- "fadd v16.4s, v2.4s, v23.4s\n"
- "fsub v19.4s, v17.4s, v19.4s\n"
- "mov v17.16b, v23.16b\n"
- "fadd v26.4s, v12.4s, v20.4s\n"
- "fsub v9.4s, v12.4s, v20.4s\n"
- "ldr d28, [x14, x21]\n"
- "fadd v3.4s, v3.4s, v6.4s\n"
- "ldr d20, [x14, x22]\n"
- "fadd v16.4s, v16.4s, v30.4s\n"
- "fmul v19.4s, v19.4s, v0.s[0]\n"
- "fmla v17.4s, v30.4s, v0.s[1]\n"
- "fadd v25.4s, v24.4s, v29.4s\n"
- "fadd v23.4s, v15.4s, v26.4s\n"
- "fsub v12.4s, v24.4s, v29.4s\n"
- "mov v15.16b, v26.16b\n"
- "fadd v24.4s, v22.4s, v28.4s\n"
- "fsub v22.4s, v22.4s, v28.4s\n"
- "fadd v29.4s, v14.4s, v16.4s\n"
- "fsub v16.4s, v14.4s, v16.4s\n"
- "ldr d28, [x14, x23]\n"
- "fadd v23.4s, v23.4s, v25.4s\n"
- "fmul v12.4s, v12.4s, v0.s[0]\n"
- "fmla v15.4s, v25.4s, v0.s[1]\n"
- "mov v6.16b, v21.16b\n"
- "fadd v30.4s, v21.4s, v19.4s\n"
- "fadd v26.4s, v18.4s, v24.4s\n"
- "mov v25.16b, v24.16b\n"
- "fadd v18.4s, v8.4s, v29.4s\n"
- "fmla v6.4s, v19.4s, v0.s[1]\n"
- "fadd v27.4s, v20.4s, v28.4s\n"
- "fsub v21.4s, v20.4s, v28.4s\n"
- "mov v19.16b, v29.16b\n"
- "fadd v29.4s, v13.4s, v30.4s\n"
- "fsub v8.4s, v13.4s, v30.4s\n"
- "fadd v14.4s, v9.4s, v12.4s\n"
- "fadd v6.4s, v6.4s, v10.4s\n"
- "ldr d20, [x14, x24]\n"
- "fadd v26.4s, v26.4s, v27.4s\n"
- "add x14, x14, #8\n"
- "fmla v9.4s, v12.4s, v0.s[1]\n"
- "ldr d24, [x27]\n"
- "fmul v21.4s, v21.4s, v0.s[0]\n"
- "fmla v25.4s, v27.4s, v0.s[1]\n"
- "fadd v10.4s, v7.4s, v29.4s\n"
- "ldr d2, [%[bptr]]\n"
- "mov v7.16b, v29.16b\n"
- "add %[bptr], %[bptr], #8\n"
- "fadd v9.4s, v9.4s, v5.4s\n"
- "fadd v13.4s, v23.4s, v26.4s\n"
- "fsub v23.4s, v23.4s, v26.4s\n"
- "fadd v27.4s, v11.4s, v17.4s\n"
- "fsub v11.4s, v11.4s, v17.4s\n"
- "fadd v30.4s, v15.4s, v25.4s\n"
- "fsub v15.4s, v15.4s, v25.4s\n"
- "ldr d28, [x27, %[in_col_stride1]]\n"
- "fadd v18.4s, v18.4s, v13.4s\n"
- "fmla v19.4s, v13.4s, v0.s[1]\n"
- "fadd v26.4s, v22.4s, v21.4s\n"
- "mov v12.16b, v22.16b\n"
- "fmul v23.4s, v23.4s, v0.s[0]\n"
- "fadd v17.4s, v4.4s, v27.4s\n"
- "fmul v15.4s, v15.4s, v0.s[0]\n"
- "mov v4.16b, v27.16b\n"
- "fmla v12.4s, v21.4s, v0.s[1]\n"
- "ldr d22, [x27, x21]\n"
- "fadd v18.4s, v18.4s, v2.4s\n"
- "fadd v19.4s, v19.4s, v2.4s\n"
- "fadd v17.4s, v17.4s, v30.4s\n"
- "fmla v4.4s, v30.4s, v0.s[1]\n"
- "fadd v25.4s, v28.4s, v22.4s\n"
- "fsub v27.4s, v28.4s, v22.4s\n"
- "fadd v12.4s, v12.4s, v20.4s\n"
- "ldr d29, [x27, x22]\n"
- "str d18, [%[outptr0]]\n"
- "fadd v22.4s, v16.4s, v23.4s\n"
- "str d19, [x28]\n"
- "fadd v28.4s, v24.4s, v25.4s\n"
- "ldr d30, [x27, x23]\n"
- "fadd v20.4s, v29.4s, v30.4s\n"
- "fsub v18.4s, v29.4s, v30.4s\n"
- "mov v21.16b, v25.16b\n"
- "ldr d25, [x27, x24]\n"
- "fmla v16.4s, v23.4s, v0.s[1]\n"
- "add x27, x27, #8\n"
- "mov v24.16b, v27.16b\n"
- "fadd v17.4s, v17.4s, v2.4s\n"
- "fadd v28.4s, v28.4s, v20.4s\n"
- "fmul v18.4s, v18.4s, v0.s[0]\n"
- "fmla v21.4s, v20.4s, v0.s[1]\n"
- "fadd v13.4s, v14.4s, v26.4s\n"
- "fsub v30.4s, v14.4s, v26.4s\n"
- "mov v14.16b, v8.16b\n"
- "str d17, [%[outptr0], x15]\n"
- "fadd v29.4s, v11.4s, v15.4s\n"
- "fadd v23.4s, v27.4s, v18.4s\n"
- "fmla v24.4s, v18.4s, v0.s[1]\n"
- "fadd v16.4s, v16.4s, v28.4s\n"
- "fadd v10.4s, v10.4s, v13.4s\n"
- "fmul v30.4s, v30.4s, v0.s[0]\n"
- "fmla v7.4s, v13.4s, v0.s[1]\n"
- "mov v26.16b, v11.16b\n"
- "fadd v13.4s, v3.4s, v6.4s\n"
- "fadd v24.4s, v24.4s, v25.4s\n"
- "fadd v27.4s, v9.4s, v12.4s\n"
- "fsub v3.4s, v3.4s, v6.4s\n"
- "fsub v6.4s, v9.4s, v12.4s\n"
- "fadd v8.4s, v8.4s, v30.4s\n"
- "fmla v14.4s, v30.4s, v0.s[1]\n"
- "fmla v26.4s, v15.4s, v0.s[1]\n"
- "fadd v1.4s, v1.4s, v13.4s\n"
- "mov v5.16b, v13.16b\n"
- "fadd v10.4s, v10.4s, v2.4s\n"
- "fmul v6.4s, v6.4s, v0.s[0]\n"
- "mov v13.16b, v3.16b\n"
- "fadd v14.4s, v14.4s, v23.4s\n"
- "fadd v22.4s, v22.4s, v2.4s\n"
- "fadd v26.4s, v26.4s, v21.4s\n"
- "fadd v1.4s, v1.4s, v27.4s\n"
- "str d10, [%[outptr0], %[output_col_stride1]]\n"
- "fmla v5.4s, v27.4s, v0.s[1]\n"
- "fadd v10.4s, v3.4s, v6.4s\n"
- "fmla v13.4s, v6.4s, v0.s[1]\n"
- "str d22, [x17]\n"
- "fadd v8.4s, v8.4s, v2.4s\n"
- "fadd v1.4s, v1.4s, v2.4s\n"
- "fadd v29.4s, v29.4s, v2.4s\n"
- "fadd v7.4s, v7.4s, v2.4s\n"
- "fadd v4.4s, v4.4s, v2.4s\n"
- "fadd v13.4s, v13.4s, v24.4s\n"
- "fadd v10.4s, v10.4s, v2.4s\n"
- "str d8, [x17, %[output_col_stride1]]\n"
- "fadd v5.4s, v5.4s, v2.4s\n"
- "str d1, [%[outptr0], x16]\n"
- "fadd v16.4s, v16.4s, v2.4s\n"
- "str d29, [x17, x15]\n"
- "fadd v14.4s, v14.4s, v2.4s\n"
- "str d10, [x17, x16]\n"
- "fadd v26.4s, v26.4s, v2.4s\n"
- "str d7, [x28, %[output_col_stride1]]\n"
- "fadd v13.4s, v13.4s, v2.4s\n"
- "str d4, [x28, x15]\n"
- "add %[outptr0], %[outptr0], #8\n"
- "str d5, [x28, x16]\n"
- "add x17, x17, #8\n"
- "str d16, [x18]\n"
- "add x28, x28, #8\n"
- "str d14, [x18, %[output_col_stride1]]\n"
- "str d26, [x18, x15]\n"
- "str d13, [x18, x16]\n"
- "add x18, x18, #8\n"
- "5:\n"
- "cbz x20, 6f\n"
- "ldr s19, [%[inptr0]]\n"
- "ldr s20, [%[inptr0], %[in_col_stride1]]\n"
- "ldr s4, [%[inptr0], x21]\n"
- "fadd v1.4s, v20.4s, v4.4s\n"
- "ldr s17, [%[inptr0], x22]\n"
- "fsub v7.4s, v20.4s, v4.4s\n"
- "ldr s22, [%[inptr0], x23]\n"
- "fadd v5.4s, v17.4s, v22.4s\n"
- "ldr s18, [%[inptr0], x24]\n"
- "fsub v10.4s, v17.4s, v22.4s\n"
- "ldr s25, [x25]\n"
- "fadd v8.4s, v19.4s, v1.4s\n"
- "ldr s12, [x25, %[in_col_stride1]]\n"
- "mov v4.16b, v1.16b\n"
- "ldr s23, [x25, x21]\n"
- "mov v1.16b, v7.16b\n"
- "ldr s9, [x25, x22]\n"
- "fmul v10.4s, v10.4s, v0.s[0]\n"
- "ldr s11, [x25, x23]\n"
- "fadd v8.4s, v8.4s, v5.4s\n"
- "ldr s6, [x25, x24]\n"
- "fmla v4.4s, v5.4s, v0.s[1]\n"
- "ldr s2, [x13]\n"
- "fadd v3.4s, v12.4s, v23.4s\n"
- "ldr s21, [x13, %[in_col_stride1]]\n"
- "fadd v7.4s, v7.4s, v10.4s\n"
- "ldr s26, [x13, x21]\n"
- "fmla v1.4s, v10.4s, v0.s[1]\n"
- "ldr s17, [x13, x22]\n"
- "fadd v27.4s, v9.4s, v11.4s\n"
- "ldr s19, [x13, x23]\n"
- "fadd v14.4s, v25.4s, v3.4s\n"
- "ldr s10, [x13, x24]\n"
- "fsub v16.4s, v12.4s, v23.4s\n"
- "ldr s15, [x26]\n"
- "fadd v1.4s, v1.4s, v18.4s\n"
- "ldr s12, [x26, %[in_col_stride1]]\n"
- "fsub v9.4s, v9.4s, v11.4s\n"
- "ldr s20, [x26, x21]\n"
- "fadd v14.4s, v14.4s, v27.4s\n"
- "ldr s24, [x26, x22]\n"
- "mov v11.16b, v3.16b\n"
- "ldr s29, [x26, x23]\n"
- "mov v3.16b, v16.16b\n"
- "ldr s5, [x26, x24]\n"
- "fmul v9.4s, v9.4s, v0.s[0]\n"
- "ldr s18, [x14]\n"
- "fmla v11.4s, v27.4s, v0.s[1]\n"
- "fadd v23.4s, v21.4s, v26.4s\n"
- "fsub v21.4s, v21.4s, v26.4s\n"
- "fadd v30.4s, v17.4s, v19.4s\n"
- "fsub v19.4s, v17.4s, v19.4s\n"
- "ldr s22, [x14, %[in_col_stride1]]\n"
- "fadd v13.4s, v16.4s, v9.4s\n"
- "fmla v3.4s, v9.4s, v0.s[1]\n"
- "fadd v16.4s, v2.4s, v23.4s\n"
- "mov v17.16b, v23.16b\n"
- "fadd v26.4s, v12.4s, v20.4s\n"
- "fsub v9.4s, v12.4s, v20.4s\n"
- "fmul v19.4s, v19.4s, v0.s[0]\n"
- "ldr s28, [x14, x21]\n"
- "fadd v3.4s, v3.4s, v6.4s\n"
- "ldr s20, [x14, x22]\n"
- "fadd v16.4s, v16.4s, v30.4s\n"
- "fmla v17.4s, v30.4s, v0.s[1]\n"
- "fadd v25.4s, v24.4s, v29.4s\n"
- "fadd v23.4s, v15.4s, v26.4s\n"
- "fsub v12.4s, v24.4s, v29.4s\n"
- "mov v15.16b, v26.16b\n"
- "fadd v24.4s, v22.4s, v28.4s\n"
- "fsub v22.4s, v22.4s, v28.4s\n"
- "fadd v30.4s, v21.4s, v19.4s\n"
- "mov v6.16b, v21.16b\n"
- "fadd v23.4s, v23.4s, v25.4s\n"
- "fmla v15.4s, v25.4s, v0.s[1]\n"
- "fmul v12.4s, v12.4s, v0.s[0]\n"
- "ldr s28, [x14, x23]\n"
- "fmla v6.4s, v19.4s, v0.s[1]\n"
- "fadd v27.4s, v20.4s, v28.4s\n"
- "fadd v26.4s, v18.4s, v24.4s\n"
- "fsub v21.4s, v20.4s, v28.4s\n"
- "mov v25.16b, v24.16b\n"
- "fadd v29.4s, v14.4s, v16.4s\n"
- "fsub v16.4s, v14.4s, v16.4s\n"
- "ldr s20, [x14, x24]\n"
- "fadd v6.4s, v6.4s, v10.4s\n"
- "ldr s24, [x27]\n"
- "fadd v26.4s, v26.4s, v27.4s\n"
- "fmul v21.4s, v21.4s, v0.s[0]\n"
- "fmla v25.4s, v27.4s, v0.s[1]\n"
- "fadd v18.4s, v8.4s, v29.4s\n"
- "mov v19.16b, v29.16b\n"
- "fadd v29.4s, v13.4s, v30.4s\n"
- "fsub v8.4s, v13.4s, v30.4s\n"
- "fadd v27.4s, v11.4s, v17.4s\n"
- "fsub v11.4s, v11.4s, v17.4s\n"
- "fadd v13.4s, v23.4s, v26.4s\n"
- "fsub v23.4s, v23.4s, v26.4s\n"
- "ldr s28, [x27, %[in_col_stride1]]\n"
- "fadd v10.4s, v7.4s, v29.4s\n"
- "mov v7.16b, v29.16b\n"
- "fadd v17.4s, v4.4s, v27.4s\n"
- "mov v4.16b, v27.16b\n"
- "fadd v18.4s, v18.4s, v13.4s\n"
- "fmla v19.4s, v13.4s, v0.s[1]\n"
- "fmul v23.4s, v23.4s, v0.s[0]\n"
- "fadd v30.4s, v15.4s, v25.4s\n"
- "fsub v15.4s, v15.4s, v25.4s\n"
- "fadd v13.4s, v3.4s, v6.4s\n"
- "fsub v3.4s, v3.4s, v6.4s\n"
- "ldr s2, [%[bptr]]\n"
- "fadd v18.4s, v18.4s, v2.4s\n"
- "fadd v19.4s, v19.4s, v2.4s\n"
- "fadd v17.4s, v17.4s, v30.4s\n"
- "fmla v4.4s, v30.4s, v0.s[1]\n"
- "fadd v14.4s, v9.4s, v12.4s\n"
- "fmul v15.4s, v15.4s, v0.s[0]\n"
- "fadd v1.4s, v1.4s, v13.4s\n"
- "str s18, [%[outptr0]]\n"
- "fadd v26.4s, v22.4s, v21.4s\n"
- "str s19, [x28]\n"
- "fmla v9.4s, v12.4s, v0.s[1]\n"
- "mov v12.16b, v22.16b\n"
- "ldr s22, [x27, x21]\n"
- "fadd v25.4s, v28.4s, v22.4s\n"
- "fsub v27.4s, v28.4s, v22.4s\n"
- "fadd v22.4s, v16.4s, v23.4s\n"
- "fadd v9.4s, v9.4s, v5.4s\n"
- "ldr s29, [x27, x22]\n"
- "fmla v12.4s, v21.4s, v0.s[1]\n"
- "ldr s30, [x27, x23]\n"
- "fadd v28.4s, v24.4s, v25.4s\n"
- "mov v21.16b, v25.16b\n"
- "fmla v16.4s, v23.4s, v0.s[1]\n"
- "ldr s25, [x27, x24]\n"
- "mov v5.16b, v13.16b\n"
- "fadd v17.4s, v17.4s, v2.4s\n"
- "fadd v12.4s, v12.4s, v20.4s\n"
- "fadd v20.4s, v29.4s, v30.4s\n"
- "fsub v18.4s, v29.4s, v30.4s\n"
- "mov v24.16b, v27.16b\n"
- "fadd v22.4s, v22.4s, v2.4s\n"
- "fadd v4.4s, v4.4s, v2.4s\n"
- "str s17, [%[outptr0], x15]\n"
- "fadd v13.4s, v14.4s, v26.4s\n"
- "fadd v28.4s, v28.4s, v20.4s\n"
- "fmla v21.4s, v20.4s, v0.s[1]\n"
- "fmul v18.4s, v18.4s, v0.s[0]\n"
- "fsub v30.4s, v14.4s, v26.4s\n"
- "str s22, [x17]\n"
- "mov v14.16b, v8.16b\n"
- "str s4, [x28, x15]\n"
- "fadd v10.4s, v10.4s, v13.4s\n"
- "fadd v16.4s, v16.4s, v28.4s\n"
- "fmla v7.4s, v13.4s, v0.s[1]\n"
- "fadd v23.4s, v27.4s, v18.4s\n"
- "fmla v24.4s, v18.4s, v0.s[1]\n"
- "fmul v30.4s, v30.4s, v0.s[0]\n"
- "fadd v29.4s, v11.4s, v15.4s\n"
- "mov v26.16b, v11.16b\n"
- "fadd v27.4s, v9.4s, v12.4s\n"
- "fsub v6.4s, v9.4s, v12.4s\n"
- "mov v13.16b, v3.16b\n"
- "fadd v24.4s, v24.4s, v25.4s\n"
- "fadd v10.4s, v10.4s, v2.4s\n"
- "fadd v8.4s, v8.4s, v30.4s\n"
- "fmla v14.4s, v30.4s, v0.s[1]\n"
- "fmla v26.4s, v15.4s, v0.s[1]\n"
- "fadd v1.4s, v1.4s, v27.4s\n"
- "fmul v6.4s, v6.4s, v0.s[0]\n"
- "fmla v5.4s, v27.4s, v0.s[1]\n"
- "str s10, [%[outptr0], %[output_col_stride1]]\n"
- "fadd v29.4s, v29.4s, v2.4s\n"
- "fadd v14.4s, v14.4s, v23.4s\n"
- "fadd v8.4s, v8.4s, v2.4s\n"
- "fadd v26.4s, v26.4s, v21.4s\n"
- "fadd v1.4s, v1.4s, v2.4s\n"
- "fadd v10.4s, v3.4s, v6.4s\n"
- "fmla v13.4s, v6.4s, v0.s[1]\n"
- "str s29, [x17, x15]\n"
- "fadd v7.4s, v7.4s, v2.4s\n"
- "str s8, [x17, %[output_col_stride1]]\n"
- "fadd v5.4s, v5.4s, v2.4s\n"
- "str s1, [%[outptr0], x16]\n"
- "fadd v16.4s, v16.4s, v2.4s\n"
- "fadd v13.4s, v13.4s, v24.4s\n"
- "fadd v10.4s, v10.4s, v2.4s\n"
- "str s7, [x28, %[output_col_stride1]]\n"
- "fadd v14.4s, v14.4s, v2.4s\n"
- "str s5, [x28, x16]\n"
- "fadd v26.4s, v26.4s, v2.4s\n"
- "str s16, [x18]\n"
- "fadd v13.4s, v13.4s, v2.4s\n"
- "str s10, [x17, x16]\n"
- "str s14, [x18, %[output_col_stride1]]\n"
- "str s26, [x18, x15]\n"
- "str s13, [x18, x16]\n"
- "6:\n"
- : [bptr] "+r" (bptr), [outptr0] "+r" (output), [inptr0] "+r" (inptr)
- : [output_row_stride] "r" (output_row_stride * sizeof(float)), [output_col_stride1] "r" (output_col_stride * sizeof(float)), [pcoeffs] "r" (coeffs), [n_channels] "r" ((long) n_channels), [in_row_stride] "r" (6 * matrix_stride * sizeof(float)), [in_col_stride1] "r" (matrix_stride * sizeof(float))
- : "cc", "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v2", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v4", "v5", "v6", "v7", "v8", "v9", "x13", "x14", "x15", "x16", "x17", "x18", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "memory"
- );
- }
- else
- {
- __asm__ __volatile__ (
- "ldr d0, [%[pcoeffs]]\n"
- "add x21, %[in_col_stride1], %[in_col_stride1]\n" // Compute input column stride 2
- "add x22, x21, %[in_col_stride1]\n" // Compute input column stride 3
- "add x25, %[inptr0], %[in_row_stride]\n" // Compute input row pointers
- "add x15, %[output_col_stride1], %[output_col_stride1]\n" // Compute output column stride 2
- "add x23, x22, %[in_col_stride1]\n" // Compute input column stride 4
- "add x13, x25, %[in_row_stride]\n" // Compute input row pointers
- "add x16, x15, %[output_col_stride1]\n" // Compute output column stride 3
- "add x24, x23, %[in_col_stride1]\n" // Compute input column stride 5
- "add x26, x13, %[in_row_stride]\n" // Compute input row pointers
- "add x17, %[outptr0], %[output_row_stride]\n" // Compute output row pointer 1
- "add x14, x26, %[in_row_stride]\n" // Compute input row pointers
- "add x28, x17, %[output_row_stride]\n" // Compute output row pointer 2
- "lsr x19, %[n_channels], #2\n"
- "add x27, x14, %[in_row_stride]\n" // Compute input row pointers
- "add x18, x28, %[output_row_stride]\n" // Compute output row pointer 3
- "and x20, %[n_channels], #3\n"
- "cbz x19, 4f\n"
- "1:\n" // Quad head
- "ldr q17, [%[inptr0]]\n"
- "subs x19, x19, #1\n"
- "ldr q23, [%[inptr0], %[in_col_stride1]]\n"
- "ldr q27, [%[inptr0], x21]\n"
- "fadd v4.4s, v23.4s, v27.4s\n"
- "ldr q24, [%[inptr0], x22]\n"
- "fsub v13.4s, v23.4s, v27.4s\n"
- "ldr q11, [%[inptr0], x23]\n"
- "fadd v10.4s, v24.4s, v11.4s\n"
- "ldr q12, [%[inptr0], x24]\n"
- "fsub v11.4s, v24.4s, v11.4s\n"
- "ldr q20, [x25]\n"
- "fadd v7.4s, v17.4s, v4.4s\n"
- "ldr q19, [x25, %[in_col_stride1]]\n"
- "mov v4.16b, v4.16b\n"
- "ldr q22, [x25, x21]\n"
- "mov v1.16b, v13.16b\n"
- "ldr q14, [x25, x22]\n"
- "fmul v11.4s, v11.4s, v0.s[0]\n"
- "ldr q18, [x25, x23]\n"
- "fadd v7.4s, v7.4s, v10.4s\n"
- "ldr q3, [x25, x24]\n"
- "fmla v4.4s, v10.4s, v0.s[1]\n"
- "fadd v8.4s, v13.4s, v11.4s\n"
- "fmla v1.4s, v11.4s, v0.s[1]\n"
- "fadd v1.4s, v1.4s, v12.4s\n"
- "beq 3f\n"
- "2:\n" // Quad loop
- "fadd v2.4s, v19.4s, v22.4s\n"
- "ldr q16, [x13]\n"
- "fadd v23.4s, v14.4s, v18.4s\n"
- "ldr q21, [x13, %[in_col_stride1]]\n"
- "fsub v15.4s, v19.4s, v22.4s\n"
- "ldr q24, [x13, x21]\n"
- "fsub v31.4s, v14.4s, v18.4s\n"
- "ldr q25, [x13, x22]\n"
- "fadd v11.4s, v20.4s, v2.4s\n"
- "ldr q17, [x13, x23]\n"
- "mov v13.16b, v2.16b\n"
- "ldr q9, [x13, x24]\n"
- "mov v2.16b, v15.16b\n"
- "ldr q6, [x26]\n"
- "fmul v31.4s, v31.4s, v0.s[0]\n"
- "ldr q19, [x26, %[in_col_stride1]]\n"
- "fadd v11.4s, v11.4s, v23.4s\n"
- "ldr q22, [x26, x21]\n"
- "fmla v13.4s, v23.4s, v0.s[1]\n"
- "ldr q12, [x26, x22]\n"
- "fadd v29.4s, v21.4s, v24.4s\n"
- "ldr q26, [x26, x23]\n"
- "fadd v15.4s, v15.4s, v31.4s\n"
- "ldr q5, [x26, x24]\n"
- "fmla v2.4s, v31.4s, v0.s[1]\n"
- "ldr q10, [x14]\n"
- "fadd v18.4s, v25.4s, v17.4s\n"
- "add %[inptr0], %[inptr0], #16\n"
- "fadd v27.4s, v16.4s, v29.4s\n"
- "add x25, x25, #16\n"
- "fsub v14.4s, v21.4s, v24.4s\n"
- "ldr q30, [x14, %[in_col_stride1]]\n"
- "fadd v2.4s, v2.4s, v3.4s\n"
- "ldr q31, [x14, x21]\n"
- "fsub v28.4s, v25.4s, v17.4s\n"
- "add x13, x13, #16\n"
- "fadd v27.4s, v27.4s, v18.4s\n"
- "add x26, x26, #16\n"
- "mov v21.16b, v29.16b\n"
- "subs x19, x19, #1\n"
- "fadd v20.4s, v19.4s, v22.4s\n"
- "fsub v17.4s, v19.4s, v22.4s\n"
- "fmul v28.4s, v28.4s, v0.s[0]\n"
- "ldr q23, [x14, x22]\n"
- "fmla v21.4s, v18.4s, v0.s[1]\n"
- "fadd v29.4s, v12.4s, v26.4s\n"
- "fsub v16.4s, v12.4s, v26.4s\n"
- "fadd v25.4s, v30.4s, v31.4s\n"
- "fadd v24.4s, v6.4s, v20.4s\n"
- "mov v6.16b, v20.16b\n"
- "fsub v22.4s, v30.4s, v31.4s\n"
- "fadd v31.4s, v11.4s, v27.4s\n"
- "fsub v12.4s, v11.4s, v27.4s\n"
- "ldr q26, [x14, x23]\n"
- "fmul v16.4s, v16.4s, v0.s[0]\n"
- "fmla v6.4s, v29.4s, v0.s[1]\n"
- "fadd v24.4s, v24.4s, v29.4s\n"
- "mov v3.16b, v14.16b\n"
- "fadd v20.4s, v14.4s, v28.4s\n"
- "fadd v29.4s, v10.4s, v25.4s\n"
- "mov v10.16b, v25.16b\n"
- "fadd v25.4s, v7.4s, v31.4s\n"
- "fmla v3.4s, v28.4s, v0.s[1]\n"
- "fadd v14.4s, v23.4s, v26.4s\n"
- "fsub v23.4s, v23.4s, v26.4s\n"
- "mov v26.16b, v31.16b\n"
- "fadd v31.4s, v15.4s, v20.4s\n"
- "fsub v11.4s, v15.4s, v20.4s\n"
- "fadd v20.4s, v17.4s, v16.4s\n"
- "mov v7.16b, v17.16b\n"
- "fadd v3.4s, v3.4s, v9.4s\n"
- "ldr q18, [x14, x24]\n"
- "fadd v29.4s, v29.4s, v14.4s\n"
- "add x14, x14, #16\n"
- "fmla v7.4s, v16.4s, v0.s[1]\n"
- "ldr q19, [x27]\n"
- "fmul v23.4s, v23.4s, v0.s[0]\n"
- "fmla v10.4s, v14.4s, v0.s[1]\n"
- "fadd v15.4s, v8.4s, v31.4s\n"
- "mov v14.16b, v31.16b\n"
- "fadd v28.4s, v24.4s, v29.4s\n"
- "fsub v24.4s, v24.4s, v29.4s\n"
- "fadd v7.4s, v7.4s, v5.4s\n"
- "ldr q27, [x27, %[in_col_stride1]]\n"
- "fadd v30.4s, v13.4s, v21.4s\n"
- "fsub v9.4s, v13.4s, v21.4s\n"
- "fadd v17.4s, v22.4s, v23.4s\n"
- "mov v8.16b, v22.16b\n"
- "fadd v25.4s, v25.4s, v28.4s\n"
- "fmul v24.4s, v24.4s, v0.s[0]\n"
- "fmla v26.4s, v28.4s, v0.s[1]\n"
- "ldr q29, [x27, x21]\n"
- "fmla v8.4s, v23.4s, v0.s[1]\n"
- "ldr q28, [x27, x22]\n"
- "fadd v13.4s, v4.4s, v30.4s\n"
- "mov v4.16b, v30.16b\n"
- "str q25, [%[outptr0]]\n" // Store output (0, 0)
- "fadd v16.4s, v27.4s, v29.4s\n"
- "str q26, [x28]\n" // Store output (2, 0)
- "fsub v29.4s, v27.4s, v29.4s\n"
- "fadd v8.4s, v8.4s, v18.4s\n"
- "ldr q23, [x27, x23]\n"
- "fadd v30.4s, v28.4s, v23.4s\n"
- "ldr q25, [x27, x24]\n"
- "fadd v19.4s, v19.4s, v16.4s\n"
- "add x27, x27, #16\n"
- "fsub v27.4s, v28.4s, v23.4s\n"
- "mov v16.16b, v16.16b\n"
- "fadd v22.4s, v20.4s, v17.4s\n"
- "fsub v20.4s, v20.4s, v17.4s\n"
- "fadd v21.4s, v12.4s, v24.4s\n"
- "mov v26.16b, v12.16b\n"
- "fadd v19.4s, v19.4s, v30.4s\n"
- "fmla v16.4s, v30.4s, v0.s[1]\n"
- "fmul v27.4s, v27.4s, v0.s[0]\n"
- "ldr q17, [%[inptr0]]\n"
- "fmla v26.4s, v24.4s, v0.s[1]\n"
- "ldr q23, [%[inptr0], %[in_col_stride1]]\n"
- "str q21, [x17]\n" // Store output (1, 0)
- "mov v5.16b, v29.16b\n"
- "fadd v15.4s, v15.4s, v22.4s\n"
- "fmul v20.4s, v20.4s, v0.s[0]\n"
- "fadd v18.4s, v29.4s, v27.4s\n"
- "fmla v14.4s, v22.4s, v0.s[1]\n"
- "fmla v5.4s, v27.4s, v0.s[1]\n"
- "ldr q27, [%[inptr0], x21]\n"
- "fadd v26.4s, v26.4s, v19.4s\n"
- "ldr q24, [%[inptr0], x22]\n"
- "str q15, [%[outptr0], %[output_col_stride1]]\n" // Store output (0, 1)
- "fadd v12.4s, v11.4s, v20.4s\n"
- "str q14, [x28, %[output_col_stride1]]\n" // Store output (2, 1)
- "mov v28.16b, v11.16b\n"
- "fadd v5.4s, v5.4s, v25.4s\n"
- "ldr q11, [%[inptr0], x23]\n"
- "str q26, [x18]\n" // Store output (3, 0)
- "fadd v21.4s, v6.4s, v10.4s\n"
- "str q12, [x17, %[output_col_stride1]]\n" // Store output (1, 1)
- "fmla v28.4s, v20.4s, v0.s[1]\n"
- "fsub v10.4s, v6.4s, v10.4s\n"
- "ldr q12, [%[inptr0], x24]\n"
- "mov v15.16b, v9.16b\n"
- "ldr q20, [x25]\n"
- "fadd v13.4s, v13.4s, v21.4s\n"
- "ldr q19, [x25, %[in_col_stride1]]\n"
- "fadd v28.4s, v28.4s, v18.4s\n"
- "ldr q22, [x25, x21]\n"
- "fmul v10.4s, v10.4s, v0.s[0]\n"
- "ldr q14, [x25, x22]\n"
- "fmla v4.4s, v21.4s, v0.s[1]\n"
- "ldr q18, [x25, x23]\n"
- "str q13, [%[outptr0], x15]\n" // Store output (0, 2)
- "fadd v6.4s, v2.4s, v3.4s\n"
- "str q28, [x18, %[output_col_stride1]]\n" // Store output (3, 1)
- "fadd v30.4s, v7.4s, v8.4s\n"
- "fadd v13.4s, v9.4s, v10.4s\n"
- "fmla v15.4s, v10.4s, v0.s[1]\n"
- "str q4, [x28, x15]\n" // Store output (2, 2)
- "fsub v2.4s, v2.4s, v3.4s\n"
- "fadd v1.4s, v1.4s, v6.4s\n"
- "ldr q3, [x25, x24]\n"
- "fsub v8.4s, v7.4s, v8.4s\n"
- "mov v6.16b, v6.16b\n"
- "str q13, [x17, x15]\n" // Store output (1, 2)
- "fadd v15.4s, v15.4s, v16.4s\n"
- "mov v9.16b, v2.16b\n"
- "fadd v4.4s, v23.4s, v27.4s\n"
- "fadd v1.4s, v1.4s, v30.4s\n"
- "fmla v6.4s, v30.4s, v0.s[1]\n"
- "fmul v8.4s, v8.4s, v0.s[0]\n"
- "fadd v10.4s, v24.4s, v11.4s\n"
- "str q15, [x18, x15]\n" // Store output (3, 2)
- "fsub v13.4s, v23.4s, v27.4s\n"
- "fadd v7.4s, v17.4s, v4.4s\n"
- "fsub v11.4s, v24.4s, v11.4s\n"
- "str q1, [%[outptr0], x16]\n" // Store output (0, 3)
- "mov v4.16b, v4.16b\n"
- "str q6, [x28, x16]\n" // Store output (2, 3)
- "fadd v2.4s, v2.4s, v8.4s\n"
- "fmla v9.4s, v8.4s, v0.s[1]\n"
- "add %[outptr0], %[outptr0], #16\n"
- "fadd v7.4s, v7.4s, v10.4s\n"
- "add x28, x28, #16\n"
- "fmul v11.4s, v11.4s, v0.s[0]\n"
- "fmla v4.4s, v10.4s, v0.s[1]\n"
- "str q2, [x17, x16]\n" // Store output (1, 3)
- "mov v1.16b, v13.16b\n"
- "fadd v9.4s, v9.4s, v5.4s\n"
- "add x17, x17, #16\n"
- "fadd v8.4s, v13.4s, v11.4s\n"
- "fmla v1.4s, v11.4s, v0.s[1]\n"
- "str q9, [x18, x16]\n" // Store output (3, 3)
- "add x18, x18, #16\n"
- "fadd v1.4s, v1.4s, v12.4s\n"
- "bne 2b\n"
- "3:\n" // Quad tail
- "fadd v2.4s, v19.4s, v22.4s\n"
- "ldr q16, [x13]\n"
- "fadd v23.4s, v14.4s, v18.4s\n"
- "ldr q21, [x13, %[in_col_stride1]]\n"
- "fsub v15.4s, v19.4s, v22.4s\n"
- "ldr q24, [x13, x21]\n"
- "fsub v31.4s, v14.4s, v18.4s\n"
- "ldr q25, [x13, x22]\n"
- "fadd v11.4s, v20.4s, v2.4s\n"
- "ldr q17, [x13, x23]\n"
- "mov v13.16b, v2.16b\n"
- "ldr q9, [x13, x24]\n"
- "mov v2.16b, v15.16b\n"
- "ldr q6, [x26]\n"
- "fmul v31.4s, v31.4s, v0.s[0]\n"
- "ldr q19, [x26, %[in_col_stride1]]\n"
- "fadd v11.4s, v11.4s, v23.4s\n"
- "ldr q22, [x26, x21]\n"
- "fmla v13.4s, v23.4s, v0.s[1]\n"
- "ldr q12, [x26, x22]\n"
- "fadd v29.4s, v21.4s, v24.4s\n"
- "ldr q26, [x26, x23]\n"
- "fadd v15.4s, v15.4s, v31.4s\n"
- "ldr q5, [x26, x24]\n"
- "fmla v2.4s, v31.4s, v0.s[1]\n"
- "ldr q10, [x14]\n"
- "fadd v18.4s, v25.4s, v17.4s\n"
- "add %[inptr0], %[inptr0], #16\n"
- "fadd v27.4s, v16.4s, v29.4s\n"
- "add x25, x25, #16\n"
- "fsub v14.4s, v21.4s, v24.4s\n"
- "ldr q30, [x14, %[in_col_stride1]]\n"
- "fadd v2.4s, v2.4s, v3.4s\n"
- "ldr q31, [x14, x21]\n"
- "fsub v28.4s, v25.4s, v17.4s\n"
- "add x13, x13, #16\n"
- "fadd v27.4s, v27.4s, v18.4s\n"
- "add x26, x26, #16\n"
- "mov v21.16b, v29.16b\n"
- "fadd v20.4s, v19.4s, v22.4s\n"
- "fsub v17.4s, v19.4s, v22.4s\n"
- "fadd v29.4s, v12.4s, v26.4s\n"
- "fmul v28.4s, v28.4s, v0.s[0]\n"
- "fsub v16.4s, v12.4s, v26.4s\n"
- "fmla v21.4s, v18.4s, v0.s[1]\n"
- "ldr q23, [x14, x22]\n"
- "fadd v24.4s, v6.4s, v20.4s\n"
- "mov v6.16b, v20.16b\n"
- "fadd v25.4s, v30.4s, v31.4s\n"
- "fsub v22.4s, v30.4s, v31.4s\n"
- "fadd v20.4s, v14.4s, v28.4s\n"
- "mov v3.16b, v14.16b\n"
- "fmul v16.4s, v16.4s, v0.s[0]\n"
- "fmla v6.4s, v29.4s, v0.s[1]\n"
- "fadd v24.4s, v24.4s, v29.4s\n"
- "ldr q26, [x14, x23]\n"
- "fmla v3.4s, v28.4s, v0.s[1]\n"
- "fadd v14.4s, v23.4s, v26.4s\n"
- "fadd v29.4s, v10.4s, v25.4s\n"
- "fsub v23.4s, v23.4s, v26.4s\n"
- "mov v10.16b, v25.16b\n"
- "fadd v31.4s, v11.4s, v27.4s\n"
- "fsub v12.4s, v11.4s, v27.4s\n"
- "ldr q18, [x14, x24]\n"
- "fadd v3.4s, v3.4s, v9.4s\n"
- "ldr q19, [x27]\n"
- "fadd v29.4s, v29.4s, v14.4s\n"
- "add x14, x14, #16\n"
- "fmul v23.4s, v23.4s, v0.s[0]\n"
- "fmla v10.4s, v14.4s, v0.s[1]\n"
- "fadd v25.4s, v7.4s, v31.4s\n"
- "mov v26.16b, v31.16b\n"
- "fadd v31.4s, v15.4s, v20.4s\n"
- "fsub v11.4s, v15.4s, v20.4s\n"
- "fadd v28.4s, v24.4s, v29.4s\n"
- "fsub v24.4s, v24.4s, v29.4s\n"
- "fadd v30.4s, v13.4s, v21.4s\n"
- "fsub v9.4s, v13.4s, v21.4s\n"
- "fadd v20.4s, v17.4s, v16.4s\n"
- "mov v7.16b, v17.16b\n"
- "fadd v15.4s, v8.4s, v31.4s\n"
- "mov v14.16b, v31.16b\n"
- "fadd v25.4s, v25.4s, v28.4s\n"
- "fmul v24.4s, v24.4s, v0.s[0]\n"
- "fmla v7.4s, v16.4s, v0.s[1]\n"
- "ldr q27, [x27, %[in_col_stride1]]\n"
- "fmla v26.4s, v28.4s, v0.s[1]\n"
- "ldr q29, [x27, x21]\n"
- "fadd v13.4s, v4.4s, v30.4s\n"
- "mov v4.16b, v30.16b\n"
- "str q25, [%[outptr0]]\n" // Store output (0, 0)
- "fadd v17.4s, v22.4s, v23.4s\n"
- "fadd v7.4s, v7.4s, v5.4s\n"
- "ldr q28, [x27, x22]\n"
- "str q26, [x28]\n" // Store output (2, 0)
- "mov v8.16b, v22.16b\n"
- "fadd v16.4s, v27.4s, v29.4s\n"
- "fsub v29.4s, v27.4s, v29.4s\n"
- "fadd v21.4s, v12.4s, v24.4s\n"
- "mov v26.16b, v12.16b\n"
- "fmla v8.4s, v23.4s, v0.s[1]\n"
- "fadd v22.4s, v20.4s, v17.4s\n"
- "fsub v20.4s, v20.4s, v17.4s\n"
- "ldr q23, [x27, x23]\n"
- "fadd v19.4s, v19.4s, v16.4s\n"
- "mov v16.16b, v16.16b\n"
- "str q21, [x17]\n" // Store output (1, 0)
- "fadd v30.4s, v28.4s, v23.4s\n"
- "fadd v8.4s, v8.4s, v18.4s\n"
- "ldr q25, [x27, x24]\n"
- "fsub v27.4s, v28.4s, v23.4s\n"
- "add x27, x27, #16\n"
- "mov v5.16b, v29.16b\n"
- "fmla v26.4s, v24.4s, v0.s[1]\n"
- "fadd v19.4s, v19.4s, v30.4s\n"
- "fmla v16.4s, v30.4s, v0.s[1]\n"
- "fadd v15.4s, v15.4s, v22.4s\n"
- "fmul v20.4s, v20.4s, v0.s[0]\n"
- "fmul v27.4s, v27.4s, v0.s[0]\n"
- "fmla v14.4s, v22.4s, v0.s[1]\n"
- "mov v28.16b, v11.16b\n"
- "fadd v21.4s, v6.4s, v10.4s\n"
- "fadd v26.4s, v26.4s, v19.4s\n"
- "fsub v10.4s, v6.4s, v10.4s\n"
- "str q15, [%[outptr0], %[output_col_stride1]]\n" // Store output (0, 1)
- "fadd v12.4s, v11.4s, v20.4s\n"
- "str q14, [x28, %[output_col_stride1]]\n" // Store output (2, 1)
- "fadd v18.4s, v29.4s, v27.4s\n"
- "fmla v5.4s, v27.4s, v0.s[1]\n"
- "fmla v28.4s, v20.4s, v0.s[1]\n"
- "str q26, [x18]\n" // Store output (3, 0)
- "fadd v13.4s, v13.4s, v21.4s\n"
- "str q12, [x17, %[output_col_stride1]]\n" // Store output (1, 1)
- "fmul v10.4s, v10.4s, v0.s[0]\n"
- "fmla v4.4s, v21.4s, v0.s[1]\n"
- "mov v15.16b, v9.16b\n"
- "fadd v5.4s, v5.4s, v25.4s\n"
- "fadd v28.4s, v28.4s, v18.4s\n"
- "str q13, [%[outptr0], x15]\n" // Store output (0, 2)
- "fadd v6.4s, v2.4s, v3.4s\n"
- "fadd v13.4s, v9.4s, v10.4s\n"
- "fmla v15.4s, v10.4s, v0.s[1]\n"
- "str q4, [x28, x15]\n" // Store output (2, 2)
- "fadd v30.4s, v7.4s, v8.4s\n"
- "str q28, [x18, %[output_col_stride1]]\n" // Store output (3, 1)
- "fsub v2.4s, v2.4s, v3.4s\n"
- "fadd v1.4s, v1.4s, v6.4s\n"
- "fsub v8.4s, v7.4s, v8.4s\n"
- "str q13, [x17, x15]\n" // Store output (1, 2)
- "fadd v15.4s, v15.4s, v16.4s\n"
- "mov v6.16b, v6.16b\n"
- "mov v9.16b, v2.16b\n"
- "fadd v1.4s, v1.4s, v30.4s\n"
- "fmul v8.4s, v8.4s, v0.s[0]\n"
- "str q15, [x18, x15]\n" // Store output (3, 2)
- "fmla v6.4s, v30.4s, v0.s[1]\n"
- "str q1, [%[outptr0], x16]\n" // Store output (0, 3)
- "fadd v2.4s, v2.4s, v8.4s\n"
- "str q6, [x28, x16]\n" // Store output (2, 3)
- "fmla v9.4s, v8.4s, v0.s[1]\n"
- "add %[outptr0], %[outptr0], #16\n"
- "add x28, x28, #16\n"
- "str q2, [x17, x16]\n" // Store output (1, 3)
- "fadd v9.4s, v9.4s, v5.4s\n"
- "add x17, x17, #16\n"
- "str q9, [x18, x16]\n" // Store output (3, 3)
- "add x18, x18, #16\n"
- "4:\n" // Double
- "cmp x20, #2\n"
- "blt 5f\n"
- "ldr d17, [%[inptr0]]\n"
- "ldr d23, [%[inptr0], %[in_col_stride1]]\n"
- "sub x20, x20, #2\n"
- "ldr d27, [%[inptr0], x21]\n"
- "ldr d24, [%[inptr0], x22]\n"
- "fadd v4.4s, v23.4s, v27.4s\n"
- "ldr d11, [%[inptr0], x23]\n"
- "fadd v10.4s, v24.4s, v11.4s\n"
- "ldr d12, [%[inptr0], x24]\n"
- "fsub v13.4s, v23.4s, v27.4s\n"
- "ldr d20, [x25]\n"
- "fsub v11.4s, v24.4s, v11.4s\n"
- "ldr d19, [x25, %[in_col_stride1]]\n"
- "fadd v7.4s, v17.4s, v4.4s\n"
- "ldr d22, [x25, x21]\n"
- "mov v4.16b, v4.16b\n"
- "ldr d14, [x25, x22]\n"
- "mov v1.16b, v13.16b\n"
- "ldr d18, [x25, x23]\n"
- "fmul v11.4s, v11.4s, v0.s[0]\n"
- "ldr d3, [x25, x24]\n"
- "fadd v7.4s, v7.4s, v10.4s\n"
- "ldr d16, [x13]\n"
- "fmla v4.4s, v10.4s, v0.s[1]\n"
- "ldr d21, [x13, %[in_col_stride1]]\n"
- "fadd v2.4s, v19.4s, v22.4s\n"
- "ldr d24, [x13, x21]\n"
- "fadd v8.4s, v13.4s, v11.4s\n"
- "ldr d25, [x13, x22]\n"
- "fmla v1.4s, v11.4s, v0.s[1]\n"
- "ldr d17, [x13, x23]\n"
- "fadd v23.4s, v14.4s, v18.4s\n"
- "ldr d9, [x13, x24]\n"
- "fadd v11.4s, v20.4s, v2.4s\n"
- "ldr d6, [x26]\n"
- "fsub v15.4s, v19.4s, v22.4s\n"
- "ldr d19, [x26, %[in_col_stride1]]\n"
- "fadd v1.4s, v1.4s, v12.4s\n"
- "ldr d22, [x26, x21]\n"
- "fsub v31.4s, v14.4s, v18.4s\n"
- "ldr d12, [x26, x22]\n"
- "fadd v11.4s, v11.4s, v23.4s\n"
- "ldr d26, [x26, x23]\n"
- "mov v13.16b, v2.16b\n"
- "ldr d5, [x26, x24]\n"
- "mov v2.16b, v15.16b\n"
- "ldr d10, [x14]\n"
- "fmul v31.4s, v31.4s, v0.s[0]\n"
- "add %[inptr0], %[inptr0], #8\n"
- "fmla v13.4s, v23.4s, v0.s[1]\n"
- "add x25, x25, #8\n"
- "fadd v29.4s, v21.4s, v24.4s\n"
- "add x13, x13, #8\n"
- "fsub v14.4s, v21.4s, v24.4s\n"
- "ldr d30, [x14, %[in_col_stride1]]\n"
- "fadd v15.4s, v15.4s, v31.4s\n"
- "add x26, x26, #8\n"
- "fmla v2.4s, v31.4s, v0.s[1]\n"
- "fadd v18.4s, v25.4s, v17.4s\n"
- "fadd v27.4s, v16.4s, v29.4s\n"
- "fsub v28.4s, v25.4s, v17.4s\n"
- "mov v21.16b, v29.16b\n"
- "fadd v20.4s, v19.4s, v22.4s\n"
- "fsub v17.4s, v19.4s, v22.4s\n"
- "ldr d31, [x14, x21]\n"
- "fadd v2.4s, v2.4s, v3.4s\n"
- "ldr d23, [x14, x22]\n"
- "fadd v27.4s, v27.4s, v18.4s\n"
- "fmul v28.4s, v28.4s, v0.s[0]\n"
- "fmla v21.4s, v18.4s, v0.s[1]\n"
- "fadd v29.4s, v12.4s, v26.4s\n"
- "fadd v24.4s, v6.4s, v20.4s\n"
- "fsub v16.4s, v12.4s, v26.4s\n"
- "mov v6.16b, v20.16b\n"
- "fadd v25.4s, v30.4s, v31.4s\n"
- "fsub v22.4s, v30.4s, v31.4s\n"
- "fadd v31.4s, v11.4s, v27.4s\n"
- "fsub v12.4s, v11.4s, v27.4s\n"
- "ldr d26, [x14, x23]\n"
- "fadd v24.4s, v24.4s, v29.4s\n"
- "fmul v16.4s, v16.4s, v0.s[0]\n"
- "fmla v6.4s, v29.4s, v0.s[1]\n"
- "mov v3.16b, v14.16b\n"
- "fadd v20.4s, v14.4s, v28.4s\n"
- "fadd v29.4s, v10.4s, v25.4s\n"
- "mov v10.16b, v25.16b\n"
- "fadd v25.4s, v7.4s, v31.4s\n"
- "fmla v3.4s, v28.4s, v0.s[1]\n"
- "fadd v14.4s, v23.4s, v26.4s\n"
- "fsub v23.4s, v23.4s, v26.4s\n"
- "mov v26.16b, v31.16b\n"
- "fadd v31.4s, v15.4s, v20.4s\n"
- "fsub v11.4s, v15.4s, v20.4s\n"
- "fadd v20.4s, v17.4s, v16.4s\n"
- "mov v7.16b, v17.16b\n"
- "fadd v3.4s, v3.4s, v9.4s\n"
- "ldr d18, [x14, x24]\n"
- "fadd v29.4s, v29.4s, v14.4s\n"
- "add x14, x14, #8\n"
- "fmla v7.4s, v16.4s, v0.s[1]\n"
- "ldr d19, [x27]\n"
- "fmul v23.4s, v23.4s, v0.s[0]\n"
- "fmla v10.4s, v14.4s, v0.s[1]\n"
- "fadd v15.4s, v8.4s, v31.4s\n"
- "mov v14.16b, v31.16b\n"
- "fadd v28.4s, v24.4s, v29.4s\n"
- "fsub v24.4s, v24.4s, v29.4s\n"
- "fadd v7.4s, v7.4s, v5.4s\n"
- "ldr d27, [x27, %[in_col_stride1]]\n"
- "fadd v30.4s, v13.4s, v21.4s\n"
- "fsub v9.4s, v13.4s, v21.4s\n"
- "fadd v17.4s, v22.4s, v23.4s\n"
- "mov v8.16b, v22.16b\n"
- "fadd v25.4s, v25.4s, v28.4s\n"
- "fmul v24.4s, v24.4s, v0.s[0]\n"
- "fmla v26.4s, v28.4s, v0.s[1]\n"
- "ldr d29, [x27, x21]\n"
- "fmla v8.4s, v23.4s, v0.s[1]\n"
- "ldr d28, [x27, x22]\n"
- "fadd v13.4s, v4.4s, v30.4s\n"
- "mov v4.16b, v30.16b\n"
- "str d25, [%[outptr0]]\n" // Store output (0, 0)
- "fadd v16.4s, v27.4s, v29.4s\n"
- "str d26, [x28]\n" // Store output (2, 0)
- "fsub v29.4s, v27.4s, v29.4s\n"
- "fadd v8.4s, v8.4s, v18.4s\n"
- "ldr d23, [x27, x23]\n"
- "fadd v30.4s, v28.4s, v23.4s\n"
- "ldr d25, [x27, x24]\n"
- "fadd v19.4s, v19.4s, v16.4s\n"
- "add x27, x27, #8\n"
- "fsub v27.4s, v28.4s, v23.4s\n"
- "mov v16.16b, v16.16b\n"
- "fadd v22.4s, v20.4s, v17.4s\n"
- "fsub v20.4s, v20.4s, v17.4s\n"
- "fadd v21.4s, v12.4s, v24.4s\n"
- "mov v26.16b, v12.16b\n"
- "fadd v19.4s, v19.4s, v30.4s\n"
- "fmla v16.4s, v30.4s, v0.s[1]\n"
- "fmul v27.4s, v27.4s, v0.s[0]\n"
- "mov v5.16b, v29.16b\n"
- "fmla v26.4s, v24.4s, v0.s[1]\n"
- "fadd v15.4s, v15.4s, v22.4s\n"
- "str d21, [x17]\n" // Store output (1, 0)
- "fmul v20.4s, v20.4s, v0.s[0]\n"
- "fmla v14.4s, v22.4s, v0.s[1]\n"
- "mov v28.16b, v11.16b\n"
- "fadd v18.4s, v29.4s, v27.4s\n"
- "fmla v5.4s, v27.4s, v0.s[1]\n"
- "str d15, [%[outptr0], %[output_col_stride1]]\n" // Store output (0, 1)
- "fadd v26.4s, v26.4s, v19.4s\n"
- "fadd v12.4s, v11.4s, v20.4s\n"
- "fmla v28.4s, v20.4s, v0.s[1]\n"
- "str d14, [x28, %[output_col_stride1]]\n" // Store output (2, 1)
- "fadd v21.4s, v6.4s, v10.4s\n"
- "fadd v5.4s, v5.4s, v25.4s\n"
- "fsub v10.4s, v6.4s, v10.4s\n"
- "str d26, [x18]\n" // Store output (3, 0)
- "mov v15.16b, v9.16b\n"
- "str d12, [x17, %[output_col_stride1]]\n" // Store output (1, 1)
- "fadd v28.4s, v28.4s, v18.4s\n"
- "fadd v13.4s, v13.4s, v21.4s\n"
- "fmla v4.4s, v21.4s, v0.s[1]\n"
- "fmul v10.4s, v10.4s, v0.s[0]\n"
- "fadd v6.4s, v2.4s, v3.4s\n"
- "fadd v30.4s, v7.4s, v8.4s\n"
- "fsub v2.4s, v2.4s, v3.4s\n"
- "str d28, [x18, %[output_col_stride1]]\n" // Store output (3, 1)
- "fsub v8.4s, v7.4s, v8.4s\n"
- "str d13, [%[outptr0], x15]\n" // Store output (0, 2)
- "str d4, [x28, x15]\n" // Store output (2, 2)
- "fadd v13.4s, v9.4s, v10.4s\n"
- "fmla v15.4s, v10.4s, v0.s[1]\n"
- "fadd v1.4s, v1.4s, v6.4s\n"
- "mov v6.16b, v6.16b\n"
- "fmul v8.4s, v8.4s, v0.s[0]\n"
- "mov v9.16b, v2.16b\n"
- "str d13, [x17, x15]\n" // Store output (1, 2)
- "fadd v15.4s, v15.4s, v16.4s\n"
- "fadd v1.4s, v1.4s, v30.4s\n"
- "fmla v6.4s, v30.4s, v0.s[1]\n"
- "fadd v2.4s, v2.4s, v8.4s\n"
- "fmla v9.4s, v8.4s, v0.s[1]\n"
- "str d15, [x18, x15]\n" // Store output (3, 2)
- "str d1, [%[outptr0], x16]\n" // Store output (0, 3)
- "str d2, [x17, x16]\n" // Store output (1, 3)
- "fadd v9.4s, v9.4s, v5.4s\n"
- "str d6, [x28, x16]\n" // Store output (2, 3)
- "add %[outptr0], %[outptr0], #8\n"
- "add x17, x17, #8\n"
- "add x28, x28, #8\n"
- "str d9, [x18, x16]\n" // Store output (3, 3)
- "add x18, x18, #8\n"
- "5:\n" // Scalar
- "cbz x20, 6f\n"
- "ldr s17, [%[inptr0]]\n"
- "ldr s23, [%[inptr0], %[in_col_stride1]]\n"
- "ldr s27, [%[inptr0], x21]\n"
- "fadd v4.4s, v23.4s, v27.4s\n"
- "ldr s24, [%[inptr0], x22]\n"
- "fsub v13.4s, v23.4s, v27.4s\n"
- "ldr s11, [%[inptr0], x23]\n"
- "fadd v10.4s, v24.4s, v11.4s\n"
- "ldr s12, [%[inptr0], x24]\n"
- "fsub v11.4s, v24.4s, v11.4s\n"
- "ldr s20, [x25]\n"
- "fadd v7.4s, v17.4s, v4.4s\n"
- "ldr s19, [x25, %[in_col_stride1]]\n"
- "mov v4.16b, v4.16b\n"
- "ldr s22, [x25, x21]\n"
- "mov v1.16b, v13.16b\n"
- "ldr s14, [x25, x22]\n"
- "fmul v11.4s, v11.4s, v0.s[0]\n"
- "ldr s18, [x25, x23]\n"
- "fadd v7.4s, v7.4s, v10.4s\n"
- "ldr s3, [x25, x24]\n"
- "fmla v4.4s, v10.4s, v0.s[1]\n"
- "ldr s16, [x13]\n"
- "fadd v2.4s, v19.4s, v22.4s\n"
- "ldr s21, [x13, %[in_col_stride1]]\n"
- "fadd v8.4s, v13.4s, v11.4s\n"
- "ldr s24, [x13, x21]\n"
- "fmla v1.4s, v11.4s, v0.s[1]\n"
- "ldr s25, [x13, x22]\n"
- "fadd v23.4s, v14.4s, v18.4s\n"
- "ldr s17, [x13, x23]\n"
- "fadd v11.4s, v20.4s, v2.4s\n"
- "ldr s9, [x13, x24]\n"
- "fsub v15.4s, v19.4s, v22.4s\n"
- "ldr s6, [x26]\n"
- "fadd v1.4s, v1.4s, v12.4s\n"
- "ldr s19, [x26, %[in_col_stride1]]\n"
- "fsub v31.4s, v14.4s, v18.4s\n"
- "ldr s22, [x26, x21]\n"
- "fadd v11.4s, v11.4s, v23.4s\n"
- "ldr s12, [x26, x22]\n"
- "mov v13.16b, v2.16b\n"
- "ldr s26, [x26, x23]\n"
- "mov v2.16b, v15.16b\n"
- "ldr s5, [x26, x24]\n"
- "fmul v31.4s, v31.4s, v0.s[0]\n"
- "ldr s10, [x14]\n"
- "fmla v13.4s, v23.4s, v0.s[1]\n"
- "fadd v29.4s, v21.4s, v24.4s\n"
- "fsub v14.4s, v21.4s, v24.4s\n"
- "fadd v18.4s, v25.4s, v17.4s\n"
- "fsub v28.4s, v25.4s, v17.4s\n"
- "ldr s30, [x14, %[in_col_stride1]]\n"
- "fadd v15.4s, v15.4s, v31.4s\n"
- "fmla v2.4s, v31.4s, v0.s[1]\n"
- "fadd v27.4s, v16.4s, v29.4s\n"
- "mov v21.16b, v29.16b\n"
- "fadd v20.4s, v19.4s, v22.4s\n"
- "fsub v17.4s, v19.4s, v22.4s\n"
- "fmul v28.4s, v28.4s, v0.s[0]\n"
- "ldr s31, [x14, x21]\n"
- "fadd v2.4s, v2.4s, v3.4s\n"
- "ldr s23, [x14, x22]\n"
- "fadd v27.4s, v27.4s, v18.4s\n"
- "fmla v21.4s, v18.4s, v0.s[1]\n"
- "fadd v29.4s, v12.4s, v26.4s\n"
- "fadd v24.4s, v6.4s, v20.4s\n"
- "fsub v16.4s, v12.4s, v26.4s\n"
- "mov v6.16b, v20.16b\n"
- "fadd v25.4s, v30.4s, v31.4s\n"
- "fsub v22.4s, v30.4s, v31.4s\n"
- "fadd v20.4s, v14.4s, v28.4s\n"
- "mov v3.16b, v14.16b\n"
- "fadd v24.4s, v24.4s, v29.4s\n"
- "fmla v6.4s, v29.4s, v0.s[1]\n"
- "fmul v16.4s, v16.4s, v0.s[0]\n"
- "ldr s26, [x14, x23]\n"
- "fmla v3.4s, v28.4s, v0.s[1]\n"
- "fadd v14.4s, v23.4s, v26.4s\n"
- "fadd v29.4s, v10.4s, v25.4s\n"
- "fsub v23.4s, v23.4s, v26.4s\n"
- "mov v10.16b, v25.16b\n"
- "fadd v31.4s, v11.4s, v27.4s\n"
- "fsub v12.4s, v11.4s, v27.4s\n"
- "ldr s18, [x14, x24]\n"
- "fadd v3.4s, v3.4s, v9.4s\n"
- "ldr s19, [x27]\n"
- "fadd v29.4s, v29.4s, v14.4s\n"
- "fmul v23.4s, v23.4s, v0.s[0]\n"
- "fmla v10.4s, v14.4s, v0.s[1]\n"
- "fadd v25.4s, v7.4s, v31.4s\n"
- "mov v26.16b, v31.16b\n"
- "fadd v31.4s, v15.4s, v20.4s\n"
- "fsub v11.4s, v15.4s, v20.4s\n"
- "fadd v30.4s, v13.4s, v21.4s\n"
- "fsub v9.4s, v13.4s, v21.4s\n"
- "fadd v28.4s, v24.4s, v29.4s\n"
- "fsub v24.4s, v24.4s, v29.4s\n"
- "ldr s27, [x27, %[in_col_stride1]]\n"
- "fadd v15.4s, v8.4s, v31.4s\n"
- "mov v14.16b, v31.16b\n"
- "fadd v13.4s, v4.4s, v30.4s\n"
- "mov v4.16b, v30.16b\n"
- "fadd v25.4s, v25.4s, v28.4s\n"
- "fmla v26.4s, v28.4s, v0.s[1]\n"
- "fmul v24.4s, v24.4s, v0.s[0]\n"
- "fadd v21.4s, v6.4s, v10.4s\n"
- "fsub v10.4s, v6.4s, v10.4s\n"
- "fadd v6.4s, v2.4s, v3.4s\n"
- "fsub v2.4s, v2.4s, v3.4s\n"
- "ldr s29, [x27, x21]\n"
- "str s25, [%[outptr0]]\n" // Store output (0, 0)
- "fadd v20.4s, v17.4s, v16.4s\n"
- "str s26, [x28]\n" // Store output (2, 0)
- "mov v7.16b, v17.16b\n"
- "fadd v17.4s, v22.4s, v23.4s\n"
- "mov v8.16b, v22.16b\n"
- "fadd v13.4s, v13.4s, v21.4s\n"
- "fmul v10.4s, v10.4s, v0.s[0]\n"
- "fmla v7.4s, v16.4s, v0.s[1]\n"
- "ldr s28, [x27, x22]\n"
- "fmla v8.4s, v23.4s, v0.s[1]\n"
- "ldr s23, [x27, x23]\n"
- "fmla v4.4s, v21.4s, v0.s[1]\n"
- "ldr s25, [x27, x24]\n"
- "str s13, [%[outptr0], x15]\n" // Store output (0, 2)
- "fadd v16.4s, v27.4s, v29.4s\n"
- "fadd v7.4s, v7.4s, v5.4s\n"
- "fadd v30.4s, v28.4s, v23.4s\n"
- "fadd v8.4s, v8.4s, v18.4s\n"
- "fsub v29.4s, v27.4s, v29.4s\n"
- "str s4, [x28, x15]\n" // Store output (2, 2)
- "fsub v27.4s, v28.4s, v23.4s\n"
- "fadd v19.4s, v19.4s, v16.4s\n"
- "mov v16.16b, v16.16b\n"
- "fadd v21.4s, v12.4s, v24.4s\n"
- "mov v26.16b, v12.16b\n"
- "mov v5.16b, v29.16b\n"
- "fadd v22.4s, v20.4s, v17.4s\n"
- "fmul v27.4s, v27.4s, v0.s[0]\n"
- "fmla v16.4s, v30.4s, v0.s[1]\n"
- "fadd v19.4s, v19.4s, v30.4s\n"
- "fmla v26.4s, v24.4s, v0.s[1]\n"
- "str s21, [x17]\n" // Store output (1, 0)
- "fsub v20.4s, v20.4s, v17.4s\n"
- "fadd v15.4s, v15.4s, v22.4s\n"
- "fmla v14.4s, v22.4s, v0.s[1]\n"
- "fadd v18.4s, v29.4s, v27.4s\n"
- "fmla v5.4s, v27.4s, v0.s[1]\n"
- "fadd v26.4s, v26.4s, v19.4s\n"
- "mov v28.16b, v11.16b\n"
- "fmul v20.4s, v20.4s, v0.s[0]\n"
- "fadd v13.4s, v9.4s, v10.4s\n"
- "str s15, [%[outptr0], %[output_col_stride1]]\n" // Store output (0, 1)
- "mov v15.16b, v9.16b\n"
- "str s14, [x28, %[output_col_stride1]]\n" // Store output (2, 1)
- "fadd v5.4s, v5.4s, v25.4s\n"
- "str s26, [x18]\n" // Store output (3, 0)
- "fadd v30.4s, v7.4s, v8.4s\n"
- "str s13, [x17, x15]\n" // Store output (1, 2)
- "fadd v12.4s, v11.4s, v20.4s\n"
- "fmla v28.4s, v20.4s, v0.s[1]\n"
- "fmla v15.4s, v10.4s, v0.s[1]\n"
- "fadd v1.4s, v1.4s, v6.4s\n"
- "fsub v8.4s, v7.4s, v8.4s\n"
- "mov v6.16b, v6.16b\n"
- "mov v9.16b, v2.16b\n"
- "str s12, [x17, %[output_col_stride1]]\n" // Store output (1, 1)
- "fadd v28.4s, v28.4s, v18.4s\n"
- "fadd v15.4s, v15.4s, v16.4s\n"
- "fadd v1.4s, v1.4s, v30.4s\n"
- "fmul v8.4s, v8.4s, v0.s[0]\n"
- "fmla v6.4s, v30.4s, v0.s[1]\n"
- "str s28, [x18, %[output_col_stride1]]\n" // Store output (3, 1)
- "str s1, [%[outptr0], x16]\n" // Store output (0, 3)
- "str s6, [x28, x16]\n" // Store output (2, 3)
- "fadd v2.4s, v2.4s, v8.4s\n"
- "str s15, [x18, x15]\n" // Store output (3, 2)
- "fmla v9.4s, v8.4s, v0.s[1]\n"
- "str s2, [x17, x16]\n" // Store output (1, 3)
- "fadd v9.4s, v9.4s, v5.4s\n"
- "str s9, [x18, x16]\n" // Store output (3, 3)
- "6:\n" // End
- : [outptr0] "+r" (output), [inptr0] "+r" (inptr)
- : [output_col_stride1] "r" (output_col_stride * sizeof(float)), [pcoeffs] "r" (coeffs), [n_channels] "r" ((long) n_channels), [in_row_stride] "r" (6 * matrix_stride * sizeof(float)), [in_col_stride1] "r" (matrix_stride * sizeof(float)), [output_row_stride] "r" (output_row_stride * sizeof(float))
- : "cc", "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v2", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v31", "v4", "v5", "v6", "v7", "v8", "v9", "x13", "x14", "x15", "x16", "x17", "x18", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "memory"
- );
- }
-}
-
-#else
-
template <>
void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots::Integers>::transform_tile(
const int n_channels,
@@ -1713,7 +36,9 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots
const float* bptr,
float* const output,
const int output_row_stride,
- const int output_col_stride
+ const int output_col_stride,
+ const float output_min,
+ const float output_max
)
{
// Construct a map to the output cells
@@ -1728,7 +53,79 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots
// For each channel of the output
int channels_remaining = n_channels;
-#ifdef __arm__
+
+#ifdef __aarch64__
+ for (; channels_remaining >= 4; channels_remaining -= 4)
+ {
+ // Matrices used and computed during this transform
+ float32x4_t F[6][6], FZ[6][4], f[4][4], b;
+
+ // Read a 6x6 tile in the Winograd domain
+ for (int i = 0, m = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++, m++)
+ {
+ F[i][j] = vld1q_f32(inptr + m*matrix_stride);
+ }
+ }
+ inptr += 4;
+
+ // Compute the matrix F Z
+ for (int i = 0; i < 6; i++)
+ {
+ // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
+ FZ[i][0] = vaddq_f32(vaddq_f32(vaddq_f32(F[i][0], F[i][1]), vaddq_f32(F[i][2], F[i][3])), F[i][4]);
+
+ // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4];
+ FZ[i][1] = vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 2.0f);
+
+ // FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4];
+ FZ[i][2] = vmlaq_n_f32(vaddq_f32(F[i][1], F[i][2]), vaddq_f32(F[i][3], F[i][4]), 4.0f);
+
+ // FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5];
+ FZ[i][3] = vaddq_f32(vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 8.0f), F[i][5]);
+ }
+
+ // Compute the output tile f = ZT F Z
+ for (int j = 0; j < 4; j++)
+ {
+ // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
+ f[0][j] = vaddq_f32(vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), vaddq_f32(FZ[2][j], FZ[3][j])), FZ[4][j]);
+
+ // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j];
+ f[1][j] = vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 2.0f);
+
+ // f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j];
+ f[2][j] = vmlaq_n_f32(vaddq_f32(FZ[1][j], FZ[2][j]), vaddq_f32(FZ[3][j], FZ[4][j]), 4.0f);
+
+ // f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j];
+ f[3][j] = vaddq_f32(vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 8.0f), FZ[5][j]);
+ }
+
+ // Write out the output tile
+ if (bptr != nullptr)
+ {
+ b = vld1q_f32(bptr);
+ bptr += 4;
+ }
+ else
+ {
+ b = vdupq_n_f32(0.0f);
+ }
+ for (int i = 0; i < output_tile_rows; i++)
+ {
+ for (int j = 0; j < output_tile_cols; j++)
+ {
+ const auto y =
+ vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)),
+ vdupq_n_f32(output_min));
+ vst1q_f32(outptrs[i][j], y);
+ outptrs[i][j] += 4;
+ }
+ }
+ }
+#endif // __aarch64__
+#ifdef __arm_any__
for (; channels_remaining >= 2; channels_remaining -= 2)
{
// Matrices used and computed during this transform
@@ -1790,12 +187,15 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots
{
for (int j = 0; j < output_tile_cols; j++)
{
- vst1_f32(outptrs[i][j], vadd_f32(f[i][j], b));
+ const auto y =
+ vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)),
+ vdup_n_f32(output_min));
+ vst1_f32(outptrs[i][j], y);
outptrs[i][j] += 2;
}
}
}
-#endif // __arm__
+#endif // __arm_any__
for (; channels_remaining; channels_remaining--)
{
// Matrices used and computed during this transform
@@ -1842,14 +242,13 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots
{
for (int j = 0; j < output_tile_cols; j++)
{
- *(outptrs[i][j]++) = f[i][j] + b;
+ const auto y = std::max(std::min(f[i][j] + b, output_max), output_min);
+ *(outptrs[i][j]++) = y;
}
}
}
}
-#endif
-
template class OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots::Integers>;
} // namespace winograd
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp
index ce921cea01..05f06a81ee 100644
--- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,7 +36,9 @@ void OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::transfo
const float* bptr,
float* const output,
const int, // No need to stride across rows
- const int output_col_stride
+ const int output_col_stride,
+ const float output_min,
+ const float output_max
)
{
// Construct a map to the output cells
@@ -76,7 +78,9 @@ void OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::transfo
}
for (int j = 0; j < output_tile_cols; j++)
{
- vst1q_f32(outptrs[j], f[j] + b);
+ const auto y = vminq_f32(vmaxq_f32(f[j] + b, vdupq_n_f32(output_min)),
+ vdupq_n_f32(output_max));
+ vst1q_f32(outptrs[j], y);
outptrs[j] += 4;
}
}
@@ -107,7 +111,9 @@ void OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::transfo
}
for (int j = 0; j < output_tile_cols; j++)
{
- vst1_f32(outptrs[j], f[j] + b);
+ const auto y = vmin_f32(vmax_f32(f[j] + b, vdup_n_f32(output_min)),
+ vdup_n_f32(output_max));
+ vst1_f32(outptrs[j], y);
outptrs[j] += 2;
}
}
@@ -138,7 +144,7 @@ void OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::transfo
}
for (int j = 0; j < output_tile_cols; j++)
{
- *(outptrs[j]++) = f[j] + b;
+ *(outptrs[j]++) = std::max(std::min(f[j] + b, output_max), output_min);
}
}
}
diff --git a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
index e699ad1815..6983c1c01b 100644
--- a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
@@ -33,6 +33,7 @@
#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
#include "support/ToolchainSupport.h"
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp"
namespace arm_compute
@@ -232,6 +233,31 @@ bool check_support_fast_math(const Size2D &output_tile, const Size2D &kernel_siz
return std::find(fast_math_winograd.begin(), fast_math_winograd.end(), p) != fast_math_winograd.end();
}
+inline bool fuse_function_supported(const ActivationLayerInfo &act_info)
+{
+ return act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ||
+ act_info.activation() == ActivationLayerInfo::ActivationFunction::BOUNDED_RELU;
+}
+
+arm_gemm::Activation arm_gemm_activation_from_acl_activation(const ActivationLayerInfo &act_info)
+{
+ switch(act_info.activation())
+ {
+ case ActivationLayerInfo::ActivationFunction::RELU:
+ {
+ return arm_gemm::Activation(arm_gemm::Activation::Type::ReLU, act_info.a(), act_info.b());
+ }
+ case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
+ {
+ return arm_gemm::Activation(arm_gemm::Activation::Type::BoundedReLU, act_info.a(), act_info.b());
+ }
+ default:
+ {
+ return arm_gemm::Activation(arm_gemm::Activation::Type::None);
+ }
+ }
+}
+
} //namespace
NEWinogradConvolutionLayer::NEWinogradConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager)
@@ -257,6 +283,8 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor *
const Size2D kernel_size = Size2D(weights->info()->dimension(width_idx), weights->info()->dimension(height_idx));
const Size2D output_tile = winograd_output_tile(input_dims, kernel_size);
+
+
// Check if the Winograd configuration requires fast math
if(!enable_fast_math)
{
@@ -388,21 +416,15 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor *
* data_type_size;
// Output storage
- const size_t output_storage_size = transform_output_kernel->get_output_storage_size(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels,
- use_same_padding)
- * data_type_size;
- ;
- const KernelShape kernel_shape({ out_channels, static_cast<int>(kernel_size.height), static_cast<int>(kernel_size.width), in_channels });
- const int kernel_matrix_stride = transform_weights_kernel->get_matrix_stride(kernel_shape);
-
- const int output_matrix_stride = transform_output_kernel->get_matrix_stride(kernel_shape, in_shape, use_padding_type);
- const auto output_shape(transform_output_kernel->get_output_shape(kernel_shape, in_shape, use_padding_type));
-
- const int input_matrix_stride = transform_input_kernel->get_matrix_stride(kernel_shape, in_shape, use_padding_type);
+ const size_t output_storage_size = transform_output_kernel->get_output_storage_size(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels) * data_type_size;
+ const int kernel_matrix_stride = transform_weights_kernel->get_matrix_stride(out_channels, in_channels);
+ const int output_matrix_stride = transform_output_kernel->get_matrix_stride(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels);
+ const auto output_shape = transform_output_kernel->get_output_shape(in_shape.n_rows, in_shape.n_cols, use_padding_type == PADDING_SAME);
+ const int input_matrix_stride = transform_input_kernel->get_matrix_stride(in_shape.n_batches, in_channels, in_shape.n_rows, in_shape.n_cols, use_padding_type == PADDING_SAME);
// Configure GEMM
- const int tile_rows = iceildiv(output_shape.n_rows, output_tile.height);
- const int tile_cols = iceildiv(output_shape.n_cols, output_tile.width);
+ const int tile_rows = iceildiv(output_shape.first, output_tile.height);
+ const int tile_cols = iceildiv(output_shape.second, output_tile.width);
const int m = in_shape.n_batches * tile_rows * tile_cols;
const int k = in_shape.n_channels;
const int n = out_channels;
@@ -489,9 +511,19 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor *
_memory_group.manage(&_output_nhwc);
output_to_use = &_output_nhwc;
}
- transform_output_kernel->configure(biases, &_output_transformed,
- output_matrix_stride, output_to_use,
- in_shape.n_batches, output_shape.n_rows, output_shape.n_cols, out_channels, &_output_workspace);
+ const arm_gemm::Activation activation = arm_gemm_activation_from_acl_activation(act_info);
+
+ transform_output_kernel->configure(biases,
+ &_output_transformed,
+ output_matrix_stride,
+ output_to_use,
+ in_shape.n_batches,
+ output_shape.first,
+ output_shape.second,
+ out_channels,
+ &_output_workspace,
+ activation);
+
const size_t output_workspace_size = transform_output_kernel->get_working_space_size(max_num_threads);
TensorInfo output_workspace_info(TensorShape(output_workspace_size), 1, _output->info()->data_type());
_output_workspace.allocator()->init(output_workspace_info);
@@ -510,7 +542,7 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor *
_transform_output_kernel = std::move(transform_output_kernel);
//Configure Activation Layer
- _is_activationlayer_enabled = act_info.enabled();
+ _is_activationlayer_enabled = act_info.enabled() && ! fuse_function_supported(act_info);
if(_is_activationlayer_enabled)
{
_activationlayer_function.configure(_output, nullptr, act_info);
@@ -546,7 +578,7 @@ void NEWinogradConvolutionLayer::run()
_permute_output.run();
}
- if(_is_activationlayer_enabled)
+ if(_is_activationlayer_enabled )
{
_activationlayer_function.run();
}