From 5264b7d5555ec980f9c52c719122479d0d676af8 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Mon, 21 Oct 2019 14:25:41 +0100 Subject: COMPMID-2576: Fuse activation in Winograd output transform. Change-Id: I26dd1307847adeaaefae0a7374b9858c07d71372 Signed-off-by: Pablo Tello Reviewed-on: https://review.mlplatform.org/c/2172 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Gian Marco Iodice --- .../kernels/NEWinogradConvolutionLayerKernel.h | 126 +- .../NEON/kernels/convolution/winograd/winograd.hpp | 105 +- .../convolution/winograd/winograd_layer.hpp | 36 +- .../kernels/NEWinogradConvolutionLayerKernel.cpp | 69 +- .../NEON/kernels/convolution/winograd/winograd.cpp | 250 ++- .../winograd/winograd_transforms/input.hpp | 7 +- .../winograd/winograd_transforms/output.hpp | 55 +- .../output_2_7_fp32_fp32_integers.cpp | 16 +- .../output_2x2_3x3_fp32_fp32_integers.cpp | 19 +- .../output_2x2_5x5_fp32_fp32_integers.cpp | 19 +- .../output_4_5_fp32_fp32_integers.cpp | 19 +- .../output_4x4_3x3_fp32_fp32_integers.cpp | 1769 +------------------- .../output_6_3_fp32_fp32_integers.cpp | 16 +- .../NEON/functions/NEWinogradConvolutionLayer.cpp | 68 +- 14 files changed, 518 insertions(+), 2056 deletions(-) diff --git a/arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h b/arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h index f6b189cb1c..962037dd4f 100644 --- a/arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h +++ b/arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h @@ -64,13 +64,15 @@ public: /** Gets the stride between matrices in the input worspace * - * @param[in] kernel_shape The shape of the weights tensor. - * @param[in] input_shape The shape of the input tensor. - * @param[in] padding_type The type of padding to be used. + * @param[in] num_batches Number of batches in the input tensor. + * @param[in] num_channels Number of feature maps in the input tensor. + * @param[in] num_rows Number of rows in each feature map. + * @param[in] num_cols Number of columns in each feature map. + * @param[in] same_padding Use "SAME" padding, otherwise use "VALID". * * @return Stride expressed in bytes. */ - virtual int get_matrix_stride(const KernelShape &kernel_shape, const Tensor4DShape &input_shape, const PaddingType padding_type) const = 0; + virtual int get_matrix_stride(int num_batches, int num_channels, int num_rows, int num_cols, bool same_padding) const = 0; /** Configure the output transform kernel. * @@ -141,13 +143,20 @@ public: /** Gets the stride between matrices in the input worspace * - * @param[in] kernel_shape The shape of the weights tensor. - * @param[in] input_shape The shape of the input tensor. - * @param[in] padding_type The type of padding to be used. + * @param[in] num_batches Number of batches in the input tensor. + * @param[in] num_channels Number of feature maps in the input tensor. + * @param[in] num_rows Number of rows in each feature map. + * @param[in] num_cols Number of columns in each feature map. + * @param[in] same_padding Use "SAME" padding, otherwise use "VALID". * * @return Stride expressed in bytes. */ - int get_matrix_stride(const KernelShape &kernel_shape, const Tensor4DShape &input_shape, const PaddingType padding_type) const override; + int get_matrix_stride( + int num_batches, + int num_channels, + int num_rows, + int num_cols, + bool same_padding) const override; /** Default constructor */ NEWinogradLayerTransformInputKernel(); @@ -241,31 +250,35 @@ public: * @param[in] num_rows Number of rows in each feature map of the input tensor. * @param[in] num_cols Number of columns in each feature map of the input tensor. * @param[in] num_output_channels Number of feature maps in the output tensor. - * @param[in] same_padding Use "SAME" padding, otherwise use "VALID". * * @return Storage size (in units of TOut) required. */ - virtual unsigned int get_output_storage_size(int num_batches, int num_rows, int num_cols, int num_output_channels, bool same_padding) const = 0; + virtual unsigned int get_output_storage_size(int num_batches, int num_rows, int num_cols, int num_output_channels) const = 0; /** Gets the stride between matrices in the output worspace * - * @param[in] kernel_shape The shape of the weights tensor. - * @param[in] input_shape The shape of the input tensor. - * @param[in] padding_type The type of padding to be used. + * @param[in] num_batches Number of batches in the output tensor. + * @param[in] num_rows Number of rows in each feature map of the input tensor. + * @param[in] num_cols Number of columns in each feature map of the input tensor. + * @param[in] num_output_channels Number of feature maps in the output tensor. * * @return Stride expressed in bytes. */ - virtual int get_matrix_stride(const KernelShape &kernel_shape, const Tensor4DShape &input_shape, const PaddingType padding_type) const = 0; + virtual int get_matrix_stride(int num_batches, int num_rows, int num_cols, int num_output_channels) const = 0; /** Get the output shape of a convolution. * - * @param[in] kernel_shape The shape of the weights tensor. - * @param[in] in_shape The shape of the input tensor. - * @param[in] padding The type of padding to be used. + * @param[in] num_rows Number of rows in each feature map of the input tensor. + * @param[in] num_cols Number of columns in each feature map of the input tensor. + * @param[in] padding_same True if padding is SAME, false otherwise * - * @return Stride expressed in bytes. + * @return Shape of the output tensor */ - virtual Tensor4DShape get_output_shape(const KernelShape &kernel_shape, const Tensor4DShape &in_shape, const PaddingType padding) const = 0; + virtual std::pair get_output_shape( + int num_rows, /* Number of rows in each feature map of the input tensor. */ + int num_cols, /* Number of columns in each feature map of the input tensor. */ + bool padding_same /* True if padding is SAME, false otherwise */ + ) const = 0; /** Configure the output transform kernel. * @@ -278,17 +291,19 @@ public: * @param[in] num_cols Number of columns in output tensor. * @param[in] num_channels Number of feature maps in the output tensor. * @param[in] workspace Tensor to be used as the working space during the computation. + * @param[in] activation Activation to be used */ virtual void configure( - const ITensor *biases, - const ITensor *transformed_output, - const int matrix_stride, - ITensor *output_nhwc, - const int num_batches, - const int num_rows, - const int num_cols, - const int num_channels, - ITensor *workspace) = 0; + const ITensor *biases, + const ITensor *transformed_output, + const int matrix_stride, + ITensor *output_nhwc, + const int num_batches, + const int num_rows, + const int num_cols, + const int num_channels, + ITensor *workspace, + const arm_gemm::Activation &activation) = 0; virtual ~INEWinogradLayerTransformOutputKernel() { @@ -326,30 +341,33 @@ public: * @param[in] num_rows Number of rows in each feature map of the input tensor. * @param[in] num_cols Number of columns in each feature map of the input tensor. * @param[in] num_output_channels Number of feature maps in the output tensor. - * @param[in] same_padding Use "SAME" padding, otherwise use "VALID". * * @return Storage size (in units of TOut) required. */ - unsigned int get_output_storage_size(int num_batches, int num_rows, int num_cols, int num_output_channels, bool same_padding) const override; + unsigned int get_output_storage_size(int num_batches, int num_rows, int num_cols, int num_output_channels) const override; /** Gets the stride between matrices in the output worspace * - * @param[in] kernel_shape The shape of the weights tensor. - * @param[in] input_shape The shape of the input tensor. - * @param[in] padding_type The type of padding to be used. + * @param[in] num_batches Number of batches in the output tensor. + * @param[in] num_rows Number of rows in each feature map of the input tensor. + * @param[in] num_cols Number of columns in each feature map of the input tensor. + * @param[in] num_output_channels Number of feature maps in the output tensor. * * @return Stride expressed in bytes. */ - int get_matrix_stride(const KernelShape &kernel_shape, const Tensor4DShape &input_shape, const PaddingType padding_type) const override; + int get_matrix_stride(int num_batches, int num_rows, int num_cols, int num_output_channels) const override; /** Get the output shape of a convolution. * - * @param[in] kernel_shape The shape of the weights tensor. - * @param[in] in_shape The shape of the input tensor. - * @param[in] padding The type of padding to be used. + * @param[in] num_rows Number of rows in each feature map of the input tensor. + * @param[in] num_cols Number of columns in each feature map of the input tensor. + * @param[in] padding_same True if padding is SAME, false otherwise * - * @return Stride expressed in bytes. + * @return Shape of the output tensor */ - Tensor4DShape get_output_shape(const KernelShape &kernel_shape, const Tensor4DShape &in_shape, const PaddingType padding) const override; + std::pair get_output_shape( + int num_rows, /* Number of rows in each feature map of the input tensor. */ + int num_cols, /* Number of columns in each feature map of the input tensor. */ + bool padding_same) const override; /** Get the working space required to perform the transformation. * @@ -374,17 +392,19 @@ public: * @param[in] num_cols Number of columns in output tensor. * @param[in] num_channels Number of feature maps in the output tensor. * @param[in] workspace Tensor to be used as the working space during the computation. + * @param[in] activation Activation to be used */ void configure( - const ITensor *biases, - const ITensor *transformed_output, - const int matrix_stride, - ITensor *output_nhwc, - const int num_batches, - const int num_rows, - const int num_cols, - const int num_channels, - ITensor *workspace) override; + const ITensor *biases, + const ITensor *transformed_output, + const int matrix_stride, + ITensor *output_nhwc, + const int num_batches, + const int num_rows, + const int num_cols, + const int num_channels, + ITensor *workspace, + const arm_gemm::Activation &activation) override; void run(const Window &window, const ThreadInfo &info) override; @@ -448,11 +468,12 @@ public: virtual unsigned int get_weight_storage_size(int num_output_channels, int num_input_channels) const = 0; /** Gets the stride between matrices in the kernel worspace * - * @param[in] kernel_shape The shape of the weights tensor. + * @param[in] num_output_channels Number of output feature maps. + * @param[in] num_input_channels Number of input feature maps. * * @return Stride expressed in bytes. */ - virtual int get_matrix_stride(const KernelShape &kernel_shape) const = 0; + virtual int get_matrix_stride(int num_output_channels, int num_input_channels) const = 0; /** Configure the weights transform kernel. * @@ -535,11 +556,12 @@ public: /** Gets the stride between matrices in the input worspace * - * @param[in] kernel_shape The shape of the weights tensor. + * @param[in] num_output_channels Number of output feature maps. + * @param[in] num_input_channels Number of input feature maps. * * @return Stride expressed in bytes. */ - int get_matrix_stride(const KernelShape &kernel_shape) const override; + int get_matrix_stride(int num_output_channels, int num_input_channels) const override; void run(const Window &window, const ThreadInfo &info) override; bool is_parallelisable() const override; diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp index 183c9c1061..bc0d9d4296 100644 --- a/arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp +++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,9 +24,10 @@ #pragma once -#include "convolution.hpp" -#include "tensor.hpp" -#include "utils.hpp" +#include "arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp" + +#include +#include namespace winograd { @@ -308,7 +309,8 @@ class OutputTransform : public IOutputTransform int n_batches, /**< Number of batches in output tensor. */ int n_rows, /**< Number of rows in output tensor. */ int n_cols, /**< Number of columns in output tensor. */ - int n_channels /**< Number of channels in output tensor. */ + int n_channels, /**< Number of channels in output tensor. */ + const arm_gemm::Activation &activation ); OutputTransform(OutputTransform&) = delete; @@ -344,6 +346,7 @@ class OutputTransform : public IOutputTransform static constexpr int output_tile_cols = InnerTileCols - KernelCols + 1; const int _n_batches, _n_rows, _n_cols, _n_channels; + const TOut _output_min, _output_max; private: void transform_uncropped_tile( @@ -372,7 +375,9 @@ class OutputTransform : public IOutputTransform const TOut* biases, TOut* output, int output_row_stride, - int output_col_stride + int output_col_stride, + TOut output_min, + TOut output_max ); /** Get the working space for a thread. */ @@ -405,7 +410,8 @@ class OutputTransform : int n_batches, /**< Number of batches in output tensor. */ int n_rows, /**< Number of rows in output tensor. */ int n_cols, /**< Number of columns in output tensor. */ - int n_channels /**< Number of channels in output tensor. */ + int n_channels, /**< Number of channels in output tensor. */ + const arm_gemm::Activation &activation ); /** Set pointers to the output tensor written by the transform. */ @@ -528,79 +534,84 @@ class WinogradGEMM typedef TIn InputType; /** Get the output shape of a convolution. */ - static Tensor4DShape get_output_shape( - const KernelShape &kernel_shape, - const Tensor4DShape &in_shape, - const PaddingType padding - ); - - /* Get the memory required to transform the kernel. - */ - static size_t get_kernel_transform_working_size(const KernelShape &shape); + static std::pair get_output_shape( + const std::pair input_shape, + bool padding_same); /** Get the memory required to store the kernel transformed into the * Winograd domain. */ - static size_t get_kernel_storage_size(const KernelShape &shape); + static size_t get_kernel_storage_size(unsigned int n_input_channels, + unsigned int n_output_channels); /** Get the memory required to store the input tensor transformed into * the Winograd domain. */ static size_t get_input_storage_size( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding_type - ); + unsigned int n_batches, // Number of batches + unsigned int n_rows, // Number of input rows + unsigned int n_cols, // Number of input columns + unsigned int n_channels, // Number of input channels + bool padding_same); /** Get the memory required to store the output tensor in the Winograd * domain. */ static size_t get_output_storage_size( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding_type - ); + unsigned int n_batches, // Number of batches + unsigned int n_rows, // Number of output rows + unsigned int n_cols, // Number of output columns + unsigned int n_channels // Number of output channels + ); /** Get the memory required to apply a Winograd operator to some input. */ static size_t get_working_space_size( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding_type - ); + unsigned int n_batches, + unsigned int n_rows, // Number of input rows + unsigned int n_cols, // Number of input columns + unsigned int n_input_channels, // Number of input channels + unsigned int n_output_channels, // Number of output channels + bool padding_same); /* Get the memory required by a single "input" matrix. */ static size_t get_input_matrix_size( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding_type - ); + unsigned int n_batches, // Number of batches + unsigned int n_rows, // Number of input rows + unsigned int n_cols, // Number of input columns + unsigned int n_channels, // Number of input channels + bool padding_same); static int get_input_matrix_stride( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding_type - ); + unsigned int n_batches, // Number of batches + unsigned int n_rows, // Number of input rows + unsigned int n_cols, // Number of input columns + unsigned int n_channels, // Number of input channels + bool padding_same); /* Get the memory required by a single "output" matrix. */ static size_t get_output_matrix_size( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding_type - ); + unsigned int n_batches, // Number of batches + unsigned int n_rows, // Number of output rows + unsigned int n_cols, // Number of output columns + unsigned int n_channels // Number of output channels + ); static int get_output_matrix_stride( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding_type - ); + unsigned int n_batches, // Number of batches + unsigned int n_rows, // Number of output rows + unsigned int n_cols, // Number of output columns + unsigned int n_channels // Number of output channels + ); /* Get the memory required by a single "kernel" matrix. */ - static size_t get_kernel_matrix_size(const KernelShape &shape); - static int get_kernel_matrix_stride(const KernelShape &shape); + static size_t get_kernel_matrix_size(unsigned int n_input_channels, + unsigned int n_output_channels); + static int get_kernel_matrix_stride(unsigned int n_input_channels, + unsigned int n_output_channels); static constexpr int M_BLOCK = 4; /** Size of block used by GEMM. */ static constexpr int N_BLOCK = 16; /** Size of block used by GEMM. */ diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_layer.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_layer.hpp index 9d418bebb4..ed8fede385 100644 --- a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_layer.hpp +++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_layer.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -23,9 +23,6 @@ */ #pragma once - -#include - #include "arm_gemm_local.hpp" #include "arm_gemm.hpp" #include "winograd.hpp" @@ -42,8 +39,8 @@ class IWinogradConvolutionLayer virtual unsigned int weight_transform_get_window(void) const = 0; virtual void weight_transform_run(unsigned int start, unsigned int stop) = 0; - virtual ITransform& input_transform(void) = 0; // Expose the input transform - virtual ITransform& output_transform(void) = 0; // Expose the output transform + virtual IInputTransform& input_transform(void) = 0; // Expose the input transform + virtual IOutputTransform& output_transform(void) = 0; // Expose the output transform virtual arm_gemm::IGemmCommon *gemm(void) = 0; // Expose the underlying GEMM }; @@ -65,15 +62,18 @@ template class WinogradConvolutionLayer : public IWinogradConvolutionLayer { + public: + using WinogradBase = winograd::WinogradGEMM; + using WeightsTransform = typename WinogradBase::template WeightsTransform; + using InputTransform = typename WinogradBase::template InputTransform; + using WinogradConv = typename WinogradBase::template Convolution; + using OutputTransform = typename WinogradBase::template OutputTransform; + private: static constexpr int InnerTileRows = OutputTileRows + KernelRows - 1; static constexpr int InnerTileCols = OutputTileCols + KernelCols - 1; static constexpr int N_GEMMS = InnerTileRows * InnerTileCols; - const KernelShape _kernel_shape; - const Tensor4DShape _input_shape; - const PaddingType _padding; - const Tensor4DShape _output_shape; const int _n_output_rows, _n_output_cols; const int _kernel_matrix_stride, _kernel_matrix_row_stride; const int _input_matrix_stride, _input_matrix_row_stride; @@ -81,19 +81,14 @@ class WinogradConvolutionLayer : public IWinogradConvolutionLayer const int _tile_rows, _tile_cols; const int _m, _k, _n; - public: - using WinogradBase = winograd::WinogradGEMM; - using WeightsTransform = typename WinogradBase::template WeightsTransform; - using InputTransform = typename WinogradBase::template InputTransform; - using WinogradConv = typename WinogradBase::template Convolution; - using OutputTransform = typename WinogradBase::template OutputTransform; - - /* Public member variables. */ WeightsTransform weights_transform; /** Operator to transform weights to Winograd domain. */ InputTransform _input_transform; /** Operator to transform input to Winograd domain. */ + const arm_gemm::GemmArgs gemm_args; arm_gemm::UniqueGemmCommon gemms; /** Operator to perform multiple GEMMs. */ OutputTransform _output_transform; /** Operator to transform output from Winograd domain. */ + public: + /** Determine how much memory (in units of TIn) to allocate for the * transformed weights. */ @@ -186,6 +181,7 @@ class WinogradConvolutionLayer : public IWinogradConvolutionLayer const int n_input_cols, /** Number of columns in a feature map of the input tensor. */ const int n_output_channels, /** Number of feature maps in the output tensor. */ const bool same_padding, /** Use "SAME" padding, otherwise use "VALID". */ + const arm_gemm::Activation &activation, const TIn* const weights, /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */ TInGEMM* const weights_storage, /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */ const TIn* const input, /** Pointer to NHWC ordered input tensor, in the spatial domain. */ @@ -201,8 +197,8 @@ class WinogradConvolutionLayer : public IWinogradConvolutionLayer unsigned int weight_transform_get_window(void) const; void weight_transform_run(const unsigned int start, const unsigned int stop); - ITransform& input_transform(void); - ITransform& output_transform(void); + IInputTransform& input_transform(void); + IOutputTransform& output_transform(void); /* Get a pointer to the GEMM underlying the Winograd transform. */ arm_gemm::IGemmCommon *gemm(void); diff --git a/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp index 263ded0b84..fda384bc62 100644 --- a/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp +++ b/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp @@ -28,6 +28,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/IAccessWindow.h" #include "arm_compute/core/ITensor.h" +#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" @@ -233,7 +234,7 @@ unsigned int NEWinogradLayerTransformWeightsKernel( // WinogradConv returns the size in bytes, we divide by `sizeof(T)` to express that in units of T - WinogradConv::get_kernel_storage_size(shape) / sizeof(T)); + WinogradConv::get_kernel_storage_size(num_input_channels, num_output_channels) / sizeof(T)); } template @@ -243,9 +244,9 @@ NEWinogradLayerTransformWeightsKernel -int NEWinogradLayerTransformWeightsKernel::get_matrix_stride(const KernelShape &kernel_shape) const +int NEWinogradLayerTransformWeightsKernel::get_matrix_stride(int num_output_channels, int num_input_channels) const { - return WinogradConv::get_kernel_matrix_stride(kernel_shape); + return WinogradConv::get_kernel_matrix_stride(num_input_channels, num_output_channels); } #ifndef DOXYGEN_SKIP_THIS @@ -325,9 +326,8 @@ unsigned int NEWinogradLayerTransformInputKernel(WinogradConv::get_input_storage_size(kern_shape, input_shape, padding) / sizeof(T)); + return static_cast(WinogradConv::get_input_storage_size(num_batches, num_rows, num_cols, num_channels, same_padding) / sizeof(T)); } template @@ -338,9 +338,13 @@ unsigned int NEWinogradLayerTransformInputKernel int NEWinogradLayerTransformInputKernel::get_matrix_stride( - const KernelShape &kernel_shape, const Tensor4DShape &input_shape, const PaddingType padding_type) const + int num_batches, /* Number of batches in the input tensor. */ + int num_channels, /* Number of feature maps in the input tensor. */ + int num_rows, /* Number of rows in each feature map. */ + int num_cols, /* Number of columns in each feature map. */ + bool same_padding /* Use "SAME" padding, otherwise use "VALID". */) const { - return WinogradConv::get_input_matrix_stride(kernel_shape, input_shape, padding_type); + return WinogradConv::get_input_matrix_stride(num_batches, num_rows, num_cols, num_channels, same_padding); } template @@ -446,21 +450,18 @@ template class NEWinogradLayerTransformInputKernel; template unsigned int NEWinogradLayerTransformOutputKernel::get_output_storage_size( - int num_batches, /* Number of batches in the output tensor. */ - int num_rows, /* Number of rows in each feature map of the input tensor. */ - int num_cols, /* Number of columns in each feature map of the input tensor. */ - int num_output_channels, /* Number of feature maps in the output tensor. */ - bool same_padding /* Use "SAME" padding, otherwise use "VALID". */ + int num_batches, /* Number of batches in the output tensor. */ + int num_rows, /* Number of rows in each feature map of the input tensor. */ + int num_cols, /* Number of columns in each feature map of the input tensor. */ + int num_output_channels /* Number of feature maps in the output tensor. */ ) const { // Construct shapes for the input and kernel tensors. const Tensor4DShape input_shape(num_batches, num_rows, num_cols, 1); const KernelShape kern_shape(num_output_channels, KernelRows, KernelCols, 1); - const PaddingType padding = (same_padding) ? PADDING_SAME : PADDING_VALID; - // Return the size, converted into units of TOut return static_cast( - WinogradConv::get_output_storage_size(kern_shape, input_shape, padding) / sizeof(T)); + WinogradConv::get_output_storage_size(num_batches, num_rows, num_cols, num_output_channels) / sizeof(T)); } template @@ -478,28 +479,36 @@ unsigned int NEWinogradLayerTransformOutputKernel int NEWinogradLayerTransformOutputKernel::get_matrix_stride( - const KernelShape &kernel_shape, const Tensor4DShape &input_shape, const PaddingType padding_type) const + int num_batches, /* Number of batches in the output tensor. */ + int num_rows, /* Number of rows in each feature map of the input tensor. */ + int num_cols, /* Number of columns in each feature map of the input tensor. */ + int num_output_channels /* Number of feature maps in the output tensor. */ +) const { - return WinogradConv::get_output_matrix_stride(kernel_shape, input_shape, padding_type); + return WinogradConv::get_output_matrix_stride(num_batches, num_rows, num_cols, num_output_channels); } + template -Tensor4DShape NEWinogradLayerTransformOutputKernel::get_output_shape( - const KernelShape &kernel_shape, const Tensor4DShape &in_shape, const PaddingType padding) const +std::pair NEWinogradLayerTransformOutputKernel::get_output_shape( + int num_rows, /* Number of rows in each feature map of the input tensor. */ + int num_cols, /* Number of columns in each feature map of the input tensor. */ + bool padding_same) const { - return WinogradConv::get_output_shape(kernel_shape, in_shape, padding); + return WinogradConv::get_output_shape(std::make_pair(num_rows, num_cols), padding_same); } template void NEWinogradLayerTransformOutputKernel::configure( - const ITensor *biases, - const ITensor *transformed_output, - const int matrix_stride, - ITensor *output_nhwc, - const int num_batches, - const int num_rows, - const int num_cols, - const int num_channels, - ITensor *workspace) + const ITensor *biases, + const ITensor *transformed_output, + const int matrix_stride, + ITensor *output_nhwc, + const int num_batches, + const int num_rows, + const int num_cols, + const int num_channels, + ITensor *workspace, + const arm_gemm::Activation &activation) { _biases = biases; _workspace = workspace; @@ -512,7 +521,7 @@ void NEWinogradLayerTransformOutputKernel(num_batches, num_rows, num_cols, num_channels); + _transform = arm_compute::support::cpp14::make_unique(num_batches, num_rows, num_cols, num_channels, activation); Window win; auto win_last = _transform->get_window(); win.set(Window::DimX, Window::Dimension(0, win_last, 1)); diff --git a/src/core/NEON/kernels/convolution/winograd/winograd.cpp b/src/core/NEON/kernels/convolution/winograd/winograd.cpp index 226f303c7d..a4eb9fce59 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd.cpp +++ b/src/core/NEON/kernels/convolution/winograd/winograd.cpp @@ -21,205 +21,147 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + #include +#include "utils.hpp" #include "winograd.hpp" + using namespace winograd; +using array2 = std::pair; -/** Get the output shape of a convolution. */ -template -template -Tensor4DShape WinogradGEMM::Convolution::get_output_shape( - const KernelShape &kernel_shape, - const Tensor4DShape &in_shape, - const PaddingType padding -) -{ - return Tensor4DShape { - in_shape.n_batches, - (padding == PADDING_SAME) ? in_shape.n_rows : in_shape.n_rows - (kernel_rows - 1), - (padding == PADDING_SAME) ? in_shape.n_cols : in_shape.n_cols - (kernel_cols - 1), - kernel_shape.n_output_channels, - in_shape.ordering - }; -} +#define MEMBERFN(RTYPE) \ + template \ + template \ + RTYPE WinogradGEMM::Convolution -/* Get the memory required to transform the kernel. - */ -template -template -size_t WinogradGEMM::Convolution::get_kernel_transform_working_size(const KernelShape &shape) -{ - if (shape.ordering == HWIO) - { - // Kernel is already in the correct order, so no additional memory is - // required. - return 0; - } - else - { - // Need to re-order the kernel into HWIO form, require enough space to - // represent the tensor. - return sizeof(TIn) * shape.size(); - } +/** Get the output shape of a convolution. */ +MEMBERFN(array2) +::get_output_shape(const std::pair input_shape, + const bool padding_same) { + const unsigned int n_rows = + padding_same ? input_shape.first : input_shape.first - (kernel_rows - 1); + const unsigned int n_cols = padding_same + ? input_shape.second + : input_shape.second - (kernel_cols - 1); + return {n_rows, n_cols}; } /** Get the memory required to store the kernel transformed into the * Winograd domain. */ -template -template -size_t WinogradGEMM::Convolution::get_kernel_storage_size(const KernelShape &shape) -{ - return N_GEMMS * get_kernel_matrix_size(shape); +MEMBERFN(size_t) +::get_kernel_storage_size(const unsigned int n_input_channels, + const unsigned int n_output_channels) { + return N_GEMMS * get_kernel_matrix_size(n_input_channels, n_output_channels); } - -template -template -size_t WinogradGEMM::Convolution::get_input_storage_size( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding -) -{ - return N_GEMMS * get_input_matrix_size(kernel_shape, input_shape, padding); +MEMBERFN(size_t) +::get_input_storage_size(const unsigned int n_batches, + const unsigned int n_rows, const unsigned int n_cols, + const unsigned int n_channels, + const bool same_padding) { + return N_GEMMS * get_input_matrix_size(n_batches, n_rows, n_cols, n_channels, + same_padding); } - -template -template -size_t WinogradGEMM::Convolution::get_output_storage_size( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding -) -{ - return N_GEMMS * get_output_matrix_size(kernel_shape, input_shape, padding); +MEMBERFN(size_t) +::get_output_storage_size(const unsigned int n_batches, + const unsigned int n_rows, const unsigned int n_cols, + const unsigned int n_channels) { + return N_GEMMS * + get_output_matrix_size(n_batches, n_rows, n_cols, n_channels); } - /** Get the memory required to apply a Winograd operator to some input. */ -template -template -size_t WinogradGEMM::Convolution::get_working_space_size( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding_type -) -{ - const auto output_shape = get_output_shape(kernel_shape, input_shape, padding_type); +MEMBERFN(size_t) +::get_working_space_size(const unsigned int n_batches, + const unsigned int n_rows, const unsigned int n_cols, + const unsigned int n_input_channels, + const unsigned int n_output_channels, + const bool padding_same) { + const auto output_shape = get_output_shape({n_rows, n_cols}, padding_same); // Get the memory required to store the matrices - const size_t matrix_sizes = N_GEMMS * ( - get_input_matrix_size(kernel_shape, input_shape, padding_type) + - get_output_matrix_size(kernel_shape, input_shape, padding_type) - ); - - // Add additional space to re-order the input and output if the input tensor - // is not in NHWC format. - if (input_shape.ordering == NHWC) - { - return matrix_sizes; // No extra spacing required - } - else // NCHW, must reorder the input and output tensors - { - // We only need to re-order the input or output at any one time, so request - // enough memory to do the largest of these. - const size_t extra_memory = std::max( - sizeof(TIn) * input_shape.size(), - sizeof(TOut) * output_shape.size() - ); - return matrix_sizes + extra_memory; - } + const size_t matrix_sizes = + N_GEMMS * + (get_input_matrix_size(n_batches, n_rows, n_cols, n_input_channels, + padding_same) + + get_output_matrix_size(n_batches, output_shape.first, + output_shape.second, n_output_channels)); + return matrix_sizes; } - /* Get the memory required by a single "input" matrix. */ -template -template -size_t WinogradGEMM::Convolution::get_input_matrix_size( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding_type -) -{ - return get_input_matrix_stride(kernel_shape, input_shape, padding_type) * sizeof(TGIn); +MEMBERFN(size_t) +::get_input_matrix_size(const unsigned int n_batches, const unsigned int n_rows, + const unsigned int n_cols, + const unsigned int n_channels, + const bool same_padding) { + return get_input_matrix_stride(n_batches, n_rows, n_cols, n_channels, + same_padding) * + sizeof(TGEMMIn); } -template -template -int WinogradGEMM::Convolution::get_input_matrix_stride( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding_type -) -{ - // Compute shape for the GEMM - const auto output_shape = get_output_shape(kernel_shape, input_shape, padding_type); - const int tile_rows = iceildiv(output_shape.n_rows, output_tile_rows); - const int tile_cols = iceildiv(output_shape.n_cols, output_tile_cols); - const int M = roundup(input_shape.n_batches * tile_rows * tile_cols, M_BLOCK); - const int K = kernel_shape.n_input_channels; +MEMBERFN(int) +::get_input_matrix_stride(const unsigned int n_batches, const unsigned int n_rows, + const unsigned int n_cols, + const unsigned int n_channels, + const bool same_padding) { + const auto output_shape = get_output_shape({n_rows, n_cols}, same_padding); + const unsigned int tile_rows = iceildiv(output_shape.first, output_tile_rows); + const unsigned int tile_cols = + iceildiv(output_shape.second, output_tile_cols); + const unsigned int M = + roundup(n_batches * tile_rows * tile_cols, M_BLOCK); + const unsigned int K = n_channels; return M * K; } - /* Get the memory required by a single "output" matrix. */ -template -template -size_t WinogradGEMM::Convolution::get_output_matrix_size( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding_type -) -{ - return get_output_matrix_stride(kernel_shape, input_shape, padding_type) * sizeof(TGOut); +MEMBERFN(size_t) +::get_output_matrix_size(const unsigned int n_batches, + const unsigned int n_rows, const unsigned int n_cols, + const unsigned int n_channels) { + return get_output_matrix_stride(n_batches, n_rows, n_cols, n_channels) * + sizeof(TGEMMOut); } - -template -template -int WinogradGEMM::Convolution::get_output_matrix_stride( - const KernelShape &kernel_shape, - const Tensor4DShape &input_shape, - const PaddingType padding_type -) -{ +MEMBERFN(int) +::get_output_matrix_stride(const unsigned int n_batches, + const unsigned int n_rows, const unsigned int n_cols, + const unsigned int n_channels) { // Compute shape for the GEMM - const auto output_shape = get_output_shape(kernel_shape, input_shape, padding_type); - const int tile_rows = iceildiv(output_shape.n_rows, output_tile_rows); - const int tile_cols = iceildiv(output_shape.n_cols, output_tile_cols); - const int M = roundup(tile_rows * tile_cols, M_BLOCK); - const int N = roundup(kernel_shape.n_output_channels, N_BLOCK); + const int tile_rows = iceildiv(n_rows, output_tile_rows); + const int tile_cols = iceildiv(n_cols, output_tile_cols); + const int M = roundup(tile_rows * tile_cols, M_BLOCK); + const int N = roundup(n_channels, N_BLOCK); - return input_shape.n_batches * M * N; + return n_batches * M * N; } /* Get the memory required by a single "kernel" matrix. */ -template -template -size_t WinogradGEMM::Convolution::get_kernel_matrix_size(const KernelShape &shape) -{ - return sizeof(TGIn) * get_kernel_matrix_stride(shape); +MEMBERFN(size_t) +::get_kernel_matrix_size(const unsigned int n_input_channels, + const unsigned int n_output_channels) { + return sizeof(TGEMMIn) * + get_kernel_matrix_stride(n_input_channels, n_output_channels); } -template -template -int WinogradGEMM::Convolution::get_kernel_matrix_stride(const KernelShape &shape) -{ - const int K = shape.n_input_channels; - const int N = roundup(shape.n_output_channels, N_BLOCK); - return K * N; +MEMBERFN(int) +::get_kernel_matrix_stride(const unsigned int n_input_channels, + const unsigned int n_output_channels) { + return n_input_channels * roundup(n_output_channels, N_BLOCK); } - // Instantiate required implementations template class WinogradGEMM<2, 2, 3, 3, WinogradRoots::Integers>::Convolution; template class WinogradGEMM<4, 4, 3, 3, WinogradRoots::Integers>::Convolution; diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp index fcbd21fcd0..8e4bebcd20 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp +++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,8 +24,11 @@ #pragma once -#include "winograd.hpp" +#include + #include "padding.hpp" +#include "utils.hpp" +#include "winograd.hpp" #define MEMBERFN(RTYPE) template <\ int InnerTileRows, int InnerTileCols,\ diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp index d97af21a43..fe47ccbde9 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp +++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -41,24 +41,30 @@ namespace winograd { -MEMBERFN()::OutputTransform( - const int n_batches, - const int n_rows, - const int n_cols, - const int n_channels -) : _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels), - _matrix_base(nullptr), - _biases(nullptr), - _matrix_stride(0), _matrix_row_stride(0), _matrix_batch_stride(0), - _outptr(nullptr), - _tiles_M(iceildiv(n_rows, output_tile_rows)), - _tiles_N(iceildiv(n_cols, output_tile_cols)), - _out_col_stride(0), _out_row_stride(0), _out_batch_stride(0), - _working_space_col_stride(n_channels), - _working_space_row_stride(output_tile_cols * _working_space_col_stride), - _working_space(nullptr) -{ -} +MEMBERFN() +::OutputTransform(const int n_batches, const int n_rows, const int n_cols, + const int n_channels, const arm_gemm::Activation &activation) + : _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), + _n_channels(n_channels), + _output_min((activation.type == arm_gemm::Activation::Type::ReLU || + activation.type == arm_gemm::Activation::Type::BoundedReLU) + ? static_cast(0.0f) + : (std::numeric_limits::has_infinity) + ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest()), + _output_max((activation.type == arm_gemm::Activation::Type::BoundedReLU) + ? static_cast(activation.param1) + : (std::numeric_limits::has_infinity) + ? std::numeric_limits::infinity() + : std::numeric_limits::max()), + _matrix_base(nullptr), _biases(nullptr), _matrix_stride(0), + _matrix_row_stride(0), _matrix_batch_stride(0), _outptr(nullptr), + _tiles_M(iceildiv(n_rows, output_tile_rows)), + _tiles_N(iceildiv(n_cols, output_tile_cols)), _out_col_stride(0), + _out_row_stride(0), _out_batch_stride(0), + _working_space_col_stride(n_channels), + _working_space_row_stride(output_tile_cols * _working_space_col_stride), + _working_space(nullptr) {} MEMBERFN(void)::set_input_matrices(const void * const mptr, const int ldmatrix, const int ldrow) { @@ -100,9 +106,10 @@ Nx1MEMBERFN()::OutputTransform( const int n_batches, const int n_rows, const int n_cols, - const int n_channels + const int n_channels, + const arm_gemm::Activation &activation ) : OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>::OutputTransform( - n_batches, n_cols, n_rows, n_channels /* Transpose rows and columns */ + n_batches, n_cols, n_rows, n_channels, activation /* Transpose rows and columns */ ) { } @@ -212,7 +219,8 @@ MEMBERFN(void)::transform_uncropped_tile( { transform_tile( n_channels, inptr, _matrix_stride, biases, - outptr, _out_row_stride, _out_col_stride + outptr, _out_row_stride, _out_col_stride, + _output_min, _output_max ); } @@ -230,7 +238,8 @@ MEMBERFN(void)::transform_cropped_tile( TOut *wsptr = static_cast(get_working_space(threadid)); transform_tile( n_channels, inptr, _matrix_stride, biases, - wsptr, _working_space_row_stride, _working_space_col_stride + wsptr, _working_space_row_stride, _working_space_col_stride, + _output_min, _output_max ); padding::crop_and_copy_tile( diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp index c32d7f2f58..f231bdd67e 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -36,7 +36,9 @@ void OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::transfo const float* bptr, float* const output, const int, // No need to stride across rows - const int output_col_stride + const int output_col_stride, + const float output_min, + const float output_max ) { // Construct a map to the output cells @@ -72,7 +74,9 @@ void OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::transfo } for (int j = 0; j < output_tile_cols; j++) { - vst1q_f32(outptrs[j], f[j] + b); + const auto y = vminq_f32(vmaxq_f32(f[j] + b, vdupq_n_f32(output_min)), + vdupq_n_f32(output_max)); + vst1q_f32(outptrs[j], y); outptrs[j] += 4; } } @@ -99,7 +103,9 @@ void OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::transfo } for (int j = 0; j < output_tile_cols; j++) { - vst1_f32(outptrs[j], f[j] + b); + const auto y = vmin_f32(vmax_f32(f[j] + b, vdup_n_f32(output_min)), + vdup_n_f32(output_max)); + vst1_f32(outptrs[j], y); outptrs[j] += 2; } } @@ -126,7 +132,7 @@ void OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::transfo } for (int j = 0; j < output_tile_cols; j++) { - *(outptrs[j]++) = f[j] + b; + *(outptrs[j]++) = std::max(std::min(f[j] + b, output_max), output_min); } } } diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp index d6ebf44f41..5136bc15c4 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -36,7 +36,9 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo const float* bptr, float* const output, const int output_row_stride, - const int output_col_stride + const int output_col_stride, + const float output_min, + const float output_max ) { // Construct a map to the output cells @@ -103,7 +105,10 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo { for (int j = 0; j < output_tile_cols; j++) { - vst1q_f32(outptrs[i][j], vaddq_f32(f[i][j], b)); + const auto y = + vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)), + vdupq_n_f32(output_min)); + vst1q_f32(outptrs[i][j], y); outptrs[i][j] += 4; } } @@ -161,7 +166,10 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo { for (int j = 0; j < output_tile_cols; j++) { - vst1_f32(outptrs[i][j], vadd_f32(f[i][j], b)); + const auto y = + vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)), + vdup_n_f32(output_min)); + vst1_f32(outptrs[i][j], y); outptrs[i][j] += 2; } } @@ -211,7 +219,8 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo { for (int j = 0; j < output_tile_cols; j++) { - *(outptrs[i][j]++) = f[i][j] + b; + const auto y = std::max(std::min(f[i][j] + b, output_max), output_min); + *(outptrs[i][j]++) = y; } } } diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp index d93d9e234a..0f911f14a3 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -36,7 +36,9 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo const float* bptr, float* const output, const int output_row_stride, - const int output_col_stride + const int output_col_stride, + const float output_min, + const float output_max ) { // Construct a map to the output cells @@ -101,7 +103,10 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo { for (int j = 0; j < output_tile_cols; j++) { - vst1q_f32(outptrs[i][j], vaddq_f32(f[i][j], b)); + const auto y = + vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)), + vdupq_n_f32(output_min)); + vst1q_f32(outptrs[i][j], y); outptrs[i][j] += 4; } } @@ -157,7 +162,10 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo { for (int j = 0; j < output_tile_cols; j++) { - vst1_f32(outptrs[i][j], vadd_f32(f[i][j], b)); + const auto y = + vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)), + vdup_n_f32(output_min)); + vst1_f32(outptrs[i][j], y); outptrs[i][j] += 2; } } @@ -205,7 +213,8 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo { for (int j = 0; j < output_tile_cols; j++) { - *(outptrs[i][j]++) = f[i][j] + b; + const auto y = std::max(std::min(f[i][j] + b, output_max), output_min); + *(outptrs[i][j]++) = y; } } } diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp index 7187ef2d20..49a3f41182 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -36,7 +36,9 @@ void OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::transfo const float* bptr, float* const output, const int, // No need to stride across rows - const int output_col_stride + const int output_col_stride, + const float output_min, + const float output_max ) { // Construct a map to the output cells @@ -74,7 +76,10 @@ void OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::transfo } for (int j = 0; j < output_tile_cols; j++) { - vst1q_f32(outptrs[j], f[j] + b); + const auto y = + vmaxq_f32(vminq_f32(vaddq_f32(f[j], b), vdupq_n_f32(output_max)), + vdupq_n_f32(output_min)); + vst1q_f32(outptrs[j], y); outptrs[j] += 4; } } @@ -103,7 +108,10 @@ void OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::transfo } for (int j = 0; j < output_tile_cols; j++) { - vst1_f32(outptrs[j], f[j] + b); + const auto y = + vmax_f32(vmin_f32(vadd_f32(f[j], b), vdup_n_f32(output_max)), + vdup_n_f32(output_min)); + vst1_f32(outptrs[j], y); outptrs[j] += 2; } } @@ -132,7 +140,8 @@ void OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::transfo } for (int j = 0; j < output_tile_cols; j++) { - *(outptrs[j]++) = f[j] + b; + const auto y = std::max(std::min(f[j] + b, output_max), output_min); + *(outptrs[j]++) = y; } } } diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp index fd16a4df1c..292999c8bc 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -28,1683 +28,6 @@ namespace winograd { -#ifdef __aarch64__ - -template <> -void OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots::Integers>::transform_tile( - int n_channels, - const float* inptr, - const int matrix_stride, - const float* bptr, - float* output, - const int output_row_stride, - const int output_col_stride -) -{ - const float coeffs[2] = {2.0f, 4.0f}; - if (bptr != nullptr) - { - __asm__ __volatile__ ( - "ldr d0, [%[pcoeffs]]\n" - "add x21, %[in_col_stride1], %[in_col_stride1]\n" - "add x22, x21, %[in_col_stride1]\n" - "add x25, %[inptr0], %[in_row_stride]\n" - "add x15, %[output_col_stride1], %[output_col_stride1]\n" - "add x23, x22, %[in_col_stride1]\n" - "add x13, x25, %[in_row_stride]\n" - "add x16, x15, %[output_col_stride1]\n" - "add x24, x23, %[in_col_stride1]\n" - "add x26, x13, %[in_row_stride]\n" - "add x17, %[outptr0], %[output_row_stride]\n" - "add x14, x26, %[in_row_stride]\n" - "add x28, x17, %[output_row_stride]\n" - "lsr x19, %[n_channels], #2\n" - "add x27, x14, %[in_row_stride]\n" - "add x18, x28, %[output_row_stride]\n" - "and x20, %[n_channels], #3\n" - "cbz x19, 4f\n" - "1:\n" - "ldr q19, [%[inptr0]]\n" - "subs x19, x19, #1\n" - "ldr q20, [%[inptr0], %[in_col_stride1]]\n" - "ldr q4, [%[inptr0], x21]\n" - "fadd v1.4s, v20.4s, v4.4s\n" - "ldr q17, [%[inptr0], x22]\n" - "fsub v7.4s, v20.4s, v4.4s\n" - "ldr q22, [%[inptr0], x23]\n" - "fadd v5.4s, v17.4s, v22.4s\n" - "ldr q18, [%[inptr0], x24]\n" - "fsub v10.4s, v17.4s, v22.4s\n" - "ldr q25, [x25]\n" - "fadd v8.4s, v19.4s, v1.4s\n" - "ldr q12, [x25, %[in_col_stride1]]\n" - "mov v4.16b, v1.16b\n" - "ldr q23, [x25, x21]\n" - "mov v1.16b, v7.16b\n" - "ldr q9, [x25, x22]\n" - "fmul v10.4s, v10.4s, v0.s[0]\n" - "ldr q11, [x25, x23]\n" - "fadd v8.4s, v8.4s, v5.4s\n" - "ldr q6, [x25, x24]\n" - "fmla v4.4s, v5.4s, v0.s[1]\n" - "fadd v7.4s, v7.4s, v10.4s\n" - "fmla v1.4s, v10.4s, v0.s[1]\n" - "fadd v1.4s, v1.4s, v18.4s\n" - "beq 3f\n" - "2:\n" - "fadd v3.4s, v12.4s, v23.4s\n" - "ldr q2, [x13]\n" - "fadd v27.4s, v9.4s, v11.4s\n" - "ldr q21, [x13, %[in_col_stride1]]\n" - "fsub v16.4s, v12.4s, v23.4s\n" - "ldr q26, [x13, x21]\n" - "fsub v9.4s, v9.4s, v11.4s\n" - "ldr q17, [x13, x22]\n" - "fadd v14.4s, v25.4s, v3.4s\n" - "ldr q19, [x13, x23]\n" - "mov v11.16b, v3.16b\n" - "ldr q10, [x13, x24]\n" - "mov v3.16b, v16.16b\n" - "ldr q15, [x26]\n" - "fmul v9.4s, v9.4s, v0.s[0]\n" - "ldr q12, [x26, %[in_col_stride1]]\n" - "fadd v14.4s, v14.4s, v27.4s\n" - "ldr q20, [x26, x21]\n" - "fmla v11.4s, v27.4s, v0.s[1]\n" - "ldr q24, [x26, x22]\n" - "fadd v23.4s, v21.4s, v26.4s\n" - "ldr q29, [x26, x23]\n" - "fadd v13.4s, v16.4s, v9.4s\n" - "ldr q5, [x26, x24]\n" - "fmla v3.4s, v9.4s, v0.s[1]\n" - "ldr q18, [x14]\n" - "fadd v30.4s, v17.4s, v19.4s\n" - "add %[inptr0], %[inptr0], #16\n" - "fadd v16.4s, v2.4s, v23.4s\n" - "add x25, x25, #16\n" - "fsub v21.4s, v21.4s, v26.4s\n" - "ldr q22, [x14, %[in_col_stride1]]\n" - "fadd v3.4s, v3.4s, v6.4s\n" - "ldr q28, [x14, x21]\n" - "fsub v19.4s, v17.4s, v19.4s\n" - "add x13, x13, #16\n" - "fadd v16.4s, v16.4s, v30.4s\n" - "add x26, x26, #16\n" - "mov v17.16b, v23.16b\n" - "subs x19, x19, #1\n" - "fadd v26.4s, v12.4s, v20.4s\n" - "fsub v9.4s, v12.4s, v20.4s\n" - "fmul v19.4s, v19.4s, v0.s[0]\n" - "ldr q20, [x14, x22]\n" - "fmla v17.4s, v30.4s, v0.s[1]\n" - "fadd v25.4s, v24.4s, v29.4s\n" - "fsub v12.4s, v24.4s, v29.4s\n" - "fadd v24.4s, v22.4s, v28.4s\n" - "fadd v23.4s, v15.4s, v26.4s\n" - "mov v15.16b, v26.16b\n" - "fsub v22.4s, v22.4s, v28.4s\n" - "fadd v29.4s, v14.4s, v16.4s\n" - "fsub v16.4s, v14.4s, v16.4s\n" - "ldr q28, [x14, x23]\n" - "fmul v12.4s, v12.4s, v0.s[0]\n" - "fmla v15.4s, v25.4s, v0.s[1]\n" - "fadd v23.4s, v23.4s, v25.4s\n" - "mov v6.16b, v21.16b\n" - "fadd v30.4s, v21.4s, v19.4s\n" - "fadd v26.4s, v18.4s, v24.4s\n" - "mov v25.16b, v24.16b\n" - "fadd v18.4s, v8.4s, v29.4s\n" - "fmla v6.4s, v19.4s, v0.s[1]\n" - "fadd v27.4s, v20.4s, v28.4s\n" - "fsub v21.4s, v20.4s, v28.4s\n" - "mov v19.16b, v29.16b\n" - "fadd v29.4s, v13.4s, v30.4s\n" - "fsub v8.4s, v13.4s, v30.4s\n" - "fadd v14.4s, v9.4s, v12.4s\n" - "fadd v6.4s, v6.4s, v10.4s\n" - "ldr q20, [x14, x24]\n" - "fadd v26.4s, v26.4s, v27.4s\n" - "add x14, x14, #16\n" - "fmla v9.4s, v12.4s, v0.s[1]\n" - "ldr q24, [x27]\n" - "fmul v21.4s, v21.4s, v0.s[0]\n" - "fmla v25.4s, v27.4s, v0.s[1]\n" - "fadd v10.4s, v7.4s, v29.4s\n" - "ldr q2, [%[bptr]]\n" - "mov v7.16b, v29.16b\n" - "add %[bptr], %[bptr], #16\n" - "fadd v9.4s, v9.4s, v5.4s\n" - "fadd v13.4s, v23.4s, v26.4s\n" - "fsub v23.4s, v23.4s, v26.4s\n" - "fadd v27.4s, v11.4s, v17.4s\n" - "fsub v11.4s, v11.4s, v17.4s\n" - "fadd v30.4s, v15.4s, v25.4s\n" - "fsub v15.4s, v15.4s, v25.4s\n" - "ldr q28, [x27, %[in_col_stride1]]\n" - "fadd v18.4s, v18.4s, v13.4s\n" - "fmla v19.4s, v13.4s, v0.s[1]\n" - "fadd v26.4s, v22.4s, v21.4s\n" - "mov v12.16b, v22.16b\n" - "fmul v23.4s, v23.4s, v0.s[0]\n" - "fadd v17.4s, v4.4s, v27.4s\n" - "fmul v15.4s, v15.4s, v0.s[0]\n" - "mov v4.16b, v27.16b\n" - "fmla v12.4s, v21.4s, v0.s[1]\n" - "ldr q22, [x27, x21]\n" - "fadd v18.4s, v18.4s, v2.4s\n" - "fadd v19.4s, v19.4s, v2.4s\n" - "fadd v17.4s, v17.4s, v30.4s\n" - "fmla v4.4s, v30.4s, v0.s[1]\n" - "fadd v25.4s, v28.4s, v22.4s\n" - "fsub v27.4s, v28.4s, v22.4s\n" - "fadd v12.4s, v12.4s, v20.4s\n" - "ldr q29, [x27, x22]\n" - "str q18, [%[outptr0]]\n" - "fadd v22.4s, v16.4s, v23.4s\n" - "str q19, [x28]\n" - "fadd v28.4s, v24.4s, v25.4s\n" - "ldr q30, [x27, x23]\n" - "fadd v20.4s, v29.4s, v30.4s\n" - "fsub v18.4s, v29.4s, v30.4s\n" - "mov v21.16b, v25.16b\n" - "ldr q25, [x27, x24]\n" - "fmla v16.4s, v23.4s, v0.s[1]\n" - "ldr q19, [%[inptr0]]\n" - "fadd v17.4s, v17.4s, v2.4s\n" - "add x27, x27, #16\n" - "fadd v28.4s, v28.4s, v20.4s\n" - "fmul v18.4s, v18.4s, v0.s[0]\n" - "fmla v21.4s, v20.4s, v0.s[1]\n" - "ldr q20, [%[inptr0], %[in_col_stride1]]\n" - "fadd v22.4s, v22.4s, v2.4s\n" - "fadd v4.4s, v4.4s, v2.4s\n" - "str q17, [%[outptr0], x15]\n" - "mov v24.16b, v27.16b\n" - "fadd v23.4s, v27.4s, v18.4s\n" - "fadd v16.4s, v16.4s, v28.4s\n" - "fadd v13.4s, v14.4s, v26.4s\n" - "fsub v30.4s, v14.4s, v26.4s\n" - "str q22, [x17]\n" - "fmla v24.4s, v18.4s, v0.s[1]\n" - "str q4, [x28, x15]\n" - "mov v14.16b, v8.16b\n" - "fadd v29.4s, v11.4s, v15.4s\n" - "ldr q4, [%[inptr0], x21]\n" - "fadd v10.4s, v10.4s, v13.4s\n" - "ldr q17, [%[inptr0], x22]\n" - "fadd v24.4s, v24.4s, v25.4s\n" - "ldr q22, [%[inptr0], x23]\n" - "fmul v30.4s, v30.4s, v0.s[0]\n" - "fmla v7.4s, v13.4s, v0.s[1]\n" - "mov v26.16b, v11.16b\n" - "fadd v13.4s, v3.4s, v6.4s\n" - "fsub v3.4s, v3.4s, v6.4s\n" - "ldr q18, [%[inptr0], x24]\n" - "fadd v10.4s, v10.4s, v2.4s\n" - "fadd v29.4s, v29.4s, v2.4s\n" - "fadd v8.4s, v8.4s, v30.4s\n" - "fmla v14.4s, v30.4s, v0.s[1]\n" - "fmla v26.4s, v15.4s, v0.s[1]\n" - "ldr q25, [x25]\n" - "fadd v27.4s, v9.4s, v12.4s\n" - "fadd v1.4s, v1.4s, v13.4s\n" - "str q10, [%[outptr0], %[output_col_stride1]]\n" - "fsub v6.4s, v9.4s, v12.4s\n" - "str q29, [x17, x15]\n" - "fadd v14.4s, v14.4s, v23.4s\n" - "fadd v26.4s, v26.4s, v21.4s\n" - "ldr q12, [x25, %[in_col_stride1]]\n" - "fadd v1.4s, v1.4s, v27.4s\n" - "ldr q23, [x25, x21]\n" - "fmul v6.4s, v6.4s, v0.s[0]\n" - "ldr q9, [x25, x22]\n" - "mov v5.16b, v13.16b\n" - "ldr q11, [x25, x23]\n" - "mov v13.16b, v3.16b\n" - "fadd v8.4s, v8.4s, v2.4s\n" - "fadd v1.4s, v1.4s, v2.4s\n" - "fadd v7.4s, v7.4s, v2.4s\n" - "fadd v10.4s, v3.4s, v6.4s\n" - "fmla v5.4s, v27.4s, v0.s[1]\n" - "fmla v13.4s, v6.4s, v0.s[1]\n" - "ldr q6, [x25, x24]\n" - "str q8, [x17, %[output_col_stride1]]\n" - "fadd v16.4s, v16.4s, v2.4s\n" - "str q1, [%[outptr0], x16]\n" - "fadd v14.4s, v14.4s, v2.4s\n" - "str q7, [x28, %[output_col_stride1]]\n" - "fadd v10.4s, v10.4s, v2.4s\n" - "fadd v13.4s, v13.4s, v24.4s\n" - "add %[outptr0], %[outptr0], #16\n" - "str q16, [x18]\n" - "fadd v5.4s, v5.4s, v2.4s\n" - "str q14, [x18, %[output_col_stride1]]\n" - "fadd v26.4s, v26.4s, v2.4s\n" - "str q10, [x17, x16]\n" - "fadd v1.4s, v20.4s, v4.4s\n" - "fadd v13.4s, v13.4s, v2.4s\n" - "add x17, x17, #16\n" - "str q5, [x28, x16]\n" - "fadd v5.4s, v17.4s, v22.4s\n" - "str q26, [x18, x15]\n" - "fsub v7.4s, v20.4s, v4.4s\n" - "fadd v8.4s, v19.4s, v1.4s\n" - "add x28, x28, #16\n" - "str q13, [x18, x16]\n" - "mov v4.16b, v1.16b\n" - "fsub v10.4s, v17.4s, v22.4s\n" - "add x18, x18, #16\n" - "mov v1.16b, v7.16b\n" - "fadd v8.4s, v8.4s, v5.4s\n" - "fmla v4.4s, v5.4s, v0.s[1]\n" - "fmul v10.4s, v10.4s, v0.s[0]\n" - "fadd v7.4s, v7.4s, v10.4s\n" - "fmla v1.4s, v10.4s, v0.s[1]\n" - "fadd v1.4s, v1.4s, v18.4s\n" - "bne 2b\n" - "3:\n" - "fadd v3.4s, v12.4s, v23.4s\n" - "ldr q2, [x13]\n" - "fadd v27.4s, v9.4s, v11.4s\n" - "ldr q21, [x13, %[in_col_stride1]]\n" - "fsub v16.4s, v12.4s, v23.4s\n" - "ldr q26, [x13, x21]\n" - "fsub v9.4s, v9.4s, v11.4s\n" - "ldr q17, [x13, x22]\n" - "fadd v14.4s, v25.4s, v3.4s\n" - "ldr q19, [x13, x23]\n" - "mov v11.16b, v3.16b\n" - "ldr q10, [x13, x24]\n" - "mov v3.16b, v16.16b\n" - "ldr q15, [x26]\n" - "fmul v9.4s, v9.4s, v0.s[0]\n" - "ldr q12, [x26, %[in_col_stride1]]\n" - "fadd v14.4s, v14.4s, v27.4s\n" - "ldr q20, [x26, x21]\n" - "fmla v11.4s, v27.4s, v0.s[1]\n" - "ldr q24, [x26, x22]\n" - "fadd v23.4s, v21.4s, v26.4s\n" - "ldr q29, [x26, x23]\n" - "fadd v13.4s, v16.4s, v9.4s\n" - "ldr q5, [x26, x24]\n" - "fmla v3.4s, v9.4s, v0.s[1]\n" - "ldr q18, [x14]\n" - "fadd v30.4s, v17.4s, v19.4s\n" - "add %[inptr0], %[inptr0], #16\n" - "fadd v16.4s, v2.4s, v23.4s\n" - "add x25, x25, #16\n" - "fsub v21.4s, v21.4s, v26.4s\n" - "ldr q22, [x14, %[in_col_stride1]]\n" - "fadd v3.4s, v3.4s, v6.4s\n" - "ldr q28, [x14, x21]\n" - "fsub v19.4s, v17.4s, v19.4s\n" - "add x13, x13, #16\n" - "fadd v16.4s, v16.4s, v30.4s\n" - "add x26, x26, #16\n" - "mov v17.16b, v23.16b\n" - "fadd v26.4s, v12.4s, v20.4s\n" - "fsub v9.4s, v12.4s, v20.4s\n" - "ldr q2, [%[bptr]]\n" - "fmul v19.4s, v19.4s, v0.s[0]\n" - "add %[bptr], %[bptr], #16\n" - "fmla v17.4s, v30.4s, v0.s[1]\n" - "fadd v25.4s, v24.4s, v29.4s\n" - "fadd v23.4s, v15.4s, v26.4s\n" - "fsub v12.4s, v24.4s, v29.4s\n" - "mov v15.16b, v26.16b\n" - "fadd v24.4s, v22.4s, v28.4s\n" - "fsub v22.4s, v22.4s, v28.4s\n" - "fadd v29.4s, v14.4s, v16.4s\n" - "fsub v16.4s, v14.4s, v16.4s\n" - "ldr q20, [x14, x22]\n" - "fadd v23.4s, v23.4s, v25.4s\n" - "fmul v12.4s, v12.4s, v0.s[0]\n" - "fmla v15.4s, v25.4s, v0.s[1]\n" - "mov v6.16b, v21.16b\n" - "fadd v30.4s, v21.4s, v19.4s\n" - "fadd v26.4s, v18.4s, v24.4s\n" - "mov v25.16b, v24.16b\n" - "fadd v18.4s, v8.4s, v29.4s\n" - "fmla v6.4s, v19.4s, v0.s[1]\n" - "mov v19.16b, v29.16b\n" - "fadd v27.4s, v11.4s, v17.4s\n" - "fsub v11.4s, v11.4s, v17.4s\n" - "fadd v29.4s, v13.4s, v30.4s\n" - "fsub v8.4s, v13.4s, v30.4s\n" - "fadd v14.4s, v9.4s, v12.4s\n" - "fadd v6.4s, v6.4s, v10.4s\n" - "ldr q28, [x14, x23]\n" - "fadd v17.4s, v4.4s, v27.4s\n" - "mov v4.16b, v27.16b\n" - "fmla v9.4s, v12.4s, v0.s[1]\n" - "fadd v27.4s, v20.4s, v28.4s\n" - "fsub v21.4s, v20.4s, v28.4s\n" - "fadd v10.4s, v7.4s, v29.4s\n" - "mov v7.16b, v29.16b\n" - "fadd v13.4s, v3.4s, v6.4s\n" - "fsub v3.4s, v3.4s, v6.4s\n" - "ldr q20, [x14, x24]\n" - "fadd v9.4s, v9.4s, v5.4s\n" - "fadd v26.4s, v26.4s, v27.4s\n" - "fmul v21.4s, v21.4s, v0.s[0]\n" - "add x14, x14, #16\n" - "fmla v25.4s, v27.4s, v0.s[1]\n" - "mov v12.16b, v22.16b\n" - "fadd v1.4s, v1.4s, v13.4s\n" - "mov v5.16b, v13.16b\n" - "fadd v13.4s, v23.4s, v26.4s\n" - "fsub v23.4s, v23.4s, v26.4s\n" - "fadd v26.4s, v22.4s, v21.4s\n" - "ldr q24, [x27]\n" - "fmla v12.4s, v21.4s, v0.s[1]\n" - "fadd v30.4s, v15.4s, v25.4s\n" - "fsub v15.4s, v15.4s, v25.4s\n" - "ldr q28, [x27, %[in_col_stride1]]\n" - "fadd v18.4s, v18.4s, v13.4s\n" - "fmul v23.4s, v23.4s, v0.s[0]\n" - "fmla v19.4s, v13.4s, v0.s[1]\n" - "ldr q22, [x27, x21]\n" - "fadd v12.4s, v12.4s, v20.4s\n" - "ldr q29, [x27, x22]\n" - "fadd v17.4s, v17.4s, v30.4s\n" - "fmul v15.4s, v15.4s, v0.s[0]\n" - "fmla v4.4s, v30.4s, v0.s[1]\n" - "fadd v25.4s, v28.4s, v22.4s\n" - "fsub v27.4s, v28.4s, v22.4s\n" - "fadd v22.4s, v16.4s, v23.4s\n" - "fadd v18.4s, v18.4s, v2.4s\n" - "fadd v17.4s, v17.4s, v2.4s\n" - "fadd v19.4s, v19.4s, v2.4s\n" - "fadd v28.4s, v24.4s, v25.4s\n" - "mov v21.16b, v25.16b\n" - "fmla v16.4s, v23.4s, v0.s[1]\n" - "ldr q30, [x27, x23]\n" - "str q18, [%[outptr0]]\n" - "fadd v20.4s, v29.4s, v30.4s\n" - "str q17, [%[outptr0], x15]\n" - "fsub v18.4s, v29.4s, v30.4s\n" - "str q19, [x28]\n" - "mov v24.16b, v27.16b\n" - "fadd v13.4s, v14.4s, v26.4s\n" - "ldr q25, [x27, x24]\n" - "fadd v28.4s, v28.4s, v20.4s\n" - "add x27, x27, #16\n" - "fmul v18.4s, v18.4s, v0.s[0]\n" - "fmla v21.4s, v20.4s, v0.s[1]\n" - "fsub v30.4s, v14.4s, v26.4s\n" - "mov v14.16b, v8.16b\n" - "fadd v10.4s, v10.4s, v13.4s\n" - "fmla v7.4s, v13.4s, v0.s[1]\n" - "fadd v16.4s, v16.4s, v28.4s\n" - "fadd v29.4s, v11.4s, v15.4s\n" - "fadd v23.4s, v27.4s, v18.4s\n" - "fmla v24.4s, v18.4s, v0.s[1]\n" - "fmul v30.4s, v30.4s, v0.s[0]\n" - "mov v26.16b, v11.16b\n" - "fadd v27.4s, v9.4s, v12.4s\n" - "fsub v6.4s, v9.4s, v12.4s\n" - "mov v13.16b, v3.16b\n" - "fadd v10.4s, v10.4s, v2.4s\n" - "fadd v24.4s, v24.4s, v25.4s\n" - "fmla v26.4s, v15.4s, v0.s[1]\n" - "fadd v8.4s, v8.4s, v30.4s\n" - "fmla v14.4s, v30.4s, v0.s[1]\n" - "fadd v1.4s, v1.4s, v27.4s\n" - "fmul v6.4s, v6.4s, v0.s[0]\n" - "str q10, [%[outptr0], %[output_col_stride1]]\n" - "fmla v5.4s, v27.4s, v0.s[1]\n" - "fadd v26.4s, v26.4s, v21.4s\n" - "fadd v22.4s, v22.4s, v2.4s\n" - "fadd v14.4s, v14.4s, v23.4s\n" - "fadd v8.4s, v8.4s, v2.4s\n" - "fadd v10.4s, v3.4s, v6.4s\n" - "fmla v13.4s, v6.4s, v0.s[1]\n" - "fadd v1.4s, v1.4s, v2.4s\n" - "fadd v29.4s, v29.4s, v2.4s\n" - "str q22, [x17]\n" - "fadd v7.4s, v7.4s, v2.4s\n" - "str q8, [x17, %[output_col_stride1]]\n" - "fadd v4.4s, v4.4s, v2.4s\n" - "fadd v13.4s, v13.4s, v24.4s\n" - "fadd v10.4s, v10.4s, v2.4s\n" - "str q1, [%[outptr0], x16]\n" - "fadd v5.4s, v5.4s, v2.4s\n" - "str q29, [x17, x15]\n" - "fadd v16.4s, v16.4s, v2.4s\n" - "str q7, [x28, %[output_col_stride1]]\n" - "fadd v14.4s, v14.4s, v2.4s\n" - "str q10, [x17, x16]\n" - "fadd v26.4s, v26.4s, v2.4s\n" - "str q4, [x28, x15]\n" - "fadd v13.4s, v13.4s, v2.4s\n" - "str q5, [x28, x16]\n" - "add %[outptr0], %[outptr0], #16\n" - "str q16, [x18]\n" - "add x17, x17, #16\n" - "str q14, [x18, %[output_col_stride1]]\n" - "add x28, x28, #16\n" - "str q26, [x18, x15]\n" - "str q13, [x18, x16]\n" - "add x18, x18, #16\n" - "4:\n" - "cmp x20, #2\n" - "blt 5f\n" - "ldr d19, [%[inptr0]]\n" - "ldr d20, [%[inptr0], %[in_col_stride1]]\n" - "sub x20, x20, #2\n" - "ldr d4, [%[inptr0], x21]\n" - "ldr d17, [%[inptr0], x22]\n" - "fadd v1.4s, v20.4s, v4.4s\n" - "ldr d22, [%[inptr0], x23]\n" - "fadd v5.4s, v17.4s, v22.4s\n" - "ldr d18, [%[inptr0], x24]\n" - "fsub v7.4s, v20.4s, v4.4s\n" - "ldr d25, [x25]\n" - "fsub v10.4s, v17.4s, v22.4s\n" - "ldr d12, [x25, %[in_col_stride1]]\n" - "fadd v8.4s, v19.4s, v1.4s\n" - "ldr d23, [x25, x21]\n" - "mov v4.16b, v1.16b\n" - "ldr d9, [x25, x22]\n" - "mov v1.16b, v7.16b\n" - "ldr d11, [x25, x23]\n" - "fmul v10.4s, v10.4s, v0.s[0]\n" - "ldr d6, [x25, x24]\n" - "fadd v8.4s, v8.4s, v5.4s\n" - "ldr d2, [x13]\n" - "fmla v4.4s, v5.4s, v0.s[1]\n" - "ldr d21, [x13, %[in_col_stride1]]\n" - "fadd v3.4s, v12.4s, v23.4s\n" - "ldr d26, [x13, x21]\n" - "fadd v7.4s, v7.4s, v10.4s\n" - "ldr d17, [x13, x22]\n" - "fmla v1.4s, v10.4s, v0.s[1]\n" - "ldr d19, [x13, x23]\n" - "fadd v27.4s, v9.4s, v11.4s\n" - "ldr d10, [x13, x24]\n" - "fadd v14.4s, v25.4s, v3.4s\n" - "ldr d15, [x26]\n" - "fsub v16.4s, v12.4s, v23.4s\n" - "ldr d12, [x26, %[in_col_stride1]]\n" - "fadd v1.4s, v1.4s, v18.4s\n" - "ldr d20, [x26, x21]\n" - "fsub v9.4s, v9.4s, v11.4s\n" - "ldr d24, [x26, x22]\n" - "fadd v14.4s, v14.4s, v27.4s\n" - "ldr d29, [x26, x23]\n" - "mov v11.16b, v3.16b\n" - "ldr d5, [x26, x24]\n" - "mov v3.16b, v16.16b\n" - "ldr d18, [x14]\n" - "fmul v9.4s, v9.4s, v0.s[0]\n" - "add %[inptr0], %[inptr0], #8\n" - "fmla v11.4s, v27.4s, v0.s[1]\n" - "add x25, x25, #8\n" - "fadd v23.4s, v21.4s, v26.4s\n" - "add x13, x13, #8\n" - "fsub v21.4s, v21.4s, v26.4s\n" - "ldr d22, [x14, %[in_col_stride1]]\n" - "fadd v13.4s, v16.4s, v9.4s\n" - "add x26, x26, #8\n" - "fmla v3.4s, v9.4s, v0.s[1]\n" - "fadd v30.4s, v17.4s, v19.4s\n" - "fadd v16.4s, v2.4s, v23.4s\n" - "fsub v19.4s, v17.4s, v19.4s\n" - "mov v17.16b, v23.16b\n" - "fadd v26.4s, v12.4s, v20.4s\n" - "fsub v9.4s, v12.4s, v20.4s\n" - "ldr d28, [x14, x21]\n" - "fadd v3.4s, v3.4s, v6.4s\n" - "ldr d20, [x14, x22]\n" - "fadd v16.4s, v16.4s, v30.4s\n" - "fmul v19.4s, v19.4s, v0.s[0]\n" - "fmla v17.4s, v30.4s, v0.s[1]\n" - "fadd v25.4s, v24.4s, v29.4s\n" - "fadd v23.4s, v15.4s, v26.4s\n" - "fsub v12.4s, v24.4s, v29.4s\n" - "mov v15.16b, v26.16b\n" - "fadd v24.4s, v22.4s, v28.4s\n" - "fsub v22.4s, v22.4s, v28.4s\n" - "fadd v29.4s, v14.4s, v16.4s\n" - "fsub v16.4s, v14.4s, v16.4s\n" - "ldr d28, [x14, x23]\n" - "fadd v23.4s, v23.4s, v25.4s\n" - "fmul v12.4s, v12.4s, v0.s[0]\n" - "fmla v15.4s, v25.4s, v0.s[1]\n" - "mov v6.16b, v21.16b\n" - "fadd v30.4s, v21.4s, v19.4s\n" - "fadd v26.4s, v18.4s, v24.4s\n" - "mov v25.16b, v24.16b\n" - "fadd v18.4s, v8.4s, v29.4s\n" - "fmla v6.4s, v19.4s, v0.s[1]\n" - "fadd v27.4s, v20.4s, v28.4s\n" - "fsub v21.4s, v20.4s, v28.4s\n" - "mov v19.16b, v29.16b\n" - "fadd v29.4s, v13.4s, v30.4s\n" - "fsub v8.4s, v13.4s, v30.4s\n" - "fadd v14.4s, v9.4s, v12.4s\n" - "fadd v6.4s, v6.4s, v10.4s\n" - "ldr d20, [x14, x24]\n" - "fadd v26.4s, v26.4s, v27.4s\n" - "add x14, x14, #8\n" - "fmla v9.4s, v12.4s, v0.s[1]\n" - "ldr d24, [x27]\n" - "fmul v21.4s, v21.4s, v0.s[0]\n" - "fmla v25.4s, v27.4s, v0.s[1]\n" - "fadd v10.4s, v7.4s, v29.4s\n" - "ldr d2, [%[bptr]]\n" - "mov v7.16b, v29.16b\n" - "add %[bptr], %[bptr], #8\n" - "fadd v9.4s, v9.4s, v5.4s\n" - "fadd v13.4s, v23.4s, v26.4s\n" - "fsub v23.4s, v23.4s, v26.4s\n" - "fadd v27.4s, v11.4s, v17.4s\n" - "fsub v11.4s, v11.4s, v17.4s\n" - "fadd v30.4s, v15.4s, v25.4s\n" - "fsub v15.4s, v15.4s, v25.4s\n" - "ldr d28, [x27, %[in_col_stride1]]\n" - "fadd v18.4s, v18.4s, v13.4s\n" - "fmla v19.4s, v13.4s, v0.s[1]\n" - "fadd v26.4s, v22.4s, v21.4s\n" - "mov v12.16b, v22.16b\n" - "fmul v23.4s, v23.4s, v0.s[0]\n" - "fadd v17.4s, v4.4s, v27.4s\n" - "fmul v15.4s, v15.4s, v0.s[0]\n" - "mov v4.16b, v27.16b\n" - "fmla v12.4s, v21.4s, v0.s[1]\n" - "ldr d22, [x27, x21]\n" - "fadd v18.4s, v18.4s, v2.4s\n" - "fadd v19.4s, v19.4s, v2.4s\n" - "fadd v17.4s, v17.4s, v30.4s\n" - "fmla v4.4s, v30.4s, v0.s[1]\n" - "fadd v25.4s, v28.4s, v22.4s\n" - "fsub v27.4s, v28.4s, v22.4s\n" - "fadd v12.4s, v12.4s, v20.4s\n" - "ldr d29, [x27, x22]\n" - "str d18, [%[outptr0]]\n" - "fadd v22.4s, v16.4s, v23.4s\n" - "str d19, [x28]\n" - "fadd v28.4s, v24.4s, v25.4s\n" - "ldr d30, [x27, x23]\n" - "fadd v20.4s, v29.4s, v30.4s\n" - "fsub v18.4s, v29.4s, v30.4s\n" - "mov v21.16b, v25.16b\n" - "ldr d25, [x27, x24]\n" - "fmla v16.4s, v23.4s, v0.s[1]\n" - "add x27, x27, #8\n" - "mov v24.16b, v27.16b\n" - "fadd v17.4s, v17.4s, v2.4s\n" - "fadd v28.4s, v28.4s, v20.4s\n" - "fmul v18.4s, v18.4s, v0.s[0]\n" - "fmla v21.4s, v20.4s, v0.s[1]\n" - "fadd v13.4s, v14.4s, v26.4s\n" - "fsub v30.4s, v14.4s, v26.4s\n" - "mov v14.16b, v8.16b\n" - "str d17, [%[outptr0], x15]\n" - "fadd v29.4s, v11.4s, v15.4s\n" - "fadd v23.4s, v27.4s, v18.4s\n" - "fmla v24.4s, v18.4s, v0.s[1]\n" - "fadd v16.4s, v16.4s, v28.4s\n" - "fadd v10.4s, v10.4s, v13.4s\n" - "fmul v30.4s, v30.4s, v0.s[0]\n" - "fmla v7.4s, v13.4s, v0.s[1]\n" - "mov v26.16b, v11.16b\n" - "fadd v13.4s, v3.4s, v6.4s\n" - "fadd v24.4s, v24.4s, v25.4s\n" - "fadd v27.4s, v9.4s, v12.4s\n" - "fsub v3.4s, v3.4s, v6.4s\n" - "fsub v6.4s, v9.4s, v12.4s\n" - "fadd v8.4s, v8.4s, v30.4s\n" - "fmla v14.4s, v30.4s, v0.s[1]\n" - "fmla v26.4s, v15.4s, v0.s[1]\n" - "fadd v1.4s, v1.4s, v13.4s\n" - "mov v5.16b, v13.16b\n" - "fadd v10.4s, v10.4s, v2.4s\n" - "fmul v6.4s, v6.4s, v0.s[0]\n" - "mov v13.16b, v3.16b\n" - "fadd v14.4s, v14.4s, v23.4s\n" - "fadd v22.4s, v22.4s, v2.4s\n" - "fadd v26.4s, v26.4s, v21.4s\n" - "fadd v1.4s, v1.4s, v27.4s\n" - "str d10, [%[outptr0], %[output_col_stride1]]\n" - "fmla v5.4s, v27.4s, v0.s[1]\n" - "fadd v10.4s, v3.4s, v6.4s\n" - "fmla v13.4s, v6.4s, v0.s[1]\n" - "str d22, [x17]\n" - "fadd v8.4s, v8.4s, v2.4s\n" - "fadd v1.4s, v1.4s, v2.4s\n" - "fadd v29.4s, v29.4s, v2.4s\n" - "fadd v7.4s, v7.4s, v2.4s\n" - "fadd v4.4s, v4.4s, v2.4s\n" - "fadd v13.4s, v13.4s, v24.4s\n" - "fadd v10.4s, v10.4s, v2.4s\n" - "str d8, [x17, %[output_col_stride1]]\n" - "fadd v5.4s, v5.4s, v2.4s\n" - "str d1, [%[outptr0], x16]\n" - "fadd v16.4s, v16.4s, v2.4s\n" - "str d29, [x17, x15]\n" - "fadd v14.4s, v14.4s, v2.4s\n" - "str d10, [x17, x16]\n" - "fadd v26.4s, v26.4s, v2.4s\n" - "str d7, [x28, %[output_col_stride1]]\n" - "fadd v13.4s, v13.4s, v2.4s\n" - "str d4, [x28, x15]\n" - "add %[outptr0], %[outptr0], #8\n" - "str d5, [x28, x16]\n" - "add x17, x17, #8\n" - "str d16, [x18]\n" - "add x28, x28, #8\n" - "str d14, [x18, %[output_col_stride1]]\n" - "str d26, [x18, x15]\n" - "str d13, [x18, x16]\n" - "add x18, x18, #8\n" - "5:\n" - "cbz x20, 6f\n" - "ldr s19, [%[inptr0]]\n" - "ldr s20, [%[inptr0], %[in_col_stride1]]\n" - "ldr s4, [%[inptr0], x21]\n" - "fadd v1.4s, v20.4s, v4.4s\n" - "ldr s17, [%[inptr0], x22]\n" - "fsub v7.4s, v20.4s, v4.4s\n" - "ldr s22, [%[inptr0], x23]\n" - "fadd v5.4s, v17.4s, v22.4s\n" - "ldr s18, [%[inptr0], x24]\n" - "fsub v10.4s, v17.4s, v22.4s\n" - "ldr s25, [x25]\n" - "fadd v8.4s, v19.4s, v1.4s\n" - "ldr s12, [x25, %[in_col_stride1]]\n" - "mov v4.16b, v1.16b\n" - "ldr s23, [x25, x21]\n" - "mov v1.16b, v7.16b\n" - "ldr s9, [x25, x22]\n" - "fmul v10.4s, v10.4s, v0.s[0]\n" - "ldr s11, [x25, x23]\n" - "fadd v8.4s, v8.4s, v5.4s\n" - "ldr s6, [x25, x24]\n" - "fmla v4.4s, v5.4s, v0.s[1]\n" - "ldr s2, [x13]\n" - "fadd v3.4s, v12.4s, v23.4s\n" - "ldr s21, [x13, %[in_col_stride1]]\n" - "fadd v7.4s, v7.4s, v10.4s\n" - "ldr s26, [x13, x21]\n" - "fmla v1.4s, v10.4s, v0.s[1]\n" - "ldr s17, [x13, x22]\n" - "fadd v27.4s, v9.4s, v11.4s\n" - "ldr s19, [x13, x23]\n" - "fadd v14.4s, v25.4s, v3.4s\n" - "ldr s10, [x13, x24]\n" - "fsub v16.4s, v12.4s, v23.4s\n" - "ldr s15, [x26]\n" - "fadd v1.4s, v1.4s, v18.4s\n" - "ldr s12, [x26, %[in_col_stride1]]\n" - "fsub v9.4s, v9.4s, v11.4s\n" - "ldr s20, [x26, x21]\n" - "fadd v14.4s, v14.4s, v27.4s\n" - "ldr s24, [x26, x22]\n" - "mov v11.16b, v3.16b\n" - "ldr s29, [x26, x23]\n" - "mov v3.16b, v16.16b\n" - "ldr s5, [x26, x24]\n" - "fmul v9.4s, v9.4s, v0.s[0]\n" - "ldr s18, [x14]\n" - "fmla v11.4s, v27.4s, v0.s[1]\n" - "fadd v23.4s, v21.4s, v26.4s\n" - "fsub v21.4s, v21.4s, v26.4s\n" - "fadd v30.4s, v17.4s, v19.4s\n" - "fsub v19.4s, v17.4s, v19.4s\n" - "ldr s22, [x14, %[in_col_stride1]]\n" - "fadd v13.4s, v16.4s, v9.4s\n" - "fmla v3.4s, v9.4s, v0.s[1]\n" - "fadd v16.4s, v2.4s, v23.4s\n" - "mov v17.16b, v23.16b\n" - "fadd v26.4s, v12.4s, v20.4s\n" - "fsub v9.4s, v12.4s, v20.4s\n" - "fmul v19.4s, v19.4s, v0.s[0]\n" - "ldr s28, [x14, x21]\n" - "fadd v3.4s, v3.4s, v6.4s\n" - "ldr s20, [x14, x22]\n" - "fadd v16.4s, v16.4s, v30.4s\n" - "fmla v17.4s, v30.4s, v0.s[1]\n" - "fadd v25.4s, v24.4s, v29.4s\n" - "fadd v23.4s, v15.4s, v26.4s\n" - "fsub v12.4s, v24.4s, v29.4s\n" - "mov v15.16b, v26.16b\n" - "fadd v24.4s, v22.4s, v28.4s\n" - "fsub v22.4s, v22.4s, v28.4s\n" - "fadd v30.4s, v21.4s, v19.4s\n" - "mov v6.16b, v21.16b\n" - "fadd v23.4s, v23.4s, v25.4s\n" - "fmla v15.4s, v25.4s, v0.s[1]\n" - "fmul v12.4s, v12.4s, v0.s[0]\n" - "ldr s28, [x14, x23]\n" - "fmla v6.4s, v19.4s, v0.s[1]\n" - "fadd v27.4s, v20.4s, v28.4s\n" - "fadd v26.4s, v18.4s, v24.4s\n" - "fsub v21.4s, v20.4s, v28.4s\n" - "mov v25.16b, v24.16b\n" - "fadd v29.4s, v14.4s, v16.4s\n" - "fsub v16.4s, v14.4s, v16.4s\n" - "ldr s20, [x14, x24]\n" - "fadd v6.4s, v6.4s, v10.4s\n" - "ldr s24, [x27]\n" - "fadd v26.4s, v26.4s, v27.4s\n" - "fmul v21.4s, v21.4s, v0.s[0]\n" - "fmla v25.4s, v27.4s, v0.s[1]\n" - "fadd v18.4s, v8.4s, v29.4s\n" - "mov v19.16b, v29.16b\n" - "fadd v29.4s, v13.4s, v30.4s\n" - "fsub v8.4s, v13.4s, v30.4s\n" - "fadd v27.4s, v11.4s, v17.4s\n" - "fsub v11.4s, v11.4s, v17.4s\n" - "fadd v13.4s, v23.4s, v26.4s\n" - "fsub v23.4s, v23.4s, v26.4s\n" - "ldr s28, [x27, %[in_col_stride1]]\n" - "fadd v10.4s, v7.4s, v29.4s\n" - "mov v7.16b, v29.16b\n" - "fadd v17.4s, v4.4s, v27.4s\n" - "mov v4.16b, v27.16b\n" - "fadd v18.4s, v18.4s, v13.4s\n" - "fmla v19.4s, v13.4s, v0.s[1]\n" - "fmul v23.4s, v23.4s, v0.s[0]\n" - "fadd v30.4s, v15.4s, v25.4s\n" - "fsub v15.4s, v15.4s, v25.4s\n" - "fadd v13.4s, v3.4s, v6.4s\n" - "fsub v3.4s, v3.4s, v6.4s\n" - "ldr s2, [%[bptr]]\n" - "fadd v18.4s, v18.4s, v2.4s\n" - "fadd v19.4s, v19.4s, v2.4s\n" - "fadd v17.4s, v17.4s, v30.4s\n" - "fmla v4.4s, v30.4s, v0.s[1]\n" - "fadd v14.4s, v9.4s, v12.4s\n" - "fmul v15.4s, v15.4s, v0.s[0]\n" - "fadd v1.4s, v1.4s, v13.4s\n" - "str s18, [%[outptr0]]\n" - "fadd v26.4s, v22.4s, v21.4s\n" - "str s19, [x28]\n" - "fmla v9.4s, v12.4s, v0.s[1]\n" - "mov v12.16b, v22.16b\n" - "ldr s22, [x27, x21]\n" - "fadd v25.4s, v28.4s, v22.4s\n" - "fsub v27.4s, v28.4s, v22.4s\n" - "fadd v22.4s, v16.4s, v23.4s\n" - "fadd v9.4s, v9.4s, v5.4s\n" - "ldr s29, [x27, x22]\n" - "fmla v12.4s, v21.4s, v0.s[1]\n" - "ldr s30, [x27, x23]\n" - "fadd v28.4s, v24.4s, v25.4s\n" - "mov v21.16b, v25.16b\n" - "fmla v16.4s, v23.4s, v0.s[1]\n" - "ldr s25, [x27, x24]\n" - "mov v5.16b, v13.16b\n" - "fadd v17.4s, v17.4s, v2.4s\n" - "fadd v12.4s, v12.4s, v20.4s\n" - "fadd v20.4s, v29.4s, v30.4s\n" - "fsub v18.4s, v29.4s, v30.4s\n" - "mov v24.16b, v27.16b\n" - "fadd v22.4s, v22.4s, v2.4s\n" - "fadd v4.4s, v4.4s, v2.4s\n" - "str s17, [%[outptr0], x15]\n" - "fadd v13.4s, v14.4s, v26.4s\n" - "fadd v28.4s, v28.4s, v20.4s\n" - "fmla v21.4s, v20.4s, v0.s[1]\n" - "fmul v18.4s, v18.4s, v0.s[0]\n" - "fsub v30.4s, v14.4s, v26.4s\n" - "str s22, [x17]\n" - "mov v14.16b, v8.16b\n" - "str s4, [x28, x15]\n" - "fadd v10.4s, v10.4s, v13.4s\n" - "fadd v16.4s, v16.4s, v28.4s\n" - "fmla v7.4s, v13.4s, v0.s[1]\n" - "fadd v23.4s, v27.4s, v18.4s\n" - "fmla v24.4s, v18.4s, v0.s[1]\n" - "fmul v30.4s, v30.4s, v0.s[0]\n" - "fadd v29.4s, v11.4s, v15.4s\n" - "mov v26.16b, v11.16b\n" - "fadd v27.4s, v9.4s, v12.4s\n" - "fsub v6.4s, v9.4s, v12.4s\n" - "mov v13.16b, v3.16b\n" - "fadd v24.4s, v24.4s, v25.4s\n" - "fadd v10.4s, v10.4s, v2.4s\n" - "fadd v8.4s, v8.4s, v30.4s\n" - "fmla v14.4s, v30.4s, v0.s[1]\n" - "fmla v26.4s, v15.4s, v0.s[1]\n" - "fadd v1.4s, v1.4s, v27.4s\n" - "fmul v6.4s, v6.4s, v0.s[0]\n" - "fmla v5.4s, v27.4s, v0.s[1]\n" - "str s10, [%[outptr0], %[output_col_stride1]]\n" - "fadd v29.4s, v29.4s, v2.4s\n" - "fadd v14.4s, v14.4s, v23.4s\n" - "fadd v8.4s, v8.4s, v2.4s\n" - "fadd v26.4s, v26.4s, v21.4s\n" - "fadd v1.4s, v1.4s, v2.4s\n" - "fadd v10.4s, v3.4s, v6.4s\n" - "fmla v13.4s, v6.4s, v0.s[1]\n" - "str s29, [x17, x15]\n" - "fadd v7.4s, v7.4s, v2.4s\n" - "str s8, [x17, %[output_col_stride1]]\n" - "fadd v5.4s, v5.4s, v2.4s\n" - "str s1, [%[outptr0], x16]\n" - "fadd v16.4s, v16.4s, v2.4s\n" - "fadd v13.4s, v13.4s, v24.4s\n" - "fadd v10.4s, v10.4s, v2.4s\n" - "str s7, [x28, %[output_col_stride1]]\n" - "fadd v14.4s, v14.4s, v2.4s\n" - "str s5, [x28, x16]\n" - "fadd v26.4s, v26.4s, v2.4s\n" - "str s16, [x18]\n" - "fadd v13.4s, v13.4s, v2.4s\n" - "str s10, [x17, x16]\n" - "str s14, [x18, %[output_col_stride1]]\n" - "str s26, [x18, x15]\n" - "str s13, [x18, x16]\n" - "6:\n" - : [bptr] "+r" (bptr), [outptr0] "+r" (output), [inptr0] "+r" (inptr) - : [output_row_stride] "r" (output_row_stride * sizeof(float)), [output_col_stride1] "r" (output_col_stride * sizeof(float)), [pcoeffs] "r" (coeffs), [n_channels] "r" ((long) n_channels), [in_row_stride] "r" (6 * matrix_stride * sizeof(float)), [in_col_stride1] "r" (matrix_stride * sizeof(float)) - : "cc", "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v2", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v4", "v5", "v6", "v7", "v8", "v9", "x13", "x14", "x15", "x16", "x17", "x18", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "memory" - ); - } - else - { - __asm__ __volatile__ ( - "ldr d0, [%[pcoeffs]]\n" - "add x21, %[in_col_stride1], %[in_col_stride1]\n" // Compute input column stride 2 - "add x22, x21, %[in_col_stride1]\n" // Compute input column stride 3 - "add x25, %[inptr0], %[in_row_stride]\n" // Compute input row pointers - "add x15, %[output_col_stride1], %[output_col_stride1]\n" // Compute output column stride 2 - "add x23, x22, %[in_col_stride1]\n" // Compute input column stride 4 - "add x13, x25, %[in_row_stride]\n" // Compute input row pointers - "add x16, x15, %[output_col_stride1]\n" // Compute output column stride 3 - "add x24, x23, %[in_col_stride1]\n" // Compute input column stride 5 - "add x26, x13, %[in_row_stride]\n" // Compute input row pointers - "add x17, %[outptr0], %[output_row_stride]\n" // Compute output row pointer 1 - "add x14, x26, %[in_row_stride]\n" // Compute input row pointers - "add x28, x17, %[output_row_stride]\n" // Compute output row pointer 2 - "lsr x19, %[n_channels], #2\n" - "add x27, x14, %[in_row_stride]\n" // Compute input row pointers - "add x18, x28, %[output_row_stride]\n" // Compute output row pointer 3 - "and x20, %[n_channels], #3\n" - "cbz x19, 4f\n" - "1:\n" // Quad head - "ldr q17, [%[inptr0]]\n" - "subs x19, x19, #1\n" - "ldr q23, [%[inptr0], %[in_col_stride1]]\n" - "ldr q27, [%[inptr0], x21]\n" - "fadd v4.4s, v23.4s, v27.4s\n" - "ldr q24, [%[inptr0], x22]\n" - "fsub v13.4s, v23.4s, v27.4s\n" - "ldr q11, [%[inptr0], x23]\n" - "fadd v10.4s, v24.4s, v11.4s\n" - "ldr q12, [%[inptr0], x24]\n" - "fsub v11.4s, v24.4s, v11.4s\n" - "ldr q20, [x25]\n" - "fadd v7.4s, v17.4s, v4.4s\n" - "ldr q19, [x25, %[in_col_stride1]]\n" - "mov v4.16b, v4.16b\n" - "ldr q22, [x25, x21]\n" - "mov v1.16b, v13.16b\n" - "ldr q14, [x25, x22]\n" - "fmul v11.4s, v11.4s, v0.s[0]\n" - "ldr q18, [x25, x23]\n" - "fadd v7.4s, v7.4s, v10.4s\n" - "ldr q3, [x25, x24]\n" - "fmla v4.4s, v10.4s, v0.s[1]\n" - "fadd v8.4s, v13.4s, v11.4s\n" - "fmla v1.4s, v11.4s, v0.s[1]\n" - "fadd v1.4s, v1.4s, v12.4s\n" - "beq 3f\n" - "2:\n" // Quad loop - "fadd v2.4s, v19.4s, v22.4s\n" - "ldr q16, [x13]\n" - "fadd v23.4s, v14.4s, v18.4s\n" - "ldr q21, [x13, %[in_col_stride1]]\n" - "fsub v15.4s, v19.4s, v22.4s\n" - "ldr q24, [x13, x21]\n" - "fsub v31.4s, v14.4s, v18.4s\n" - "ldr q25, [x13, x22]\n" - "fadd v11.4s, v20.4s, v2.4s\n" - "ldr q17, [x13, x23]\n" - "mov v13.16b, v2.16b\n" - "ldr q9, [x13, x24]\n" - "mov v2.16b, v15.16b\n" - "ldr q6, [x26]\n" - "fmul v31.4s, v31.4s, v0.s[0]\n" - "ldr q19, [x26, %[in_col_stride1]]\n" - "fadd v11.4s, v11.4s, v23.4s\n" - "ldr q22, [x26, x21]\n" - "fmla v13.4s, v23.4s, v0.s[1]\n" - "ldr q12, [x26, x22]\n" - "fadd v29.4s, v21.4s, v24.4s\n" - "ldr q26, [x26, x23]\n" - "fadd v15.4s, v15.4s, v31.4s\n" - "ldr q5, [x26, x24]\n" - "fmla v2.4s, v31.4s, v0.s[1]\n" - "ldr q10, [x14]\n" - "fadd v18.4s, v25.4s, v17.4s\n" - "add %[inptr0], %[inptr0], #16\n" - "fadd v27.4s, v16.4s, v29.4s\n" - "add x25, x25, #16\n" - "fsub v14.4s, v21.4s, v24.4s\n" - "ldr q30, [x14, %[in_col_stride1]]\n" - "fadd v2.4s, v2.4s, v3.4s\n" - "ldr q31, [x14, x21]\n" - "fsub v28.4s, v25.4s, v17.4s\n" - "add x13, x13, #16\n" - "fadd v27.4s, v27.4s, v18.4s\n" - "add x26, x26, #16\n" - "mov v21.16b, v29.16b\n" - "subs x19, x19, #1\n" - "fadd v20.4s, v19.4s, v22.4s\n" - "fsub v17.4s, v19.4s, v22.4s\n" - "fmul v28.4s, v28.4s, v0.s[0]\n" - "ldr q23, [x14, x22]\n" - "fmla v21.4s, v18.4s, v0.s[1]\n" - "fadd v29.4s, v12.4s, v26.4s\n" - "fsub v16.4s, v12.4s, v26.4s\n" - "fadd v25.4s, v30.4s, v31.4s\n" - "fadd v24.4s, v6.4s, v20.4s\n" - "mov v6.16b, v20.16b\n" - "fsub v22.4s, v30.4s, v31.4s\n" - "fadd v31.4s, v11.4s, v27.4s\n" - "fsub v12.4s, v11.4s, v27.4s\n" - "ldr q26, [x14, x23]\n" - "fmul v16.4s, v16.4s, v0.s[0]\n" - "fmla v6.4s, v29.4s, v0.s[1]\n" - "fadd v24.4s, v24.4s, v29.4s\n" - "mov v3.16b, v14.16b\n" - "fadd v20.4s, v14.4s, v28.4s\n" - "fadd v29.4s, v10.4s, v25.4s\n" - "mov v10.16b, v25.16b\n" - "fadd v25.4s, v7.4s, v31.4s\n" - "fmla v3.4s, v28.4s, v0.s[1]\n" - "fadd v14.4s, v23.4s, v26.4s\n" - "fsub v23.4s, v23.4s, v26.4s\n" - "mov v26.16b, v31.16b\n" - "fadd v31.4s, v15.4s, v20.4s\n" - "fsub v11.4s, v15.4s, v20.4s\n" - "fadd v20.4s, v17.4s, v16.4s\n" - "mov v7.16b, v17.16b\n" - "fadd v3.4s, v3.4s, v9.4s\n" - "ldr q18, [x14, x24]\n" - "fadd v29.4s, v29.4s, v14.4s\n" - "add x14, x14, #16\n" - "fmla v7.4s, v16.4s, v0.s[1]\n" - "ldr q19, [x27]\n" - "fmul v23.4s, v23.4s, v0.s[0]\n" - "fmla v10.4s, v14.4s, v0.s[1]\n" - "fadd v15.4s, v8.4s, v31.4s\n" - "mov v14.16b, v31.16b\n" - "fadd v28.4s, v24.4s, v29.4s\n" - "fsub v24.4s, v24.4s, v29.4s\n" - "fadd v7.4s, v7.4s, v5.4s\n" - "ldr q27, [x27, %[in_col_stride1]]\n" - "fadd v30.4s, v13.4s, v21.4s\n" - "fsub v9.4s, v13.4s, v21.4s\n" - "fadd v17.4s, v22.4s, v23.4s\n" - "mov v8.16b, v22.16b\n" - "fadd v25.4s, v25.4s, v28.4s\n" - "fmul v24.4s, v24.4s, v0.s[0]\n" - "fmla v26.4s, v28.4s, v0.s[1]\n" - "ldr q29, [x27, x21]\n" - "fmla v8.4s, v23.4s, v0.s[1]\n" - "ldr q28, [x27, x22]\n" - "fadd v13.4s, v4.4s, v30.4s\n" - "mov v4.16b, v30.16b\n" - "str q25, [%[outptr0]]\n" // Store output (0, 0) - "fadd v16.4s, v27.4s, v29.4s\n" - "str q26, [x28]\n" // Store output (2, 0) - "fsub v29.4s, v27.4s, v29.4s\n" - "fadd v8.4s, v8.4s, v18.4s\n" - "ldr q23, [x27, x23]\n" - "fadd v30.4s, v28.4s, v23.4s\n" - "ldr q25, [x27, x24]\n" - "fadd v19.4s, v19.4s, v16.4s\n" - "add x27, x27, #16\n" - "fsub v27.4s, v28.4s, v23.4s\n" - "mov v16.16b, v16.16b\n" - "fadd v22.4s, v20.4s, v17.4s\n" - "fsub v20.4s, v20.4s, v17.4s\n" - "fadd v21.4s, v12.4s, v24.4s\n" - "mov v26.16b, v12.16b\n" - "fadd v19.4s, v19.4s, v30.4s\n" - "fmla v16.4s, v30.4s, v0.s[1]\n" - "fmul v27.4s, v27.4s, v0.s[0]\n" - "ldr q17, [%[inptr0]]\n" - "fmla v26.4s, v24.4s, v0.s[1]\n" - "ldr q23, [%[inptr0], %[in_col_stride1]]\n" - "str q21, [x17]\n" // Store output (1, 0) - "mov v5.16b, v29.16b\n" - "fadd v15.4s, v15.4s, v22.4s\n" - "fmul v20.4s, v20.4s, v0.s[0]\n" - "fadd v18.4s, v29.4s, v27.4s\n" - "fmla v14.4s, v22.4s, v0.s[1]\n" - "fmla v5.4s, v27.4s, v0.s[1]\n" - "ldr q27, [%[inptr0], x21]\n" - "fadd v26.4s, v26.4s, v19.4s\n" - "ldr q24, [%[inptr0], x22]\n" - "str q15, [%[outptr0], %[output_col_stride1]]\n" // Store output (0, 1) - "fadd v12.4s, v11.4s, v20.4s\n" - "str q14, [x28, %[output_col_stride1]]\n" // Store output (2, 1) - "mov v28.16b, v11.16b\n" - "fadd v5.4s, v5.4s, v25.4s\n" - "ldr q11, [%[inptr0], x23]\n" - "str q26, [x18]\n" // Store output (3, 0) - "fadd v21.4s, v6.4s, v10.4s\n" - "str q12, [x17, %[output_col_stride1]]\n" // Store output (1, 1) - "fmla v28.4s, v20.4s, v0.s[1]\n" - "fsub v10.4s, v6.4s, v10.4s\n" - "ldr q12, [%[inptr0], x24]\n" - "mov v15.16b, v9.16b\n" - "ldr q20, [x25]\n" - "fadd v13.4s, v13.4s, v21.4s\n" - "ldr q19, [x25, %[in_col_stride1]]\n" - "fadd v28.4s, v28.4s, v18.4s\n" - "ldr q22, [x25, x21]\n" - "fmul v10.4s, v10.4s, v0.s[0]\n" - "ldr q14, [x25, x22]\n" - "fmla v4.4s, v21.4s, v0.s[1]\n" - "ldr q18, [x25, x23]\n" - "str q13, [%[outptr0], x15]\n" // Store output (0, 2) - "fadd v6.4s, v2.4s, v3.4s\n" - "str q28, [x18, %[output_col_stride1]]\n" // Store output (3, 1) - "fadd v30.4s, v7.4s, v8.4s\n" - "fadd v13.4s, v9.4s, v10.4s\n" - "fmla v15.4s, v10.4s, v0.s[1]\n" - "str q4, [x28, x15]\n" // Store output (2, 2) - "fsub v2.4s, v2.4s, v3.4s\n" - "fadd v1.4s, v1.4s, v6.4s\n" - "ldr q3, [x25, x24]\n" - "fsub v8.4s, v7.4s, v8.4s\n" - "mov v6.16b, v6.16b\n" - "str q13, [x17, x15]\n" // Store output (1, 2) - "fadd v15.4s, v15.4s, v16.4s\n" - "mov v9.16b, v2.16b\n" - "fadd v4.4s, v23.4s, v27.4s\n" - "fadd v1.4s, v1.4s, v30.4s\n" - "fmla v6.4s, v30.4s, v0.s[1]\n" - "fmul v8.4s, v8.4s, v0.s[0]\n" - "fadd v10.4s, v24.4s, v11.4s\n" - "str q15, [x18, x15]\n" // Store output (3, 2) - "fsub v13.4s, v23.4s, v27.4s\n" - "fadd v7.4s, v17.4s, v4.4s\n" - "fsub v11.4s, v24.4s, v11.4s\n" - "str q1, [%[outptr0], x16]\n" // Store output (0, 3) - "mov v4.16b, v4.16b\n" - "str q6, [x28, x16]\n" // Store output (2, 3) - "fadd v2.4s, v2.4s, v8.4s\n" - "fmla v9.4s, v8.4s, v0.s[1]\n" - "add %[outptr0], %[outptr0], #16\n" - "fadd v7.4s, v7.4s, v10.4s\n" - "add x28, x28, #16\n" - "fmul v11.4s, v11.4s, v0.s[0]\n" - "fmla v4.4s, v10.4s, v0.s[1]\n" - "str q2, [x17, x16]\n" // Store output (1, 3) - "mov v1.16b, v13.16b\n" - "fadd v9.4s, v9.4s, v5.4s\n" - "add x17, x17, #16\n" - "fadd v8.4s, v13.4s, v11.4s\n" - "fmla v1.4s, v11.4s, v0.s[1]\n" - "str q9, [x18, x16]\n" // Store output (3, 3) - "add x18, x18, #16\n" - "fadd v1.4s, v1.4s, v12.4s\n" - "bne 2b\n" - "3:\n" // Quad tail - "fadd v2.4s, v19.4s, v22.4s\n" - "ldr q16, [x13]\n" - "fadd v23.4s, v14.4s, v18.4s\n" - "ldr q21, [x13, %[in_col_stride1]]\n" - "fsub v15.4s, v19.4s, v22.4s\n" - "ldr q24, [x13, x21]\n" - "fsub v31.4s, v14.4s, v18.4s\n" - "ldr q25, [x13, x22]\n" - "fadd v11.4s, v20.4s, v2.4s\n" - "ldr q17, [x13, x23]\n" - "mov v13.16b, v2.16b\n" - "ldr q9, [x13, x24]\n" - "mov v2.16b, v15.16b\n" - "ldr q6, [x26]\n" - "fmul v31.4s, v31.4s, v0.s[0]\n" - "ldr q19, [x26, %[in_col_stride1]]\n" - "fadd v11.4s, v11.4s, v23.4s\n" - "ldr q22, [x26, x21]\n" - "fmla v13.4s, v23.4s, v0.s[1]\n" - "ldr q12, [x26, x22]\n" - "fadd v29.4s, v21.4s, v24.4s\n" - "ldr q26, [x26, x23]\n" - "fadd v15.4s, v15.4s, v31.4s\n" - "ldr q5, [x26, x24]\n" - "fmla v2.4s, v31.4s, v0.s[1]\n" - "ldr q10, [x14]\n" - "fadd v18.4s, v25.4s, v17.4s\n" - "add %[inptr0], %[inptr0], #16\n" - "fadd v27.4s, v16.4s, v29.4s\n" - "add x25, x25, #16\n" - "fsub v14.4s, v21.4s, v24.4s\n" - "ldr q30, [x14, %[in_col_stride1]]\n" - "fadd v2.4s, v2.4s, v3.4s\n" - "ldr q31, [x14, x21]\n" - "fsub v28.4s, v25.4s, v17.4s\n" - "add x13, x13, #16\n" - "fadd v27.4s, v27.4s, v18.4s\n" - "add x26, x26, #16\n" - "mov v21.16b, v29.16b\n" - "fadd v20.4s, v19.4s, v22.4s\n" - "fsub v17.4s, v19.4s, v22.4s\n" - "fadd v29.4s, v12.4s, v26.4s\n" - "fmul v28.4s, v28.4s, v0.s[0]\n" - "fsub v16.4s, v12.4s, v26.4s\n" - "fmla v21.4s, v18.4s, v0.s[1]\n" - "ldr q23, [x14, x22]\n" - "fadd v24.4s, v6.4s, v20.4s\n" - "mov v6.16b, v20.16b\n" - "fadd v25.4s, v30.4s, v31.4s\n" - "fsub v22.4s, v30.4s, v31.4s\n" - "fadd v20.4s, v14.4s, v28.4s\n" - "mov v3.16b, v14.16b\n" - "fmul v16.4s, v16.4s, v0.s[0]\n" - "fmla v6.4s, v29.4s, v0.s[1]\n" - "fadd v24.4s, v24.4s, v29.4s\n" - "ldr q26, [x14, x23]\n" - "fmla v3.4s, v28.4s, v0.s[1]\n" - "fadd v14.4s, v23.4s, v26.4s\n" - "fadd v29.4s, v10.4s, v25.4s\n" - "fsub v23.4s, v23.4s, v26.4s\n" - "mov v10.16b, v25.16b\n" - "fadd v31.4s, v11.4s, v27.4s\n" - "fsub v12.4s, v11.4s, v27.4s\n" - "ldr q18, [x14, x24]\n" - "fadd v3.4s, v3.4s, v9.4s\n" - "ldr q19, [x27]\n" - "fadd v29.4s, v29.4s, v14.4s\n" - "add x14, x14, #16\n" - "fmul v23.4s, v23.4s, v0.s[0]\n" - "fmla v10.4s, v14.4s, v0.s[1]\n" - "fadd v25.4s, v7.4s, v31.4s\n" - "mov v26.16b, v31.16b\n" - "fadd v31.4s, v15.4s, v20.4s\n" - "fsub v11.4s, v15.4s, v20.4s\n" - "fadd v28.4s, v24.4s, v29.4s\n" - "fsub v24.4s, v24.4s, v29.4s\n" - "fadd v30.4s, v13.4s, v21.4s\n" - "fsub v9.4s, v13.4s, v21.4s\n" - "fadd v20.4s, v17.4s, v16.4s\n" - "mov v7.16b, v17.16b\n" - "fadd v15.4s, v8.4s, v31.4s\n" - "mov v14.16b, v31.16b\n" - "fadd v25.4s, v25.4s, v28.4s\n" - "fmul v24.4s, v24.4s, v0.s[0]\n" - "fmla v7.4s, v16.4s, v0.s[1]\n" - "ldr q27, [x27, %[in_col_stride1]]\n" - "fmla v26.4s, v28.4s, v0.s[1]\n" - "ldr q29, [x27, x21]\n" - "fadd v13.4s, v4.4s, v30.4s\n" - "mov v4.16b, v30.16b\n" - "str q25, [%[outptr0]]\n" // Store output (0, 0) - "fadd v17.4s, v22.4s, v23.4s\n" - "fadd v7.4s, v7.4s, v5.4s\n" - "ldr q28, [x27, x22]\n" - "str q26, [x28]\n" // Store output (2, 0) - "mov v8.16b, v22.16b\n" - "fadd v16.4s, v27.4s, v29.4s\n" - "fsub v29.4s, v27.4s, v29.4s\n" - "fadd v21.4s, v12.4s, v24.4s\n" - "mov v26.16b, v12.16b\n" - "fmla v8.4s, v23.4s, v0.s[1]\n" - "fadd v22.4s, v20.4s, v17.4s\n" - "fsub v20.4s, v20.4s, v17.4s\n" - "ldr q23, [x27, x23]\n" - "fadd v19.4s, v19.4s, v16.4s\n" - "mov v16.16b, v16.16b\n" - "str q21, [x17]\n" // Store output (1, 0) - "fadd v30.4s, v28.4s, v23.4s\n" - "fadd v8.4s, v8.4s, v18.4s\n" - "ldr q25, [x27, x24]\n" - "fsub v27.4s, v28.4s, v23.4s\n" - "add x27, x27, #16\n" - "mov v5.16b, v29.16b\n" - "fmla v26.4s, v24.4s, v0.s[1]\n" - "fadd v19.4s, v19.4s, v30.4s\n" - "fmla v16.4s, v30.4s, v0.s[1]\n" - "fadd v15.4s, v15.4s, v22.4s\n" - "fmul v20.4s, v20.4s, v0.s[0]\n" - "fmul v27.4s, v27.4s, v0.s[0]\n" - "fmla v14.4s, v22.4s, v0.s[1]\n" - "mov v28.16b, v11.16b\n" - "fadd v21.4s, v6.4s, v10.4s\n" - "fadd v26.4s, v26.4s, v19.4s\n" - "fsub v10.4s, v6.4s, v10.4s\n" - "str q15, [%[outptr0], %[output_col_stride1]]\n" // Store output (0, 1) - "fadd v12.4s, v11.4s, v20.4s\n" - "str q14, [x28, %[output_col_stride1]]\n" // Store output (2, 1) - "fadd v18.4s, v29.4s, v27.4s\n" - "fmla v5.4s, v27.4s, v0.s[1]\n" - "fmla v28.4s, v20.4s, v0.s[1]\n" - "str q26, [x18]\n" // Store output (3, 0) - "fadd v13.4s, v13.4s, v21.4s\n" - "str q12, [x17, %[output_col_stride1]]\n" // Store output (1, 1) - "fmul v10.4s, v10.4s, v0.s[0]\n" - "fmla v4.4s, v21.4s, v0.s[1]\n" - "mov v15.16b, v9.16b\n" - "fadd v5.4s, v5.4s, v25.4s\n" - "fadd v28.4s, v28.4s, v18.4s\n" - "str q13, [%[outptr0], x15]\n" // Store output (0, 2) - "fadd v6.4s, v2.4s, v3.4s\n" - "fadd v13.4s, v9.4s, v10.4s\n" - "fmla v15.4s, v10.4s, v0.s[1]\n" - "str q4, [x28, x15]\n" // Store output (2, 2) - "fadd v30.4s, v7.4s, v8.4s\n" - "str q28, [x18, %[output_col_stride1]]\n" // Store output (3, 1) - "fsub v2.4s, v2.4s, v3.4s\n" - "fadd v1.4s, v1.4s, v6.4s\n" - "fsub v8.4s, v7.4s, v8.4s\n" - "str q13, [x17, x15]\n" // Store output (1, 2) - "fadd v15.4s, v15.4s, v16.4s\n" - "mov v6.16b, v6.16b\n" - "mov v9.16b, v2.16b\n" - "fadd v1.4s, v1.4s, v30.4s\n" - "fmul v8.4s, v8.4s, v0.s[0]\n" - "str q15, [x18, x15]\n" // Store output (3, 2) - "fmla v6.4s, v30.4s, v0.s[1]\n" - "str q1, [%[outptr0], x16]\n" // Store output (0, 3) - "fadd v2.4s, v2.4s, v8.4s\n" - "str q6, [x28, x16]\n" // Store output (2, 3) - "fmla v9.4s, v8.4s, v0.s[1]\n" - "add %[outptr0], %[outptr0], #16\n" - "add x28, x28, #16\n" - "str q2, [x17, x16]\n" // Store output (1, 3) - "fadd v9.4s, v9.4s, v5.4s\n" - "add x17, x17, #16\n" - "str q9, [x18, x16]\n" // Store output (3, 3) - "add x18, x18, #16\n" - "4:\n" // Double - "cmp x20, #2\n" - "blt 5f\n" - "ldr d17, [%[inptr0]]\n" - "ldr d23, [%[inptr0], %[in_col_stride1]]\n" - "sub x20, x20, #2\n" - "ldr d27, [%[inptr0], x21]\n" - "ldr d24, [%[inptr0], x22]\n" - "fadd v4.4s, v23.4s, v27.4s\n" - "ldr d11, [%[inptr0], x23]\n" - "fadd v10.4s, v24.4s, v11.4s\n" - "ldr d12, [%[inptr0], x24]\n" - "fsub v13.4s, v23.4s, v27.4s\n" - "ldr d20, [x25]\n" - "fsub v11.4s, v24.4s, v11.4s\n" - "ldr d19, [x25, %[in_col_stride1]]\n" - "fadd v7.4s, v17.4s, v4.4s\n" - "ldr d22, [x25, x21]\n" - "mov v4.16b, v4.16b\n" - "ldr d14, [x25, x22]\n" - "mov v1.16b, v13.16b\n" - "ldr d18, [x25, x23]\n" - "fmul v11.4s, v11.4s, v0.s[0]\n" - "ldr d3, [x25, x24]\n" - "fadd v7.4s, v7.4s, v10.4s\n" - "ldr d16, [x13]\n" - "fmla v4.4s, v10.4s, v0.s[1]\n" - "ldr d21, [x13, %[in_col_stride1]]\n" - "fadd v2.4s, v19.4s, v22.4s\n" - "ldr d24, [x13, x21]\n" - "fadd v8.4s, v13.4s, v11.4s\n" - "ldr d25, [x13, x22]\n" - "fmla v1.4s, v11.4s, v0.s[1]\n" - "ldr d17, [x13, x23]\n" - "fadd v23.4s, v14.4s, v18.4s\n" - "ldr d9, [x13, x24]\n" - "fadd v11.4s, v20.4s, v2.4s\n" - "ldr d6, [x26]\n" - "fsub v15.4s, v19.4s, v22.4s\n" - "ldr d19, [x26, %[in_col_stride1]]\n" - "fadd v1.4s, v1.4s, v12.4s\n" - "ldr d22, [x26, x21]\n" - "fsub v31.4s, v14.4s, v18.4s\n" - "ldr d12, [x26, x22]\n" - "fadd v11.4s, v11.4s, v23.4s\n" - "ldr d26, [x26, x23]\n" - "mov v13.16b, v2.16b\n" - "ldr d5, [x26, x24]\n" - "mov v2.16b, v15.16b\n" - "ldr d10, [x14]\n" - "fmul v31.4s, v31.4s, v0.s[0]\n" - "add %[inptr0], %[inptr0], #8\n" - "fmla v13.4s, v23.4s, v0.s[1]\n" - "add x25, x25, #8\n" - "fadd v29.4s, v21.4s, v24.4s\n" - "add x13, x13, #8\n" - "fsub v14.4s, v21.4s, v24.4s\n" - "ldr d30, [x14, %[in_col_stride1]]\n" - "fadd v15.4s, v15.4s, v31.4s\n" - "add x26, x26, #8\n" - "fmla v2.4s, v31.4s, v0.s[1]\n" - "fadd v18.4s, v25.4s, v17.4s\n" - "fadd v27.4s, v16.4s, v29.4s\n" - "fsub v28.4s, v25.4s, v17.4s\n" - "mov v21.16b, v29.16b\n" - "fadd v20.4s, v19.4s, v22.4s\n" - "fsub v17.4s, v19.4s, v22.4s\n" - "ldr d31, [x14, x21]\n" - "fadd v2.4s, v2.4s, v3.4s\n" - "ldr d23, [x14, x22]\n" - "fadd v27.4s, v27.4s, v18.4s\n" - "fmul v28.4s, v28.4s, v0.s[0]\n" - "fmla v21.4s, v18.4s, v0.s[1]\n" - "fadd v29.4s, v12.4s, v26.4s\n" - "fadd v24.4s, v6.4s, v20.4s\n" - "fsub v16.4s, v12.4s, v26.4s\n" - "mov v6.16b, v20.16b\n" - "fadd v25.4s, v30.4s, v31.4s\n" - "fsub v22.4s, v30.4s, v31.4s\n" - "fadd v31.4s, v11.4s, v27.4s\n" - "fsub v12.4s, v11.4s, v27.4s\n" - "ldr d26, [x14, x23]\n" - "fadd v24.4s, v24.4s, v29.4s\n" - "fmul v16.4s, v16.4s, v0.s[0]\n" - "fmla v6.4s, v29.4s, v0.s[1]\n" - "mov v3.16b, v14.16b\n" - "fadd v20.4s, v14.4s, v28.4s\n" - "fadd v29.4s, v10.4s, v25.4s\n" - "mov v10.16b, v25.16b\n" - "fadd v25.4s, v7.4s, v31.4s\n" - "fmla v3.4s, v28.4s, v0.s[1]\n" - "fadd v14.4s, v23.4s, v26.4s\n" - "fsub v23.4s, v23.4s, v26.4s\n" - "mov v26.16b, v31.16b\n" - "fadd v31.4s, v15.4s, v20.4s\n" - "fsub v11.4s, v15.4s, v20.4s\n" - "fadd v20.4s, v17.4s, v16.4s\n" - "mov v7.16b, v17.16b\n" - "fadd v3.4s, v3.4s, v9.4s\n" - "ldr d18, [x14, x24]\n" - "fadd v29.4s, v29.4s, v14.4s\n" - "add x14, x14, #8\n" - "fmla v7.4s, v16.4s, v0.s[1]\n" - "ldr d19, [x27]\n" - "fmul v23.4s, v23.4s, v0.s[0]\n" - "fmla v10.4s, v14.4s, v0.s[1]\n" - "fadd v15.4s, v8.4s, v31.4s\n" - "mov v14.16b, v31.16b\n" - "fadd v28.4s, v24.4s, v29.4s\n" - "fsub v24.4s, v24.4s, v29.4s\n" - "fadd v7.4s, v7.4s, v5.4s\n" - "ldr d27, [x27, %[in_col_stride1]]\n" - "fadd v30.4s, v13.4s, v21.4s\n" - "fsub v9.4s, v13.4s, v21.4s\n" - "fadd v17.4s, v22.4s, v23.4s\n" - "mov v8.16b, v22.16b\n" - "fadd v25.4s, v25.4s, v28.4s\n" - "fmul v24.4s, v24.4s, v0.s[0]\n" - "fmla v26.4s, v28.4s, v0.s[1]\n" - "ldr d29, [x27, x21]\n" - "fmla v8.4s, v23.4s, v0.s[1]\n" - "ldr d28, [x27, x22]\n" - "fadd v13.4s, v4.4s, v30.4s\n" - "mov v4.16b, v30.16b\n" - "str d25, [%[outptr0]]\n" // Store output (0, 0) - "fadd v16.4s, v27.4s, v29.4s\n" - "str d26, [x28]\n" // Store output (2, 0) - "fsub v29.4s, v27.4s, v29.4s\n" - "fadd v8.4s, v8.4s, v18.4s\n" - "ldr d23, [x27, x23]\n" - "fadd v30.4s, v28.4s, v23.4s\n" - "ldr d25, [x27, x24]\n" - "fadd v19.4s, v19.4s, v16.4s\n" - "add x27, x27, #8\n" - "fsub v27.4s, v28.4s, v23.4s\n" - "mov v16.16b, v16.16b\n" - "fadd v22.4s, v20.4s, v17.4s\n" - "fsub v20.4s, v20.4s, v17.4s\n" - "fadd v21.4s, v12.4s, v24.4s\n" - "mov v26.16b, v12.16b\n" - "fadd v19.4s, v19.4s, v30.4s\n" - "fmla v16.4s, v30.4s, v0.s[1]\n" - "fmul v27.4s, v27.4s, v0.s[0]\n" - "mov v5.16b, v29.16b\n" - "fmla v26.4s, v24.4s, v0.s[1]\n" - "fadd v15.4s, v15.4s, v22.4s\n" - "str d21, [x17]\n" // Store output (1, 0) - "fmul v20.4s, v20.4s, v0.s[0]\n" - "fmla v14.4s, v22.4s, v0.s[1]\n" - "mov v28.16b, v11.16b\n" - "fadd v18.4s, v29.4s, v27.4s\n" - "fmla v5.4s, v27.4s, v0.s[1]\n" - "str d15, [%[outptr0], %[output_col_stride1]]\n" // Store output (0, 1) - "fadd v26.4s, v26.4s, v19.4s\n" - "fadd v12.4s, v11.4s, v20.4s\n" - "fmla v28.4s, v20.4s, v0.s[1]\n" - "str d14, [x28, %[output_col_stride1]]\n" // Store output (2, 1) - "fadd v21.4s, v6.4s, v10.4s\n" - "fadd v5.4s, v5.4s, v25.4s\n" - "fsub v10.4s, v6.4s, v10.4s\n" - "str d26, [x18]\n" // Store output (3, 0) - "mov v15.16b, v9.16b\n" - "str d12, [x17, %[output_col_stride1]]\n" // Store output (1, 1) - "fadd v28.4s, v28.4s, v18.4s\n" - "fadd v13.4s, v13.4s, v21.4s\n" - "fmla v4.4s, v21.4s, v0.s[1]\n" - "fmul v10.4s, v10.4s, v0.s[0]\n" - "fadd v6.4s, v2.4s, v3.4s\n" - "fadd v30.4s, v7.4s, v8.4s\n" - "fsub v2.4s, v2.4s, v3.4s\n" - "str d28, [x18, %[output_col_stride1]]\n" // Store output (3, 1) - "fsub v8.4s, v7.4s, v8.4s\n" - "str d13, [%[outptr0], x15]\n" // Store output (0, 2) - "str d4, [x28, x15]\n" // Store output (2, 2) - "fadd v13.4s, v9.4s, v10.4s\n" - "fmla v15.4s, v10.4s, v0.s[1]\n" - "fadd v1.4s, v1.4s, v6.4s\n" - "mov v6.16b, v6.16b\n" - "fmul v8.4s, v8.4s, v0.s[0]\n" - "mov v9.16b, v2.16b\n" - "str d13, [x17, x15]\n" // Store output (1, 2) - "fadd v15.4s, v15.4s, v16.4s\n" - "fadd v1.4s, v1.4s, v30.4s\n" - "fmla v6.4s, v30.4s, v0.s[1]\n" - "fadd v2.4s, v2.4s, v8.4s\n" - "fmla v9.4s, v8.4s, v0.s[1]\n" - "str d15, [x18, x15]\n" // Store output (3, 2) - "str d1, [%[outptr0], x16]\n" // Store output (0, 3) - "str d2, [x17, x16]\n" // Store output (1, 3) - "fadd v9.4s, v9.4s, v5.4s\n" - "str d6, [x28, x16]\n" // Store output (2, 3) - "add %[outptr0], %[outptr0], #8\n" - "add x17, x17, #8\n" - "add x28, x28, #8\n" - "str d9, [x18, x16]\n" // Store output (3, 3) - "add x18, x18, #8\n" - "5:\n" // Scalar - "cbz x20, 6f\n" - "ldr s17, [%[inptr0]]\n" - "ldr s23, [%[inptr0], %[in_col_stride1]]\n" - "ldr s27, [%[inptr0], x21]\n" - "fadd v4.4s, v23.4s, v27.4s\n" - "ldr s24, [%[inptr0], x22]\n" - "fsub v13.4s, v23.4s, v27.4s\n" - "ldr s11, [%[inptr0], x23]\n" - "fadd v10.4s, v24.4s, v11.4s\n" - "ldr s12, [%[inptr0], x24]\n" - "fsub v11.4s, v24.4s, v11.4s\n" - "ldr s20, [x25]\n" - "fadd v7.4s, v17.4s, v4.4s\n" - "ldr s19, [x25, %[in_col_stride1]]\n" - "mov v4.16b, v4.16b\n" - "ldr s22, [x25, x21]\n" - "mov v1.16b, v13.16b\n" - "ldr s14, [x25, x22]\n" - "fmul v11.4s, v11.4s, v0.s[0]\n" - "ldr s18, [x25, x23]\n" - "fadd v7.4s, v7.4s, v10.4s\n" - "ldr s3, [x25, x24]\n" - "fmla v4.4s, v10.4s, v0.s[1]\n" - "ldr s16, [x13]\n" - "fadd v2.4s, v19.4s, v22.4s\n" - "ldr s21, [x13, %[in_col_stride1]]\n" - "fadd v8.4s, v13.4s, v11.4s\n" - "ldr s24, [x13, x21]\n" - "fmla v1.4s, v11.4s, v0.s[1]\n" - "ldr s25, [x13, x22]\n" - "fadd v23.4s, v14.4s, v18.4s\n" - "ldr s17, [x13, x23]\n" - "fadd v11.4s, v20.4s, v2.4s\n" - "ldr s9, [x13, x24]\n" - "fsub v15.4s, v19.4s, v22.4s\n" - "ldr s6, [x26]\n" - "fadd v1.4s, v1.4s, v12.4s\n" - "ldr s19, [x26, %[in_col_stride1]]\n" - "fsub v31.4s, v14.4s, v18.4s\n" - "ldr s22, [x26, x21]\n" - "fadd v11.4s, v11.4s, v23.4s\n" - "ldr s12, [x26, x22]\n" - "mov v13.16b, v2.16b\n" - "ldr s26, [x26, x23]\n" - "mov v2.16b, v15.16b\n" - "ldr s5, [x26, x24]\n" - "fmul v31.4s, v31.4s, v0.s[0]\n" - "ldr s10, [x14]\n" - "fmla v13.4s, v23.4s, v0.s[1]\n" - "fadd v29.4s, v21.4s, v24.4s\n" - "fsub v14.4s, v21.4s, v24.4s\n" - "fadd v18.4s, v25.4s, v17.4s\n" - "fsub v28.4s, v25.4s, v17.4s\n" - "ldr s30, [x14, %[in_col_stride1]]\n" - "fadd v15.4s, v15.4s, v31.4s\n" - "fmla v2.4s, v31.4s, v0.s[1]\n" - "fadd v27.4s, v16.4s, v29.4s\n" - "mov v21.16b, v29.16b\n" - "fadd v20.4s, v19.4s, v22.4s\n" - "fsub v17.4s, v19.4s, v22.4s\n" - "fmul v28.4s, v28.4s, v0.s[0]\n" - "ldr s31, [x14, x21]\n" - "fadd v2.4s, v2.4s, v3.4s\n" - "ldr s23, [x14, x22]\n" - "fadd v27.4s, v27.4s, v18.4s\n" - "fmla v21.4s, v18.4s, v0.s[1]\n" - "fadd v29.4s, v12.4s, v26.4s\n" - "fadd v24.4s, v6.4s, v20.4s\n" - "fsub v16.4s, v12.4s, v26.4s\n" - "mov v6.16b, v20.16b\n" - "fadd v25.4s, v30.4s, v31.4s\n" - "fsub v22.4s, v30.4s, v31.4s\n" - "fadd v20.4s, v14.4s, v28.4s\n" - "mov v3.16b, v14.16b\n" - "fadd v24.4s, v24.4s, v29.4s\n" - "fmla v6.4s, v29.4s, v0.s[1]\n" - "fmul v16.4s, v16.4s, v0.s[0]\n" - "ldr s26, [x14, x23]\n" - "fmla v3.4s, v28.4s, v0.s[1]\n" - "fadd v14.4s, v23.4s, v26.4s\n" - "fadd v29.4s, v10.4s, v25.4s\n" - "fsub v23.4s, v23.4s, v26.4s\n" - "mov v10.16b, v25.16b\n" - "fadd v31.4s, v11.4s, v27.4s\n" - "fsub v12.4s, v11.4s, v27.4s\n" - "ldr s18, [x14, x24]\n" - "fadd v3.4s, v3.4s, v9.4s\n" - "ldr s19, [x27]\n" - "fadd v29.4s, v29.4s, v14.4s\n" - "fmul v23.4s, v23.4s, v0.s[0]\n" - "fmla v10.4s, v14.4s, v0.s[1]\n" - "fadd v25.4s, v7.4s, v31.4s\n" - "mov v26.16b, v31.16b\n" - "fadd v31.4s, v15.4s, v20.4s\n" - "fsub v11.4s, v15.4s, v20.4s\n" - "fadd v30.4s, v13.4s, v21.4s\n" - "fsub v9.4s, v13.4s, v21.4s\n" - "fadd v28.4s, v24.4s, v29.4s\n" - "fsub v24.4s, v24.4s, v29.4s\n" - "ldr s27, [x27, %[in_col_stride1]]\n" - "fadd v15.4s, v8.4s, v31.4s\n" - "mov v14.16b, v31.16b\n" - "fadd v13.4s, v4.4s, v30.4s\n" - "mov v4.16b, v30.16b\n" - "fadd v25.4s, v25.4s, v28.4s\n" - "fmla v26.4s, v28.4s, v0.s[1]\n" - "fmul v24.4s, v24.4s, v0.s[0]\n" - "fadd v21.4s, v6.4s, v10.4s\n" - "fsub v10.4s, v6.4s, v10.4s\n" - "fadd v6.4s, v2.4s, v3.4s\n" - "fsub v2.4s, v2.4s, v3.4s\n" - "ldr s29, [x27, x21]\n" - "str s25, [%[outptr0]]\n" // Store output (0, 0) - "fadd v20.4s, v17.4s, v16.4s\n" - "str s26, [x28]\n" // Store output (2, 0) - "mov v7.16b, v17.16b\n" - "fadd v17.4s, v22.4s, v23.4s\n" - "mov v8.16b, v22.16b\n" - "fadd v13.4s, v13.4s, v21.4s\n" - "fmul v10.4s, v10.4s, v0.s[0]\n" - "fmla v7.4s, v16.4s, v0.s[1]\n" - "ldr s28, [x27, x22]\n" - "fmla v8.4s, v23.4s, v0.s[1]\n" - "ldr s23, [x27, x23]\n" - "fmla v4.4s, v21.4s, v0.s[1]\n" - "ldr s25, [x27, x24]\n" - "str s13, [%[outptr0], x15]\n" // Store output (0, 2) - "fadd v16.4s, v27.4s, v29.4s\n" - "fadd v7.4s, v7.4s, v5.4s\n" - "fadd v30.4s, v28.4s, v23.4s\n" - "fadd v8.4s, v8.4s, v18.4s\n" - "fsub v29.4s, v27.4s, v29.4s\n" - "str s4, [x28, x15]\n" // Store output (2, 2) - "fsub v27.4s, v28.4s, v23.4s\n" - "fadd v19.4s, v19.4s, v16.4s\n" - "mov v16.16b, v16.16b\n" - "fadd v21.4s, v12.4s, v24.4s\n" - "mov v26.16b, v12.16b\n" - "mov v5.16b, v29.16b\n" - "fadd v22.4s, v20.4s, v17.4s\n" - "fmul v27.4s, v27.4s, v0.s[0]\n" - "fmla v16.4s, v30.4s, v0.s[1]\n" - "fadd v19.4s, v19.4s, v30.4s\n" - "fmla v26.4s, v24.4s, v0.s[1]\n" - "str s21, [x17]\n" // Store output (1, 0) - "fsub v20.4s, v20.4s, v17.4s\n" - "fadd v15.4s, v15.4s, v22.4s\n" - "fmla v14.4s, v22.4s, v0.s[1]\n" - "fadd v18.4s, v29.4s, v27.4s\n" - "fmla v5.4s, v27.4s, v0.s[1]\n" - "fadd v26.4s, v26.4s, v19.4s\n" - "mov v28.16b, v11.16b\n" - "fmul v20.4s, v20.4s, v0.s[0]\n" - "fadd v13.4s, v9.4s, v10.4s\n" - "str s15, [%[outptr0], %[output_col_stride1]]\n" // Store output (0, 1) - "mov v15.16b, v9.16b\n" - "str s14, [x28, %[output_col_stride1]]\n" // Store output (2, 1) - "fadd v5.4s, v5.4s, v25.4s\n" - "str s26, [x18]\n" // Store output (3, 0) - "fadd v30.4s, v7.4s, v8.4s\n" - "str s13, [x17, x15]\n" // Store output (1, 2) - "fadd v12.4s, v11.4s, v20.4s\n" - "fmla v28.4s, v20.4s, v0.s[1]\n" - "fmla v15.4s, v10.4s, v0.s[1]\n" - "fadd v1.4s, v1.4s, v6.4s\n" - "fsub v8.4s, v7.4s, v8.4s\n" - "mov v6.16b, v6.16b\n" - "mov v9.16b, v2.16b\n" - "str s12, [x17, %[output_col_stride1]]\n" // Store output (1, 1) - "fadd v28.4s, v28.4s, v18.4s\n" - "fadd v15.4s, v15.4s, v16.4s\n" - "fadd v1.4s, v1.4s, v30.4s\n" - "fmul v8.4s, v8.4s, v0.s[0]\n" - "fmla v6.4s, v30.4s, v0.s[1]\n" - "str s28, [x18, %[output_col_stride1]]\n" // Store output (3, 1) - "str s1, [%[outptr0], x16]\n" // Store output (0, 3) - "str s6, [x28, x16]\n" // Store output (2, 3) - "fadd v2.4s, v2.4s, v8.4s\n" - "str s15, [x18, x15]\n" // Store output (3, 2) - "fmla v9.4s, v8.4s, v0.s[1]\n" - "str s2, [x17, x16]\n" // Store output (1, 3) - "fadd v9.4s, v9.4s, v5.4s\n" - "str s9, [x18, x16]\n" // Store output (3, 3) - "6:\n" // End - : [outptr0] "+r" (output), [inptr0] "+r" (inptr) - : [output_col_stride1] "r" (output_col_stride * sizeof(float)), [pcoeffs] "r" (coeffs), [n_channels] "r" ((long) n_channels), [in_row_stride] "r" (6 * matrix_stride * sizeof(float)), [in_col_stride1] "r" (matrix_stride * sizeof(float)), [output_row_stride] "r" (output_row_stride * sizeof(float)) - : "cc", "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v2", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v31", "v4", "v5", "v6", "v7", "v8", "v9", "x13", "x14", "x15", "x16", "x17", "x18", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "memory" - ); - } -} - -#else - template <> void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots::Integers>::transform_tile( const int n_channels, @@ -1713,7 +36,9 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots const float* bptr, float* const output, const int output_row_stride, - const int output_col_stride + const int output_col_stride, + const float output_min, + const float output_max ) { // Construct a map to the output cells @@ -1728,7 +53,79 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots // For each channel of the output int channels_remaining = n_channels; -#ifdef __arm__ + +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used and computed during this transform + float32x4_t F[6][6], FZ[6][4], f[4][4], b; + + // Read a 6x6 tile in the Winograd domain + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + F[i][j] = vld1q_f32(inptr + m*matrix_stride); + } + } + inptr += 4; + + // Compute the matrix F Z + for (int i = 0; i < 6; i++) + { + // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; + FZ[i][0] = vaddq_f32(vaddq_f32(vaddq_f32(F[i][0], F[i][1]), vaddq_f32(F[i][2], F[i][3])), F[i][4]); + + // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4]; + FZ[i][1] = vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 2.0f); + + // FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4]; + FZ[i][2] = vmlaq_n_f32(vaddq_f32(F[i][1], F[i][2]), vaddq_f32(F[i][3], F[i][4]), 4.0f); + + // FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5]; + FZ[i][3] = vaddq_f32(vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 8.0f), F[i][5]); + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 4; j++) + { + // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; + f[0][j] = vaddq_f32(vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), vaddq_f32(FZ[2][j], FZ[3][j])), FZ[4][j]); + + // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j]; + f[1][j] = vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 2.0f); + + // f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j]; + f[2][j] = vmlaq_n_f32(vaddq_f32(FZ[1][j], FZ[2][j]), vaddq_f32(FZ[3][j], FZ[4][j]), 4.0f); + + // f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j]; + f[3][j] = vaddq_f32(vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 8.0f), FZ[5][j]); + } + + // Write out the output tile + if (bptr != nullptr) + { + b = vld1q_f32(bptr); + bptr += 4; + } + else + { + b = vdupq_n_f32(0.0f); + } + for (int i = 0; i < output_tile_rows; i++) + { + for (int j = 0; j < output_tile_cols; j++) + { + const auto y = + vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)), + vdupq_n_f32(output_min)); + vst1q_f32(outptrs[i][j], y); + outptrs[i][j] += 4; + } + } + } +#endif // __aarch64__ +#ifdef __arm_any__ for (; channels_remaining >= 2; channels_remaining -= 2) { // Matrices used and computed during this transform @@ -1790,12 +187,15 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots { for (int j = 0; j < output_tile_cols; j++) { - vst1_f32(outptrs[i][j], vadd_f32(f[i][j], b)); + const auto y = + vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)), + vdup_n_f32(output_min)); + vst1_f32(outptrs[i][j], y); outptrs[i][j] += 2; } } } -#endif // __arm__ +#endif // __arm_any__ for (; channels_remaining; channels_remaining--) { // Matrices used and computed during this transform @@ -1842,14 +242,13 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots { for (int j = 0; j < output_tile_cols; j++) { - *(outptrs[i][j]++) = f[i][j] + b; + const auto y = std::max(std::min(f[i][j] + b, output_max), output_min); + *(outptrs[i][j]++) = y; } } } } -#endif - template class OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots::Integers>; } // namespace winograd diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp index ce921cea01..05f06a81ee 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -36,7 +36,9 @@ void OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::transfo const float* bptr, float* const output, const int, // No need to stride across rows - const int output_col_stride + const int output_col_stride, + const float output_min, + const float output_max ) { // Construct a map to the output cells @@ -76,7 +78,9 @@ void OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::transfo } for (int j = 0; j < output_tile_cols; j++) { - vst1q_f32(outptrs[j], f[j] + b); + const auto y = vminq_f32(vmaxq_f32(f[j] + b, vdupq_n_f32(output_min)), + vdupq_n_f32(output_max)); + vst1q_f32(outptrs[j], y); outptrs[j] += 4; } } @@ -107,7 +111,9 @@ void OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::transfo } for (int j = 0; j < output_tile_cols; j++) { - vst1_f32(outptrs[j], f[j] + b); + const auto y = vmin_f32(vmax_f32(f[j] + b, vdup_n_f32(output_min)), + vdup_n_f32(output_max)); + vst1_f32(outptrs[j], y); outptrs[j] += 2; } } @@ -138,7 +144,7 @@ void OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::transfo } for (int j = 0; j < output_tile_cols; j++) { - *(outptrs[j]++) = f[j] + b; + *(outptrs[j]++) = std::max(std::min(f[j] + b, output_max), output_min); } } } diff --git a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp index e699ad1815..6983c1c01b 100644 --- a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp @@ -33,6 +33,7 @@ #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" #include "support/ToolchainSupport.h" +#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp" #include "arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp" namespace arm_compute @@ -232,6 +233,31 @@ bool check_support_fast_math(const Size2D &output_tile, const Size2D &kernel_siz return std::find(fast_math_winograd.begin(), fast_math_winograd.end(), p) != fast_math_winograd.end(); } +inline bool fuse_function_supported(const ActivationLayerInfo &act_info) +{ + return act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU || + act_info.activation() == ActivationLayerInfo::ActivationFunction::BOUNDED_RELU; +} + +arm_gemm::Activation arm_gemm_activation_from_acl_activation(const ActivationLayerInfo &act_info) +{ + switch(act_info.activation()) + { + case ActivationLayerInfo::ActivationFunction::RELU: + { + return arm_gemm::Activation(arm_gemm::Activation::Type::ReLU, act_info.a(), act_info.b()); + } + case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU: + { + return arm_gemm::Activation(arm_gemm::Activation::Type::BoundedReLU, act_info.a(), act_info.b()); + } + default: + { + return arm_gemm::Activation(arm_gemm::Activation::Type::None); + } + } +} + } //namespace NEWinogradConvolutionLayer::NEWinogradConvolutionLayer(const std::shared_ptr &memory_manager) @@ -257,6 +283,8 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor * const Size2D kernel_size = Size2D(weights->info()->dimension(width_idx), weights->info()->dimension(height_idx)); const Size2D output_tile = winograd_output_tile(input_dims, kernel_size); + + // Check if the Winograd configuration requires fast math if(!enable_fast_math) { @@ -388,21 +416,15 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor * * data_type_size; // Output storage - const size_t output_storage_size = transform_output_kernel->get_output_storage_size(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels, - use_same_padding) - * data_type_size; - ; - const KernelShape kernel_shape({ out_channels, static_cast(kernel_size.height), static_cast(kernel_size.width), in_channels }); - const int kernel_matrix_stride = transform_weights_kernel->get_matrix_stride(kernel_shape); - - const int output_matrix_stride = transform_output_kernel->get_matrix_stride(kernel_shape, in_shape, use_padding_type); - const auto output_shape(transform_output_kernel->get_output_shape(kernel_shape, in_shape, use_padding_type)); - - const int input_matrix_stride = transform_input_kernel->get_matrix_stride(kernel_shape, in_shape, use_padding_type); + const size_t output_storage_size = transform_output_kernel->get_output_storage_size(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels) * data_type_size; + const int kernel_matrix_stride = transform_weights_kernel->get_matrix_stride(out_channels, in_channels); + const int output_matrix_stride = transform_output_kernel->get_matrix_stride(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels); + const auto output_shape = transform_output_kernel->get_output_shape(in_shape.n_rows, in_shape.n_cols, use_padding_type == PADDING_SAME); + const int input_matrix_stride = transform_input_kernel->get_matrix_stride(in_shape.n_batches, in_channels, in_shape.n_rows, in_shape.n_cols, use_padding_type == PADDING_SAME); // Configure GEMM - const int tile_rows = iceildiv(output_shape.n_rows, output_tile.height); - const int tile_cols = iceildiv(output_shape.n_cols, output_tile.width); + const int tile_rows = iceildiv(output_shape.first, output_tile.height); + const int tile_cols = iceildiv(output_shape.second, output_tile.width); const int m = in_shape.n_batches * tile_rows * tile_cols; const int k = in_shape.n_channels; const int n = out_channels; @@ -489,9 +511,19 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor * _memory_group.manage(&_output_nhwc); output_to_use = &_output_nhwc; } - transform_output_kernel->configure(biases, &_output_transformed, - output_matrix_stride, output_to_use, - in_shape.n_batches, output_shape.n_rows, output_shape.n_cols, out_channels, &_output_workspace); + const arm_gemm::Activation activation = arm_gemm_activation_from_acl_activation(act_info); + + transform_output_kernel->configure(biases, + &_output_transformed, + output_matrix_stride, + output_to_use, + in_shape.n_batches, + output_shape.first, + output_shape.second, + out_channels, + &_output_workspace, + activation); + const size_t output_workspace_size = transform_output_kernel->get_working_space_size(max_num_threads); TensorInfo output_workspace_info(TensorShape(output_workspace_size), 1, _output->info()->data_type()); _output_workspace.allocator()->init(output_workspace_info); @@ -510,7 +542,7 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor * _transform_output_kernel = std::move(transform_output_kernel); //Configure Activation Layer - _is_activationlayer_enabled = act_info.enabled(); + _is_activationlayer_enabled = act_info.enabled() && ! fuse_function_supported(act_info); if(_is_activationlayer_enabled) { _activationlayer_function.configure(_output, nullptr, act_info); @@ -546,7 +578,7 @@ void NEWinogradConvolutionLayer::run() _permute_output.run(); } - if(_is_activationlayer_enabled) + if(_is_activationlayer_enabled ) { _activationlayer_function.run(); } -- cgit v1.2.1