aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/winograd/transforms
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/winograd/transforms')
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp638
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp1498
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp961
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp195
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp822
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp356
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp650
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp655
8 files changed, 0 insertions, 5775 deletions
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp
deleted file mode 100644
index 7013c66ac0..0000000000
--- a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp
+++ /dev/null
@@ -1,638 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-#include "../tensor.hpp"
-
-namespace winograd {
- /* Transform an input tensor into the Winograd domain.
- */
- template <typename T>
- struct Winograd2x2_3x3GemmInput {
- static void execute(
- const T *inptr,
- const Tensor4DShape& input_shape,
- const PaddingType padding_type,
- const int tile_M,
- const int tile_N,
- T *outptr_base,
- const int matrix_stride,
- const int matrix_batch_stride,
- const int matrix_row_stride
- );
-
- static size_t bytes_read(const Tensor4DShape &input_shape,
- const Tensor4DShape &output_shape) {
- const int tile_rows = iceildiv(output_shape.n_rows, 2);
- const int tile_cols = iceildiv(output_shape.n_cols, 2);
- return input_shape.n_batches * tile_rows * (16 + 8*(tile_cols - 1)) * input_shape.n_channels * sizeof(T);
- }
-
- static int flops_performed(const Tensor4DShape &input_shape,
- const Tensor4DShape &output_shape) {
- const int tile_rows = iceildiv(output_shape.n_rows, 2);
- const int tile_cols = iceildiv(output_shape.n_cols, 2);
- return input_shape.n_batches * tile_rows * (32 + 24*(tile_cols - 1)) * input_shape.n_channels;
- }
-
- static size_t bytes_written(const Tensor4DShape &input_shape,
- const Tensor4DShape &output_shape) {
- const int tile_rows = iceildiv(output_shape.n_rows, 2);
- const int tile_cols = iceildiv(output_shape.n_cols, 2);
- const int M = input_shape.n_batches * tile_rows * tile_cols;
- return 16 * M * input_shape.n_channels * sizeof(T);
- }
-
- protected:
- template <const PaddingType padding, const int pad_bottom, const int pad_right>
- static void process_tile_tensor(
- const int tile_M, // Number of rows of tiles
- const int tile_N, // Number of columns of tiles
- int n_channels, // Number of input channels
- const T* const input, // Base input pointer (appropriate to batch and channel)
- const int input_row_stride, // Stride between rows of the input
- const int input_col_stride, // Stride between columns of the input
- T* const matrix, // 1st output matrix (appropriate to batch and channel)
- const int matrix_stride, // Stride between matrices
- const int matrix_row_stride // Stride between rows of the output matrix
- );
-
- template <const int pad_top, const int pad_left,
- const int pad_bottom, const int pad_right,
- const int proc_channels>
- static void process_tile_row(
- const int tile_N, // Number of tiles in the row
- const T* const input, // Base input pointer (appropriate to batch, channel and row)
- const int input_row_stride, // Stride between rows of the input
- const int input_col_stride, // Stride between columns of the input
- T* const matrix, // 1st output matrix (appropriate to batch, channel and row)
- const int matrix_stride, // Stride between matrices
- const int matrix_row_stride // Stride between rows of the output matrix
- );
- };
-
- template <typename T>
- struct Winograd2x2_3x3GemmInputChannelwise {
- static void execute(
- const T *inptr,
- const Tensor4DShape& input_shape,
- const PaddingType padding_type,
- const int tile_M,
- const int tile_N,
- T *outptr_base,
- const int matrix_stride,
- const int matrix_batch_stride,
- const int matrix_row_stride
- );
-
- static size_t bytes_read(const Tensor4DShape &input_shape,
- const Tensor4DShape &output_shape) {
- // We read as many bytes as we write
- return bytes_written(input_shape, output_shape);
- }
-
- static int flops_performed(const Tensor4DShape &input_shape,
- const Tensor4DShape &output_shape) {
- const int tile_rows = iceildiv(output_shape.n_rows, 2);
- const int tile_cols = iceildiv(output_shape.n_cols, 2);
- return input_shape.n_batches * tile_rows * 32 * tile_cols * input_shape.n_channels;
- }
-
- static size_t bytes_written(const Tensor4DShape &input_shape,
- const Tensor4DShape &output_shape) {
- return winograd::Winograd2x2_3x3GemmInput<T>::bytes_written(input_shape, output_shape);
- }
-
- protected:
- typedef void (*tilefunc)(int, const T*, int, int, T*, int);
- template <const int pad_top,
- const int pad_left,
- const int pad_bottom,
- const int pad_right>
- static void process_tile(
- int n_channels, // Number of channels in the tile
- const T* const input_base,
- const int input_row_stride,
- const int input_col_stride,
- T* const matrix_base,
- const int matrix_stride
- );
-
- private:
- template <const int pad_top,
- const int pad_left,
- const int pad_bottom,
- const int pad_right,
- const int proc_channels>
- static void _process_tile(
- int &n_channels, const T* &inptr,
- const int input_row_stride, const int input_col_stride,
- T* &outptr, const int matrix_stride
- );
- };
-}
-
-/*****************************************************************************/
-// Include specialised implementations here
-#include "input_2x2_3x3/a64_float.hpp"
-#include "input_2x2_3x3/a64_float_channelwise.hpp"
-/*****************************************************************************/
-
-/*****************************************************************************/
-template <typename T>
-void winograd::Winograd2x2_3x3GemmInput<T>::execute(
- const T *inptr_base,
- const Tensor4DShape& input_shape,
- const PaddingType padding_type,
- const int tile_M,
- const int tile_N,
- T *outptr_base,
- const int matrix_stride,
- const int matrix_batch_stride,
- const int matrix_row_stride
-) {
- // Select an appropriate matrix processing method for the shape and padding
- // of the input tensor.
- typedef void (*tensorfunc)(int, int, int, const T*, int, int, T*, int, int);
- const auto process_tensor = [&padding_type, &input_shape] () -> tensorfunc {
- if (padding_type == PADDING_VALID) {
- const int pad_bottom = input_shape.n_rows % 2;
- const int pad_right = input_shape.n_cols % 2;
-
- if (pad_bottom == 0 && pad_right == 0) {
- return process_tile_tensor<PADDING_VALID, 0, 0>;
- } else if (pad_bottom == 0 && pad_right == 1) {
- return process_tile_tensor<PADDING_VALID, 0, 1>;
- } else if (pad_bottom == 1 && pad_right == 0) {
- return process_tile_tensor<PADDING_VALID, 1, 0>;
- } else if (pad_bottom == 1 && pad_right == 1) {
- return process_tile_tensor<PADDING_VALID, 1, 1>;
- }
- } else { // PADDING_SAME
- const int pad_bottom = 1 + input_shape.n_rows % 2;
- const int pad_right = 1 + input_shape.n_cols % 2;
-
- if (pad_bottom == 1 && pad_right == 1) {
- return process_tile_tensor<PADDING_SAME, 1, 1>;
- } else if (pad_bottom == 1 && pad_right == 2) {
- return process_tile_tensor<PADDING_SAME, 1, 2>;
- } else if (pad_bottom == 2 && pad_right == 1) {
- return process_tile_tensor<PADDING_SAME, 2, 1>;
- } else if (pad_bottom == 2 && pad_right == 2) {
- return process_tile_tensor<PADDING_SAME, 2, 2>;
- }
- }
-
- printf("%s::%u Uncovered case.\n", __FILE__, __LINE__);
- exit(-1);
- return NULL; // No function found
- } ();
-
- // Compute strides
- const int input_row_stride = input_shape.n_cols * input_shape.n_channels;
- const int input_col_stride = input_shape.n_channels;
-
- // Process each batch of the tensor in turn.
- for (int batch = 0; batch < input_shape.n_batches; batch++) {
- // Work out pointers
- const T *inptr = inptr_base + (batch * input_shape.n_rows *
- input_shape.n_cols * input_shape.n_channels);
- T *outptr = outptr_base + batch * matrix_batch_stride;
-
- // Delegate doing the actual work
- process_tensor(
- tile_M, tile_N, input_shape.n_channels,
- inptr, input_row_stride, input_col_stride,
- outptr, matrix_stride, matrix_row_stride
- );
- }
-}
-
-/*****************************************************************************/
-template <typename T>
-template <const PaddingType padding, const int pad_bottom, const int pad_right>
-void winograd::Winograd2x2_3x3GemmInput<T>::process_tile_tensor(
- const int tile_M, // Number of rows of tiles
- const int tile_N, // Number of columns of tiles
- int n_channels, // Number of input channels
- const T* const input, // Base input pointer (appropriate to batch and channel)
- const int input_row_stride, // Stride between rows of the input
- const int input_col_stride, // Stride between columns of the input
- T* const matrix, // 1st output matrix (appropriate to batch and channel)
- const int matrix_stride, // Stride between matrices
- const int matrix_row_stride // Stride between rows of the output matrix
-) {
- // Base row processing functions
- typedef void (*rowfunc)(int, const T*, int, int, T*, int, int);
- const rowfunc process_top_row[3] = {
- (padding == PADDING_VALID)
- ? process_tile_row<0, 0, 0, pad_right, 1>
- : process_tile_row<1, 1, 0, pad_right, 1>,
- (padding == PADDING_VALID)
- ? process_tile_row<0, 0, 0, pad_right, 2>
- : process_tile_row<1, 1, 0, pad_right, 2>,
- (padding == PADDING_VALID)
- ? process_tile_row<0, 0, 0, pad_right, 4>
- : process_tile_row<1, 1, 0, pad_right, 4>,
- };
- const rowfunc process_middle_row[3] = {
- (padding == PADDING_VALID)
- ? process_tile_row<0, 0, 0, pad_right, 1>
- : process_tile_row<0, 1, 0, pad_right, 1>,
- (padding == PADDING_VALID)
- ? process_tile_row<0, 0, 0, pad_right, 2>
- : process_tile_row<0, 1, 0, pad_right, 2>,
- (padding == PADDING_VALID)
- ? process_tile_row<0, 0, 0, pad_right, 4>
- : process_tile_row<0, 1, 0, pad_right, 4>,
- };
- const rowfunc process_bottom_row[3] = {
- (padding == PADDING_VALID)
- ? process_tile_row<0, 0, pad_bottom, pad_right, 1>
- : process_tile_row<0, 1, pad_bottom, pad_right, 1>,
- (padding == PADDING_VALID)
- ? process_tile_row<0, 0, pad_bottom, pad_right, 2>
- : process_tile_row<0, 1, pad_bottom, pad_right, 2>,
- (padding == PADDING_VALID)
- ? process_tile_row<0, 0, pad_bottom, pad_right, 4>
- : process_tile_row<0, 1, pad_bottom, pad_right, 4>,
- };
-
- // Method to get an input pointer for the given tile row
- const auto get_inptr = [&input, &input_row_stride] (const int tile_i) {
- if (padding == PADDING_VALID) {
- return input + 2 * tile_i * input_row_stride;
- } else {
- return input + (2 * tile_i - (tile_i ? 1 : 0)) * input_row_stride;
- }
- };
-
- // Wrapper to process a row of tiles, covering all channels.
- const auto process_row =
- [tile_N, input_row_stride, input_col_stride, matrix_stride, matrix_row_stride, n_channels]
- (const rowfunc f[3], const T *inptr, T *outptr) {
- int rem_channels = n_channels;
-
- // While there remain channels to process continue to process the
- // row.
- for (; rem_channels >= 4; rem_channels -= 4, inptr += 4, outptr += 4) {
- f[2](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride);
- }
- for (; rem_channels >= 2; rem_channels -= 2, inptr += 2, outptr += 2) {
- f[1](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride);
- }
- if (rem_channels) {
- f[0](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride);
- }
- };
-
- // Process all rows of tiles in the tensor
- for (int tile_i = 0; tile_i < tile_M; tile_i++) {
- T* const m_row = matrix + tile_i * tile_N * matrix_row_stride;
- const T *row_inptr = get_inptr(tile_i);
-
- if (tile_i == 0) {
- // Top row of the input
- process_row(process_top_row, row_inptr, m_row);
- } else if (tile_i == tile_M - 1) {
- // Bottom row of the input
- process_row(process_bottom_row, row_inptr, m_row);
- } else {
- // Any other row of the input
- process_row(process_middle_row, row_inptr, m_row);
- }
- }
-}
-
-/*****************************************************************************/
-template <typename T>
-template <const int pad_top, const int pad_left,
- const int pad_bottom, const int pad_right,
- const int proc_channels>
-void winograd::Winograd2x2_3x3GemmInput<T>::process_tile_row(
- const int tile_N, // Number of tiles in the row
- const T* const input, // Base input pointer (appropriate to batch, channel and row)
- const int input_row_stride, // Stride between rows of the input
- const int input_col_stride, // Stride between columns of the input
- T* const matrix, // 1st output matrix (appropriate to batch, channel and row)
- const int matrix_stride, // Stride between matrices
- const int matrix_row_stride // Stride between rows of the output matrix
-) {
- // Construct copies of the pointers
- const T *inptr = input;
- T *outptr = matrix;
-
- // Storage for the tensors x, X.T x, and X.T x X.
- T x[4][4][proc_channels], XTx[4][4][proc_channels], XTxX[4][4][proc_channels];
-
- // For every tile in the row
- for (int tile_j = 0; tile_j < tile_N; tile_j++) {
- // Determine the padding for the tile
- const int tile_pad_left = (tile_j == 0) ? pad_left : 0;
- const int tile_pad_right = (tile_j == tile_N - 1) ? pad_right : 0;
-
- // Load tile values. If this is the first tile in the row then we must load
- // all values, otherwise we can just load the final two columns of the input.
- for (int i = 0; i < 4; i++) {
- for (int j = ((tile_j == 0) ? 0 : 2); j < 4; j++) {
- // Fill with padding if required
- if (i < pad_top || 4 - pad_bottom <= i ||
- j < tile_pad_left || 4 - tile_pad_right <= j) {
- for (int c = 0; c < proc_channels; c++) {
- x[i][j][c] = static_cast<T>(0); // Padding
- }
- } else {
- // Load values, note that the initial padding offsets the pointer we
- // were provided.
- for (int c = 0; c < proc_channels; c++) {
- const int row_offset = (i - pad_top) * input_row_stride;
- const int col_offset = (j - tile_pad_left) * input_col_stride;
- x[i][j][c] = inptr[row_offset + col_offset + c];
- }
- }
- }
- }
-
- // Compute the matrix X.T x. Note, can elide operations depending on the
- // padding. Furthermore, if this isn't the left-most tile we can skip half
- // of the operations by copying results from the previous version of X.T x.
- // This latter optimisation can be simplified by unrolling the outermost
- // loop by two and by renaming the registers containing XTx.
- if (tile_j == 0) {
- for (int j = 0; j < 4; j++) {
- for (int c = 0; c < proc_channels; c++) {
- XTx[0][j][c] = x[0][j][c] - x[2][j][c];
- XTx[1][j][c] = x[1][j][c] + x[2][j][c];
- XTx[2][j][c] = -x[1][j][c] + x[2][j][c];
- XTx[3][j][c] = x[1][j][c] - x[3][j][c];
- }
- }
- } else {
- for (int j = 0; j < 2; j++) {
- for (int c = 0; c < proc_channels; c++) {
- XTx[0][j][c] = XTx[0][j + 2][c];
- XTx[1][j][c] = XTx[1][j + 2][c];
- XTx[2][j][c] = XTx[2][j + 2][c];
- XTx[3][j][c] = XTx[3][j + 2][c];
- }
- }
- for (int j = 2; j < 4; j++) {
- for (int c = 0; c < proc_channels; c++) {
- XTx[0][j][c] = x[0][j][c] - x[2][j][c];
- XTx[1][j][c] = x[1][j][c] + x[2][j][c];
- XTx[2][j][c] = -x[1][j][c] + x[2][j][c];
- XTx[3][j][c] = x[1][j][c] - x[3][j][c];
- }
- }
- }
-
- // Compute the matrix X.T x X. Note, can elide operations based on the
- // padding.
- for (int i = 0; i < 4; i++) {
- for (int c = 0; c < proc_channels; c++) {
- XTxX[i][0][c] = XTx[i][0][c] - XTx[i][2][c];
- XTxX[i][1][c] = XTx[i][1][c] + XTx[i][2][c];
- XTxX[i][2][c] = -XTx[i][1][c] + XTx[i][2][c];
- XTxX[i][3][c] = XTx[i][1][c] - XTx[i][3][c];
- }
- }
-
- // Store the output matrix (X.T x X)
- for (int i = 0; i < 4; i++) {
- for (int j = 0; j < 4; j++) {
- // Get a pointer to the relevant output matrix
- T *mptr = outptr + (i*4 + j)*matrix_stride;
-
- // Write out the channels
- for (int c = 0; c < proc_channels; c++) {
- mptr[c] = XTxX[i][j][c];
- }
- }
- }
-
- // Update the pointers
- inptr += input_col_stride * ((tile_j == 0 && pad_left) ? 1 : 2);
- outptr += matrix_row_stride;
- }
-}
-
-/*****************************************************************************/
-template <typename T>
-void winograd::Winograd2x2_3x3GemmInputChannelwise<T>::execute(
- const T *inptr,
- const Tensor4DShape& input_shape,
- const PaddingType padding_type,
- const int tile_M,
- const int tile_N,
- T *outptr_base,
- const int matrix_stride,
- const int matrix_batch_stride,
- const int matrix_row_stride
-) {
- const int n_channels = input_shape.n_channels;
- const int input_col_stride = n_channels;
- const int input_row_stride = input_shape.n_cols * input_col_stride;
-
- // Determine the padding and hence select appropriate methods for each tile.
- tilefunc fs[3][3];
-
- if (padding_type == PADDING_VALID) {
- constexpr int pad_top = 0;
- constexpr int pad_left = 0;
- const int pad_right = input_shape.n_cols % 2 == 0;
-
- fs[0][0] = process_tile<pad_top, pad_left, 0, 0>;
- fs[0][1] = process_tile<pad_top, 0, 0, 0>;
- fs[0][2] = (pad_right) ? process_tile<pad_top, 0, 0, 0> : process_tile<pad_top, 0, 0, 1>;
-
- fs[1][0] = process_tile<0, pad_left, 0, 0>;
- fs[1][1] = process_tile<0, 0, 0, 0>;
- fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 0> : process_tile<0, 0, 0, 1>;
-
- if (input_shape.n_rows % 2 == 0) {
- constexpr int pad_bottom = 0;
- fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>;
- fs[2][1] = process_tile<0, 0, pad_bottom, 0>;
- fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>;
- } else {
- constexpr int pad_bottom = 1;
- fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>;
- fs[2][1] = process_tile<0, 0, pad_bottom, 0>;
- fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>;
- }
- } else {
- constexpr int pad_top = 1;
- constexpr int pad_left = 1;
- const int pad_right = input_shape.n_cols % 2 == 0;
-
- fs[0][0] = process_tile<pad_top, pad_left, 0, 0>;
- fs[0][1] = process_tile<pad_top, 0, 0, 0>;
- fs[0][2] = (pad_right) ? process_tile<pad_top, 0, 0, 1> : process_tile<pad_top, 0, 0, 2>;
-
- fs[1][0] = process_tile<0, pad_left, 0, 0>;
- fs[1][1] = process_tile<0, 0, 0, 0>;
- fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 1> : process_tile<0, 0, 0, 2>;
-
- if (input_shape.n_rows % 2 == 0) {
- constexpr int pad_bottom = 1;
- fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>;
- fs[2][1] = process_tile<0, 0, pad_bottom, 0>;
- fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>;
- } else {
- constexpr int pad_bottom = 2;
- fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>;
- fs[2][1] = process_tile<0, 0, pad_bottom, 0>;
- fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>;
- }
- }
-
- // Process each tile in turn
- for (int batch = 0; batch < input_shape.n_batches; batch++) {
- const T* const input_base_batch = inptr + batch*input_shape.n_rows*input_shape.n_cols*n_channels;
-
- for (int tile_i = 0; tile_i < tile_M; tile_i++) {
- const int row_offset = (tile_i == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1);
- const T* const input_base_row = input_base_batch + (2*tile_i - row_offset)*input_shape.n_cols*n_channels;
-
- // Select the set of functions for the row
- const int fs_i = (tile_i == 0) ? 0 : ((tile_i < tile_M - 1) ? 1 : 2);
-
- for (int tile_j = 0; tile_j < tile_N; tile_j++) {
- // Select the function for the column
- const int fs_j = (tile_j == 0) ? 0 : ((tile_j < tile_N - 1) ? 1 : 2);
- const auto f = fs[fs_i][fs_j];
-
- // Get pointers into the input and outputs
- const int col_offset = (tile_j == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1);
- const T* const input_base_col = input_base_row + (2*tile_j - col_offset)*n_channels;
- T* const matrix_base = outptr_base + batch*matrix_batch_stride + (tile_i*tile_N + tile_j)*matrix_row_stride;
- f(n_channels, input_base_col, input_row_stride, input_col_stride,
- matrix_base, matrix_stride);
- }
- }
- }
-}
-
-template <typename T>
-template <const int pad_top,
- const int pad_left,
- const int pad_bottom,
- const int pad_right>
-void winograd::Winograd2x2_3x3GemmInputChannelwise<T>::process_tile(
- int n_channels, // Number of channels in the tile
- const T* const input_base,
- const int input_row_stride,
- const int input_col_stride,
- T* const matrix_base,
- const int matrix_stride
-) {
- // Copy pointers
- const T *inptr = input_base;
- T *outptr = matrix_base;
-
- // Process channels (modifies inptr, outptr and n_channels)
- _process_tile<pad_top, pad_left, pad_bottom, pad_right, 4>(
- n_channels, inptr, input_row_stride, input_col_stride,
- outptr, matrix_stride
- );
- _process_tile<pad_top, pad_left, pad_bottom, pad_right, 2>(
- n_channels, inptr, input_row_stride, input_col_stride,
- outptr, matrix_stride
- );
- _process_tile<pad_top, pad_left, pad_bottom, pad_right, 1>(
- n_channels, inptr, input_row_stride, input_col_stride,
- outptr, matrix_stride
- );
-}
-
-template <typename T>
-template <const int pad_top,
- const int pad_left,
- const int pad_bottom,
- const int pad_right,
- const int proc_channels>
-void winograd::Winograd2x2_3x3GemmInputChannelwise<T>::_process_tile(
- int &n_channels,
- const T* &inptr, const int input_row_stride, const int input_col_stride,
- T* &outptr, const int matrix_stride
-) {
- // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
- // offsets to access the intermediate matrices.
- T* outptrs[4] = {
- outptr,
- outptr + matrix_stride * 4,
- outptr + matrix_stride * 8,
- outptr + matrix_stride * 12
- };
-
- // The matrix X; zeroed to account for padding.
- T x[4][4];
- for (int i = 0; i < 4; i++) {
- for (int j = 0; j < 4; j++) {
- x[i][j] = 0;
- }
- }
-
- // The matrices X.T x and U
- T XTx[4][4], U[4][4];
-
- // Now progress through each channel
- for (; n_channels >= proc_channels; n_channels -= proc_channels) {
- for (int n = 0; n < proc_channels; n++) {
- // Load the matrix X
- for (int cell_i = pad_top, i = 0; cell_i < 4 - pad_bottom; cell_i++, i++) {
- for (int cell_j = pad_left, j = 0; cell_j < 4 - pad_right; cell_j++, j++) {
- x[cell_i][cell_j] = inptr[i*input_row_stride + j*input_col_stride];
- }
- }
- inptr++;
-
- // Compute the matrix X.T
- for (int j = 0; j < 4; j++) {
- XTx[0][j] = x[0][j] - x[2][j];
- XTx[1][j] = x[1][j] + x[2][j];
- XTx[2][j] = x[2][j] - x[1][j];
- XTx[3][j] = x[1][j] - x[3][j];
- }
-
- // Hence compute the matrix U
- for (int i = 0; i < 4; i++) {
- U[i][0] = XTx[i][0] - XTx[i][2];
- U[i][1] = XTx[i][1] + XTx[i][2];
- U[i][2] = XTx[i][2] - XTx[i][1];
- U[i][3] = XTx[i][1] - XTx[i][3];
- }
-
- // Store the matrix U
- for (int i = 0; i < 4; i++) {
- for (int j = 0; j < 4; j++) {
- outptrs[i][j * matrix_stride] = U[i][j];
- }
- outptrs[i]++;
- }
- }
- }
-
- // Update the output pointer for future calls
- outptr = outptrs[0];
-}
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp
deleted file mode 100644
index a99cbe325b..0000000000
--- a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp
+++ /dev/null
@@ -1,1498 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-#include "../input_2x2_3x3.hpp"
-
-#ifdef __aarch64__
-namespace winograd {
-
-// Pad left by one column, pad right by one column, no upper or lower padding, 4 channels
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInput<float>::process_tile_row<0, 1, 0, 1, 4>(
- const int tile_N, // Number of tiles in the row
- const float* const input, // Base input pointer (appropriate to batch, channel and row)
- const int input_row_stride, // Stride between rows of the input
- const int input_col_stride, // Stride between columns of the input
- float* const matrix, // 1st output matrix (appropriate to batch, channel and row)
- const int matrix_stride, // Stride between matrices
- const int matrix_row_stride // Stride between rows of the output matrix
-) {
- /* SIMD register allocation
- * ========================
- *
- * In the following code we read 4x4 tiles of a matrix `x`, with which we
- * compute another matrix `X.T x` where:
- *
- * / 1 0 0 0 \
- * X = | 0 1 -1 1 |
- * | -1 1 1 0 |
- * \ 0 0 0 -1 /
- *
- * Hence, `X.T` is a program which operates upon rows of the matrix `X`.
- * We subsequently compute and store the matrix `U = (X.T x) X`.
- *
- * Importantly, each iteration of the loop below loads a new matrix `x'`
- * where the final two columns of `x'` are the first two columns of the
- * previous `x`. That is:
- *
- * x11 x12 x13 x14
- * x21 x22 x23 x24
- * x31 x32 x33 x34
- * x41 x42 x43 x44
- *
- * x'11 x'12 x'13 x'14
- * x'21 x'22 x'23 x'24
- * x'31 x'32 x'33 x'34
- * x'41 x'42 x'43 x'44
- *
- * Consequently, while the first iteration of the below loop must load 16
- * values for `x`, the second need load only 8. *Furthermore*, since we noted
- * above that the operation `X.T x` was a program which operated upon *rows*
- * of the matrix `x` it follows that that the relation that `x'[i][1] =
- * x[i][3]` and `x'[i][2] = x[i][4]` applies also the matrices `X.T x'` and
- * `X.T x`. That is:
- *
- * (X.T x)11 (X.T x)12 (X.T x)13 (X.T x)14
- * (X.T x)21 (X.T x)22 (X.T x)23 (X.T x)24
- * (X.T x)31 (X.T x)32 (X.T x)33 (X.T x)34
- * (X.T x)41 (X.T x)42 (X.T x)43 (X.T x)44
- *
- * (X.T x')11 (X.T x')12 (X.T x')13 (X.T x')14
- * (X.T x')12 (X.T x')12 (X.T x')12 (X.T x')12
- * (X.T x')13 (X.T x')13 (X.T x')13 (X.T x')13
- * (X.T x')14 (X.T x')14 (X.T x')14 (X.T x')14
- *
- * Hence, as well as not needing to load new values for x'[i][1..2] it is
- * also unnecessary to recompute values for (X.T x')[i][1..2].
- *
- * Following this we break the registers into blocks `A` and `B` used by the
- * two stages of the unrolled loop. These registers named such that the
- * latter columns of `A` become the earlier columns of `B` and vice-versa:
- *
- * AXTx11 AXTx12 > AXTx13 AXTx14 |
- * AXTx21 AXTx22 > AXTx23 AXTx24 |
- * AXTx31 AXTx32 > AXTx33 AXTx34 |
- * AXTx41 AXTx42 > AXTx43 AXTx44 |
- *
- * BXTx13 BXTx14 | BXTx11 BXTx12 >
- * BXTx23 BXTx24 | BXTx21 BXTx22 >
- * BXTx33 BXTx34 | BXTx31 BXTx32 >
- * BXTx43 BXTx44 | BXTx41 BXTx42 >
- *
- * These 32 named registers require only 16 architectural registers. 1
- * additional architectural register is used as scratch space and 8
- * architectural registers are used to load in the values x[1..4][3,4].
- *
- * Input and output addressing
- * ===========================
- * TODO Description
- */
- const float *inptr0 = input;
- const float *inptr1 = input + input_row_stride;
- const float *inptr2 = input + input_row_stride * 2;
- const float *inptr3 = input + input_row_stride * 3;
-
- float *outptr0 = matrix;
- float *outptr4 = matrix + matrix_stride * 4;
- float *outptr8 = matrix + matrix_stride * 8;
- float *outptr12 = matrix + matrix_stride * 12;
-
- int tile_j = tile_N; // Tiles to process
-
- asm volatile (
- // Named SIMD registers according to the policy given above
- // Registers into which to load the latter two columns of `x`
- "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n"
- "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n"
- "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n"
- "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n"
-
- // Registers for storing X.T x (both A and B halves)
- "AXTx11 .req v8\n" "BXTx13 .req v8\n"
- "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n"
- "AXTx21 .req v10\n" "BXTx23 .req v10\n"
- "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n"
- "AXTx31 .req v12\n" "BXTx33 .req v12\n"
- "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n"
- "AXTx41 .req v14\n" "BXTx43 .req v14\n"
- "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n"
- "AXTx13 .req v16\n" "BXTx11 .req v16\n"
- "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n"
- "AXTx23 .req v18\n" "BXTx21 .req v18\n"
- "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n"
- "AXTx33 .req v20\n" "BXTx31 .req v20\n"
- "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n"
- "AXTx43 .req v22\n" "BXTx41 .req v22\n"
- "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n"
-
- // Result register (TODO Does using more registers yield better
- // performance)
- "U .req v24\n qU .req q24\n"
-
- // ----------------------------------------------------------------------
- // Head of loop
- // Loads a complete 4x4 tile of x, computes X.T x, computes and stores
- // `U = X.T x X`. Prepares for the 'A' half of the loop.
- // NOTE: Since the first tile has the leftmost column padded we can
- // skip 4 loads and 4 calculations for the matrix X.T x X.
-
- // Temporarily alias registers for computing the first (non-padded)
- // column of x.
- "x_12 .req v0\n qx_12 .req q0\n"
- "x_22 .req v1\n qx_22 .req q1\n"
- "x_32 .req v2\n qx_32 .req q2\n"
- "x_42 .req v3\n qx_42 .req q3\n"
-
- "ldr qx_12, [%x[inptr0]]\n"
- "ldr qx_22, [%x[inptr1]]\n"
- "ldr qx_32, [%x[inptr2]]\n"
- "ldr qx_42, [%x[inptr3]]\n"
-
- "fsub BXTx12.4s, x_12.4s, x_32.4s\n"
- "fadd BXTx22.4s, x_22.4s, x_32.4s\n"
- "fsub BXTx32.4s, x_32.4s, x_22.4s\n"
- "fsub BXTx42.4s, x_22.4s, x_42.4s\n"
-
- ".unreq x_12\n .unreq qx_12\n"
- ".unreq x_22\n .unreq qx_22\n"
- ".unreq x_32\n .unreq qx_32\n"
- ".unreq x_42\n .unreq qx_42\n"
-
- // Load and compute latter two columns of the first tile. Progress the
- // input pointers (by three columns so that the each points are the
- // second column of the next tile, that is, each points at the first
- // column which must be read for the next tile.
- "ldr qx_13, [%x[inptr0], %x[colstride1]]\n"
- "ldr qx_23, [%x[inptr1], %x[colstride1]]\n"
- "ldr qx_33, [%x[inptr2], %x[colstride1]]\n"
- "ldr qx_43, [%x[inptr3], %x[colstride1]]\n"
-
- "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
- "ldr qx_14, [%x[inptr0], %x[colstride2]]\n"
-
- "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
- "ldr qx_24, [%x[inptr1], %x[colstride2]]\n"
-
- "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
- "ldr qx_34, [%x[inptr2], %x[colstride2]]\n"
-
- "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
- "ldr qx_44, [%x[inptr3], %x[colstride2]]\n"
-
- "fsub BXTx14.4s, x_14.4s, x_34.4s\n"
- "add %x[inptr0], %x[inptr0], %x[colstride3]\n"
-
- "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
- "add %x[inptr1], %x[inptr1], %x[colstride3]\n"
-
- "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], %x[colstride3]\n"
-
- "fsub BXTx44.4s, x_24.4s, x_44.4s\n"
- "add %x[inptr3], %x[inptr3], %x[colstride3]\n"
-
- // Compute and store U for the first tile
- // First row
- "fneg U.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
- // Second row
- "fneg U.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
- // Third row
- "fneg U.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
- // Fourth row, simultaneously load the first column of inputs for the
- // next tile.
- "fneg U.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
- "ldr qx_13, [%x[inptr0]]\n"
-
- "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "ldr qx_23, [%x[inptr1]]\n"
-
- "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "ldr qx_33, [%x[inptr2]]\n"
-
- "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
- "ldr qx_43, [%x[inptr3]]\n"
-
- "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-
- // Update the loop counter, subtract two to account for both the head and
- // the tail.
- "subs %x[tile_j], %x[tile_j], #2\n"
- "beq 2f\n" // Jump to "A" tail if out of tiles
-
- // ----------------------------------------------------------------------
- "1:"
- // Start part A
- // Load last column of this tile (the first column has already been
- // loaded) and compute latter two columns of X.T x.
- "fsub AXTx13.4s, x_13.4s, x_33.4s\n"
- "ldr qx_14, [%x[inptr0], %x[colstride1]]\n"
- "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
- "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
- "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
- "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
- "fsub AXTx43.4s, x_23.4s, x_43.4s\n"
- "ldr qx_44, [%x[inptr3], %x[colstride1]]\n"
- "fsub AXTx14.4s, x_14.4s, x_34.4s\n"
- "add %x[inptr0], %x[inptr0], %x[colstride2]\n"
- "fadd AXTx24.4s, x_24.4s, x_34.4s\n"
- "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
- "fsub AXTx34.4s, x_34.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
- "fsub AXTx44.4s, x_24.4s, x_44.4s\n"
- "add %x[inptr3], %x[inptr3], %x[colstride2]\n"
-
- // Compute and store U.
- // First row
- "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, AXTx12.4s, AXTx14.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
- // Second row
- "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fsub U.4s, AXTx22.4s, AXTx24.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
- // Third row
- "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, AXTx32.4s, AXTx34.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
- // Fourth row
- "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
- "ldr qx_13, [%x[inptr0]]\n"
-
- "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "ldr qx_23, [%x[inptr1]]\n"
-
- "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "ldr qx_33, [%x[inptr2]]\n"
-
- "fsub U.4s, AXTx42.4s, AXTx44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
- "ldr qx_43, [%x[inptr3]]\n"
-
- "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-
- "subs %x[tile_j], %x[tile_j], #1\n"
- "beq 3f\n" // Jump to 'B' tail
-
- // Start part B
- // Load last column of this tile (the first column has already been
- // loaded) and compute latter two columns of X.T x.
- "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
- "ldr qx_14, [%x[inptr0], %x[colstride1]]\n"
- "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
- "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
- "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
- "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
- "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
- "ldr qx_44, [%x[inptr3], %x[colstride1]]\n"
- "fsub BXTx14.4s, x_14.4s, x_34.4s\n"
- "add %x[inptr0], %x[inptr0], %x[colstride2]\n"
- "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
- "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
- "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
- "fsub BXTx44.4s, x_24.4s, x_44.4s\n"
- "add %x[inptr3], %x[inptr3], %x[colstride2]\n"
-
- // Compute and store U.
- // First row
- "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
- // Second row
- "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
- // Third row
- "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
- // Fourth row
- "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
- "ldr qx_13, [%x[inptr0]]\n"
-
- "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "ldr qx_23, [%x[inptr1]]\n"
-
- "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "ldr qx_33, [%x[inptr2]]\n"
-
- "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
- "ldr qx_43, [%x[inptr3]]\n"
-
- "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
- "subs %x[tile_j], %x[tile_j], #1\n"
- "bne 1b\n" // Continue loop, otherwise flow into 'A' tail
-
- // ----------------------------------------------------------------------
- "2:"
- // 'A' tail
- // Since the final column is padding and the last-but-one column has
- // already been loaded just compute the 3rd column of `X.T x'.
- "fsub AXTx13.4s, x_13.4s, x_33.4s\n"
- "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
- "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
- "fsub AXTx43.4s, x_23.4s, x_43.4s\n"
-
- // Compute and store U. Modified to account for the final column of X.T
- // x containing padding. Note, it is also unnecessary to update the
- // output pointers.
- // First row
- "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "str qAXTx12, [%x[outptr0], %x[mstride3]]\n"
-
- // Second row
- "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "str qAXTx22, [%x[outptr4], %x[mstride3]]\n"
-
- // Third row
- "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "str qAXTx32, [%x[outptr8], %x[mstride3]]\n"
-
- // Fourth row
- "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
- "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "str qAXTx42, [%x[outptr12], %x[mstride3]]\n"
-
- "b 4f\n" // Jump to end of function
-
- // ----------------------------------------------------------------------
- "3:"
- // 'B' tail
- // Since the final column is padding and the last-but-one column has
- // already been loaded just compute the 3rd column of `X.T x'.
- "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
- "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
- "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
- "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
-
- // Compute and store U. Modified to account for the final column of X.T
- // x containing padding. Note, it is also unnecessary to update the
- // output pointers.
- // First row
- "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "str qBXTx12, [%x[outptr0], %x[mstride3]]\n"
-
- // Second row
- "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "str qBXTx22, [%x[outptr4], %x[mstride3]]\n"
-
- // Third row
- "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "str qBXTx32, [%x[outptr8], %x[mstride3]]\n"
-
- // Fourth row
- "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
- "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "str qBXTx42, [%x[outptr12], %x[mstride3]]\n"
-
- // ----------------------------------------------------------------------
- "4:"
- // End of function
-
- // Clear names
- ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n"
- ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n"
- ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n"
- ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n"
- ".unreq AXTx11\n" ".unreq BXTx13\n"
- ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n"
- ".unreq AXTx21\n" ".unreq BXTx23\n"
- ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n"
- ".unreq AXTx31\n" ".unreq BXTx33\n"
- ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n"
- ".unreq AXTx41\n" ".unreq BXTx43\n"
- ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n"
- ".unreq AXTx13\n" ".unreq BXTx11\n"
- ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n"
- ".unreq AXTx23\n" ".unreq BXTx21\n"
- ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n"
- ".unreq AXTx33\n" ".unreq BXTx31\n"
- ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n"
- ".unreq AXTx43\n" ".unreq BXTx41\n"
- ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n"
- ".unreq U\n" ".unreq qU\n"
- : [inptr0] "+r" (inptr0),
- [inptr1] "+r" (inptr1),
- [inptr2] "+r" (inptr2),
- [inptr3] "+r" (inptr3),
- [outptr0] "+r" (outptr0),
- [outptr4] "+r" (outptr4),
- [outptr8] "+r" (outptr8),
- [outptr12] "+r" (outptr12),
- [tile_j] "+r" (tile_j) // Tile counter
- : [colstride1] "r" (1 * input_col_stride * sizeof(float)),
- [colstride2] "r" (2 * input_col_stride * sizeof(float)),
- [colstride3] "r" (3 * input_col_stride * sizeof(float)),
- [mstride1] "r" (1 * matrix_stride * sizeof(float)),
- [mstride2] "r" (2 * matrix_stride * sizeof(float)),
- [mstride3] "r" (3 * matrix_stride * sizeof(float)),
- [matrix_row_stride] "r" (matrix_row_stride * sizeof(float))
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
- "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
- "v22", "v23", "v24"
- );
-}
-
-// Pad top, left and right by 1.
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInput<float>::process_tile_row<1, 1, 0, 1, 4>(
- const int tile_N,
- const float* const input,
- const int input_row_stride,
- const int input_col_stride,
- float* const matrix,
- const int matrix_stride,
- const int matrix_row_stride
-) {
- const float *inptr0 = input;
- const float *inptr1 = input + input_row_stride;
- const float *inptr2 = input + input_row_stride * 2;
-
- float *outptr0 = matrix;
- float *outptr4 = matrix + matrix_stride * 4;
- float *outptr8 = matrix + matrix_stride * 8;
- float *outptr12 = matrix + matrix_stride * 12;
-
- int tile_j = tile_N; // Tiles to process
-
- asm volatile (
- // Named SIMD registers according to the policy given above
- // Registers into which to load the latter two columns of `x`
- // NOTE: We need only load the latter three rows since we know that the
- // first row is padded.
- "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n"
- "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n"
- "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n"
-
- // Registers for storing X.T x (both A and B halves)
- "AXTx11 .req v8\n" "BXTx13 .req v8\n"
- "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n"
- "AXTx21 .req v10\n" "BXTx23 .req v10\n"
- "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n"
- "AXTx31 .req v12\n" "BXTx33 .req v12\n"
- "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n"
- "AXTx41 .req v14\n" "BXTx43 .req v14\n"
- "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n"
- "AXTx13 .req v16\n" "BXTx11 .req v16\n"
- "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n"
- "AXTx23 .req v18\n" "BXTx21 .req v18\n"
- "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n"
- "AXTx33 .req v20\n" "BXTx31 .req v20\n"
- "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n"
- "AXTx43 .req v22\n" "BXTx41 .req v22\n"
- "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n"
-
- // Result register (TODO Does using more registers yield better
- // performance)
- "U .req v24\n qU .req q24\n"
-
- // ----------------------------------------------------------------------
- // Head of loop
- // Loads a complete 4x4 tile of x, computes X.T x, computes and stores
- // `U = X.T x X`. Prepares for the 'A' half of the loop.
- // NOTE: Since the first tile has the leftmost column padded we can
- // skip 4 loads and 4 calculations for the matrix X.T x X.
-
- // Temporarily alias registers for computing the first (non-padded)
- // column of x.
- "x_22 .req v1\n qx_22 .req q1\n"
- "x_32 .req v2\n qx_32 .req q2\n"
- "x_42 .req v3\n qx_42 .req q3\n"
-
- "ldr qx_22, [%x[inptr1]]\n"
- "ldr qx_32, [%x[inptr2]]\n"
- "ldr qx_42, [%x[inptr3]]\n"
-
- "fneg BXTx12.4s, x_32.4s\n"
- "fadd BXTx22.4s, x_22.4s, x_32.4s\n"
- "fsub BXTx32.4s, x_32.4s, x_22.4s\n"
- "fsub BXTx42.4s, x_22.4s, x_42.4s\n"
-
- ".unreq x_22\n .unreq qx_22\n"
- ".unreq x_32\n .unreq qx_32\n"
- ".unreq x_42\n .unreq qx_42\n"
-
- // Load and compute latter two columns of the first tile. Progress the
- // input pointers (by three columns so that the each points are the
- // second column of the next tile, that is, each points at the first
- // column which must be read for the next tile.
- "ldr qx_23, [%x[inptr1], %x[colstride1]]\n"
- "ldr qx_33, [%x[inptr2], %x[colstride1]]\n"
- "ldr qx_43, [%x[inptr3], %x[colstride1]]\n"
-
- "fneg BXTx13.4s, x_33.4s\n"
-
- "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
- "ldr qx_24, [%x[inptr1], %x[colstride2]]\n"
-
- "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
- "ldr qx_34, [%x[inptr2], %x[colstride2]]\n"
-
- "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
- "ldr qx_44, [%x[inptr3], %x[colstride2]]\n"
-
- "fneg BXTx14.4s, x_34.4s\n"
-
- "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
- "add %x[inptr1], %x[inptr1], %x[colstride3]\n"
-
- "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], %x[colstride3]\n"
-
- "fsub BXTx44.4s, x_24.4s, x_44.4s\n"
- "add %x[inptr3], %x[inptr3], %x[colstride3]\n"
-
- // Compute and store U for the first tile
- // First row
- "fneg U.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
- // Second row
- "fneg U.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
- // Third row
- "fneg U.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
- // Fourth row, simultaneously load the first column of inputs for the
- // next tile.
- "fneg U.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
-
- "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "ldr qx_23, [%x[inptr1]]\n"
-
- "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "ldr qx_33, [%x[inptr2]]\n"
-
- "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
- "ldr qx_43, [%x[inptr3]]\n"
-
- "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-
- // Update the loop counter, subtract two to account for both the head and
- // the tail.
- "subs %x[tile_j], %x[tile_j], #2\n"
- "beq 2f\n" // Jump to "A" tail if out of tiles
-
- // ----------------------------------------------------------------------
- "1:"
- // Start part A
- // Load last column of this tile (the first column has already been
- // loaded) and compute latter two columns of X.T x.
- "fneg AXTx13.4s, x_33.4s\n"
- "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
- "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
- "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
- "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
- "fsub AXTx43.4s, x_23.4s, x_43.4s\n"
- "ldr qx_44, [%x[inptr3], %x[colstride1]]\n"
- "fneg AXTx14.4s, x_34.4s\n"
- "fadd AXTx24.4s, x_24.4s, x_34.4s\n"
- "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
- "fsub AXTx34.4s, x_34.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
- "fsub AXTx44.4s, x_24.4s, x_44.4s\n"
- "add %x[inptr3], %x[inptr3], %x[colstride2]\n"
-
- // Compute and store U.
- // First row
- "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, AXTx12.4s, AXTx14.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
- // Second row
- "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fsub U.4s, AXTx22.4s, AXTx24.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
- // Third row
- "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, AXTx32.4s, AXTx34.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
- // Fourth row
- "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
-
- "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "ldr qx_23, [%x[inptr1]]\n"
-
- "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "ldr qx_33, [%x[inptr2]]\n"
-
- "fsub U.4s, AXTx42.4s, AXTx44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
- "ldr qx_43, [%x[inptr3]]\n"
-
- "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-
- "subs %x[tile_j], %x[tile_j], #1\n"
- "beq 3f\n" // Jump to 'B' tail
-
- // Start part B
- // Load last column of this tile (the first column has already been
- // loaded) and compute latter two columns of X.T x.
- "fneg BXTx13.4s, x_33.4s\n"
- "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
- "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
- "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
- "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
- "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
- "ldr qx_44, [%x[inptr3], %x[colstride1]]\n"
- "fneg BXTx14.4s, x_34.4s\n"
- "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
- "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
- "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
- "fsub BXTx44.4s, x_24.4s, x_44.4s\n"
- "add %x[inptr3], %x[inptr3], %x[colstride2]\n"
-
- // Compute and store U.
- // First row
- "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
- // Second row
- "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
- // Third row
- "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
- // Fourth row
- "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
-
- "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "ldr qx_23, [%x[inptr1]]\n"
-
- "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "ldr qx_33, [%x[inptr2]]\n"
-
- "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
- "ldr qx_43, [%x[inptr3]]\n"
-
- "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
- "subs %x[tile_j], %x[tile_j], #1\n"
- "bne 1b\n" // Continue loop, otherwise flow into 'A' tail
-
- // ----------------------------------------------------------------------
- "2:"
- // 'A' tail
- // Since the final column is padding and the last-but-one column has
- // already been loaded just compute the 3rd column of `X.T x'.
- "fneg AXTx13.4s, x_33.4s\n"
- "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
- "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
- "fsub AXTx43.4s, x_23.4s, x_43.4s\n"
-
- // Compute and store U. Modified to account for the final column of X.T
- // x containing padding. Note, it is also unnecessary to update the
- // output pointers.
- // First row
- "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "str qAXTx12, [%x[outptr0], %x[mstride3]]\n"
-
- // Second row
- "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "str qAXTx22, [%x[outptr4], %x[mstride3]]\n"
-
- // Third row
- "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "str qAXTx32, [%x[outptr8], %x[mstride3]]\n"
-
- // Fourth row
- "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
- "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "str qAXTx42, [%x[outptr12], %x[mstride3]]\n"
-
- "b 4f\n" // Jump to end of function
-
- // ----------------------------------------------------------------------
- "3:"
- // 'B' tail
- // Since the final column is padding and the last-but-one column has
- // already been loaded just compute the 3rd column of `X.T x'.
- "fneg BXTx13.4s, x_33.4s\n"
- "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
- "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
- "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
-
- // Compute and store U. Modified to account for the final column of X.T
- // x containing padding. Note, it is also unnecessary to update the
- // output pointers.
- // First row
- "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "str qBXTx12, [%x[outptr0], %x[mstride3]]\n"
-
- // Second row
- "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "str qBXTx22, [%x[outptr4], %x[mstride3]]\n"
-
- // Third row
- "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "str qBXTx32, [%x[outptr8], %x[mstride3]]\n"
-
- // Fourth row
- "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
- "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "str qBXTx42, [%x[outptr12], %x[mstride3]]\n"
-
- // ----------------------------------------------------------------------
- "4:"
- // End of function
-
- // Clear names
- ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n"
- ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n"
- ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n"
- ".unreq AXTx11\n" ".unreq BXTx13\n"
- ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n"
- ".unreq AXTx21\n" ".unreq BXTx23\n"
- ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n"
- ".unreq AXTx31\n" ".unreq BXTx33\n"
- ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n"
- ".unreq AXTx41\n" ".unreq BXTx43\n"
- ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n"
- ".unreq AXTx13\n" ".unreq BXTx11\n"
- ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n"
- ".unreq AXTx23\n" ".unreq BXTx21\n"
- ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n"
- ".unreq AXTx33\n" ".unreq BXTx31\n"
- ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n"
- ".unreq AXTx43\n" ".unreq BXTx41\n"
- ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n"
- ".unreq U\n" ".unreq qU\n"
- : [inptr1] "+r" (inptr0), // Offset to account for padded row
- [inptr2] "+r" (inptr1), // Offset to account for padded row
- [inptr3] "+r" (inptr2), // Offset to account for padded row
- [outptr0] "+r" (outptr0),
- [outptr4] "+r" (outptr4),
- [outptr8] "+r" (outptr8),
- [outptr12] "+r" (outptr12),
- [tile_j] "+r" (tile_j) // Tile counter
- : [colstride1] "r" (1 * input_col_stride * sizeof(float)),
- [colstride2] "r" (2 * input_col_stride * sizeof(float)),
- [colstride3] "r" (3 * input_col_stride * sizeof(float)),
- [mstride1] "r" (1 * matrix_stride * sizeof(float)),
- [mstride2] "r" (2 * matrix_stride * sizeof(float)),
- [mstride3] "r" (3 * matrix_stride * sizeof(float)),
- [matrix_row_stride] "r" (matrix_row_stride * sizeof(float))
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
- "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
- "v22", "v23", "v24"
- );
-}
-
-// Pad left, right and bottom by 1.
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInput<float>::process_tile_row<0, 1, 1, 1, 4>(
- const int tile_N,
- const float* const input,
- const int input_row_stride,
- const int input_col_stride,
- float* const matrix,
- const int matrix_stride,
- const int matrix_row_stride
-) {
- const float *inptr0 = input;
- const float *inptr1 = input + input_row_stride;
- const float *inptr2 = input + input_row_stride * 2;
-
- float *outptr0 = matrix;
- float *outptr4 = matrix + matrix_stride * 4;
- float *outptr8 = matrix + matrix_stride * 8;
- float *outptr12 = matrix + matrix_stride * 12;
-
- int tile_j = tile_N; // Tiles to process
-
- asm volatile (
- // Named SIMD registers according to the policy given above
- // Registers into which to load the latter two columns of `x`
- // NOTE: Bottom row is not required since since it is padded.
- "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n"
- "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n"
- "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n"
-
- // Registers for storing X.T x (both A and B halves)
- "AXTx11 .req v8\n" "BXTx13 .req v8\n"
- "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n"
- "AXTx21 .req v10\n" "BXTx23 .req v10\n"
- "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n"
- "AXTx31 .req v12\n" "BXTx33 .req v12\n"
- "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n"
- "AXTx41 .req v14\n" "BXTx43 .req v14\n"
- "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n"
- "AXTx13 .req v16\n" "BXTx11 .req v16\n"
- "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n"
- "AXTx23 .req v18\n" "BXTx21 .req v18\n"
- "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n"
- "AXTx33 .req v20\n" "BXTx31 .req v20\n"
- "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n"
- "AXTx43 .req v22\n" "BXTx41 .req v22\n"
- "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n"
-
- // Result register (TODO Does using more registers yield better
- // performance)
- "U .req v24\n qU .req q24\n"
-
- // ----------------------------------------------------------------------
- // Head of loop
- // Loads a complete 4x4 tile of x, computes X.T x, computes and stores
- // `U = X.T x X`. Prepares for the 'A' half of the loop.
- // NOTE: Since the first tile has the leftmost column padded we can
- // skip 4 loads and 4 calculations for the matrix X.T x X.
-
- // Temporarily alias registers for computing the first (non-padded)
- // column of x.
- "x_12 .req v0\n qx_12 .req q0\n"
- "x_22 .req v1\n qx_22 .req q1\n"
- "x_32 .req v2\n qx_32 .req q2\n"
-
- "ldr qx_12, [%x[inptr0]]\n"
- "ldr qx_22, [%x[inptr1]]\n"
- "ldr qx_32, [%x[inptr2]]\n"
-
- "fsub BXTx12.4s, x_12.4s, x_32.4s\n"
- "fadd BXTx22.4s, x_22.4s, x_32.4s\n"
- "fsub BXTx32.4s, x_32.4s, x_22.4s\n"
- "mov BXTx42.16b, x_22.16b\n" // Probably should do better
-
- ".unreq x_12\n .unreq qx_12\n"
- ".unreq x_22\n .unreq qx_22\n"
- ".unreq x_32\n .unreq qx_32\n"
-
- // Load and compute latter two columns of the first tile. Progress the
- // input pointers (by three columns so that the each points are the
- // second column of the next tile, that is, each points at the first
- // column which must be read for the next tile.
- "ldr qx_13, [%x[inptr0], %x[colstride1]]\n"
- "ldr qx_23, [%x[inptr1], %x[colstride1]]\n"
- "ldr qx_33, [%x[inptr2], %x[colstride1]]\n"
-
- "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
- "ldr qx_14, [%x[inptr0], %x[colstride2]]\n"
-
- "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
- "ldr qx_24, [%x[inptr1], %x[colstride2]]\n"
-
- "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
- "ldr qx_34, [%x[inptr2], %x[colstride2]]\n"
-
- "mov BXTx43.16b, x_23.16b\n"
- "fsub BXTx14.4s, x_14.4s, x_34.4s\n"
- "add %x[inptr0], %x[inptr0], %x[colstride3]\n"
-
- "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
- "add %x[inptr1], %x[inptr1], %x[colstride3]\n"
-
- "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], %x[colstride3]\n"
-
- "mov BXTx44.16b, x_24.16b\n"
-
- // Compute and store U for the first tile
- // First row
- "fneg U.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
- // Second row
- "fneg U.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
- // Third row
- "fneg U.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
- // Fourth row, simultaneously load the first column of inputs for the
- // next tile.
- "fneg U.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
- "ldr qx_13, [%x[inptr0]]\n"
-
- "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "ldr qx_23, [%x[inptr1]]\n"
-
- "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "ldr qx_33, [%x[inptr2]]\n"
-
- "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
-
- "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-
- // Update the loop counter, subtract two to account for both the head and
- // the tail.
- "subs %x[tile_j], %x[tile_j], #2\n"
- "beq 2f\n" // Jump to "A" tail if out of tiles
-
- // ----------------------------------------------------------------------
- "1:"
- // Start part A
- // Load last column of this tile (the first column has already been
- // loaded) and compute latter two columns of X.T x.
- "fsub AXTx13.4s, x_13.4s, x_33.4s\n"
- "ldr qx_14, [%x[inptr0], %x[colstride1]]\n"
- "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
- "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
- "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
- "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
- "mov AXTx43.16b, x_23.16b\n"
-
- "fsub AXTx14.4s, x_14.4s, x_34.4s\n"
- "add %x[inptr0], %x[inptr0], %x[colstride2]\n"
- "fadd AXTx24.4s, x_24.4s, x_34.4s\n"
- "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
- "fsub AXTx34.4s, x_34.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
- "mov AXTx44.16b, x_24.16b\n"
-
- // Compute and store U.
- // First row
- "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, AXTx12.4s, AXTx14.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
- // Second row
- "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fsub U.4s, AXTx22.4s, AXTx24.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
- // Third row
- "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, AXTx32.4s, AXTx34.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
- // Fourth row
- "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
- "ldr qx_13, [%x[inptr0]]\n"
-
- "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "ldr qx_23, [%x[inptr1]]\n"
-
- "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "ldr qx_33, [%x[inptr2]]\n"
-
- "fsub U.4s, AXTx42.4s, AXTx44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
-
- "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-
- "subs %x[tile_j], %x[tile_j], #1\n"
- "beq 3f\n" // Jump to 'B' tail
-
- // Start part B
- // Load last column of this tile (the first column has already been
- // loaded) and compute latter two columns of X.T x.
- "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
- "ldr qx_14, [%x[inptr0], %x[colstride1]]\n"
- "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
- "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
- "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
- "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
- "mov BXTx43.16b, x_23.16b\n"
-
- "fsub BXTx14.4s, x_14.4s, x_34.4s\n"
- "add %x[inptr0], %x[inptr0], %x[colstride2]\n"
- "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
- "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
- "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
- "mov BXTx44.16b, x_24.16b\n"
-
- // Compute and store U.
- // First row
- "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
- // Second row
- "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
- // Third row
- "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
- // Fourth row
- "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
- "ldr qx_13, [%x[inptr0]]\n"
-
- "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "ldr qx_23, [%x[inptr1]]\n"
-
- "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "ldr qx_33, [%x[inptr2]]\n"
-
- "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
-
- "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
- "subs %x[tile_j], %x[tile_j], #1\n"
- "bne 1b\n" // Continue loop, otherwise flow into 'A' tail
-
- // ----------------------------------------------------------------------
- "2:"
- // 'A' tail
- // Since the final column is padding and the last-but-one column has
- // already been loaded just compute the 3rd column of `X.T x'.
- "fsub AXTx13.4s, x_13.4s, x_33.4s\n"
- "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
- "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
- "mov AXTx43.16b, x_23.16b\n"
-
- // Compute and store U. Modified to account for the final column of X.T
- // x containing padding. Note, it is also unnecessary to update the
- // output pointers.
- // First row
- "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "str qAXTx12, [%x[outptr0], %x[mstride3]]\n"
-
- // Second row
- "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "str qAXTx22, [%x[outptr4], %x[mstride3]]\n"
-
- // Third row
- "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "str qAXTx32, [%x[outptr8], %x[mstride3]]\n"
-
- // Fourth row
- "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
- "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "str qAXTx42, [%x[outptr12], %x[mstride3]]\n"
-
- "b 4f\n" // Jump to end of function
-
- // ----------------------------------------------------------------------
- "3:"
- // 'B' tail
- // Since the final column is padding and the last-but-one column has
- // already been loaded just compute the 3rd column of `X.T x'.
- "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
- "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
- "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
- "mov BXTx43.16b, x_23.16b\n"
-
- // Compute and store U. Modified to account for the final column of X.T
- // x containing padding. Note, it is also unnecessary to update the
- // output pointers.
- // First row
- "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "str qBXTx12, [%x[outptr0], %x[mstride3]]\n"
-
- // Second row
- "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "str qBXTx22, [%x[outptr4], %x[mstride3]]\n"
-
- // Third row
- "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "str qBXTx32, [%x[outptr8], %x[mstride3]]\n"
-
- // Fourth row
- "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12]]\n"
- "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "str qBXTx42, [%x[outptr12], %x[mstride3]]\n"
-
- // ----------------------------------------------------------------------
- "4:"
- // End of function
-
- // Clear names
- ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n"
- ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n"
- ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n"
- ".unreq AXTx11\n" ".unreq BXTx13\n"
- ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n"
- ".unreq AXTx21\n" ".unreq BXTx23\n"
- ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n"
- ".unreq AXTx31\n" ".unreq BXTx33\n"
- ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n"
- ".unreq AXTx41\n" ".unreq BXTx43\n"
- ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n"
- ".unreq AXTx13\n" ".unreq BXTx11\n"
- ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n"
- ".unreq AXTx23\n" ".unreq BXTx21\n"
- ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n"
- ".unreq AXTx33\n" ".unreq BXTx31\n"
- ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n"
- ".unreq AXTx43\n" ".unreq BXTx41\n"
- ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n"
- ".unreq U\n" ".unreq qU\n"
- : [inptr0] "+r" (inptr0),
- [inptr1] "+r" (inptr1),
- [inptr2] "+r" (inptr2),
- [outptr0] "+r" (outptr0),
- [outptr4] "+r" (outptr4),
- [outptr8] "+r" (outptr8),
- [outptr12] "+r" (outptr12),
- [tile_j] "+r" (tile_j) // Tile counter
- : [colstride1] "r" (1 * input_col_stride * sizeof(float)),
- [colstride2] "r" (2 * input_col_stride * sizeof(float)),
- [colstride3] "r" (3 * input_col_stride * sizeof(float)),
- [mstride1] "r" (1 * matrix_stride * sizeof(float)),
- [mstride2] "r" (2 * matrix_stride * sizeof(float)),
- [mstride3] "r" (3 * matrix_stride * sizeof(float)),
- [matrix_row_stride] "r" (matrix_row_stride * sizeof(float))
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
- "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
- "v22", "v23", "v24"
- );
-}
-}
-#endif // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp
deleted file mode 100644
index ad1ad55291..0000000000
--- a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp
+++ /dev/null
@@ -1,961 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-#include "../input_2x2_3x3.hpp"
-
-#ifdef __aarch64__
-
-namespace winograd {
-
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 0, 0, 0, 4>(
- int &n_channels, // Number of channels in the tile
- const float* &inptr0,
- const int input_row_stride,
- const int input_col_stride,
- float* &outptr0,
- const int matrix_stride
-) {
- // We use 4 pointers to point to the starting position on each row and use
- // three offsets to extract elements from each of the other 3 columns.
- auto inptr1 = inptr0 + 1*input_row_stride;
- auto inptr2 = inptr0 + 2*input_row_stride;
- auto inptr3 = inptr0 + 3*input_row_stride;
-
- // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
- // offsets to access the intermediate matrices.
- auto outptr1 = outptr0 + matrix_stride * 4;
- auto outptr2 = outptr0 + matrix_stride * 8;
- auto outptr3 = outptr0 + matrix_stride * 12;
-
- for (; n_channels > 3; n_channels -= 4) {
- asm volatile (
- "X_11 .req v0\n" "qX_11 .req q0\n"
- "X_12 .req v1\n" "qX_12 .req q1\n"
- "X_13 .req v2\n" "qX_13 .req q2\n"
- "X_14 .req v3\n" "qX_14 .req q3\n"
- "X_21 .req v4\n" "qX_21 .req q4\n"
- "X_22 .req v5\n" "qX_22 .req q5\n"
- "X_23 .req v6\n" "qX_23 .req q6\n"
- "X_24 .req v7\n" "qX_24 .req q7\n"
- "X_31 .req v8\n" "qX_31 .req q8\n"
- "X_32 .req v9\n" "qX_32 .req q9\n"
- "X_33 .req v10\n" "qX_33 .req q10\n"
- "X_34 .req v11\n" "qX_34 .req q11\n"
- "X_41 .req v12\n" "qX_41 .req q12\n"
- "X_42 .req v13\n" "qX_42 .req q13\n"
- "X_43 .req v14\n" "qX_43 .req q14\n"
- "X_44 .req v15\n" "qX_44 .req q15\n"
- "xX_11 .req v16\n"
- "xX_12 .req v17\n"
- "xX_13 .req v18\n"
- "xX_14 .req v19\n"
- "xX_21 .req v20\n"
- "xX_22 .req v21\n"
- "xX_23 .req v22\n"
- "xX_24 .req v23\n"
- "xX_31 .req v24\n"
- "xX_32 .req v25\n"
- "xX_33 .req v26\n"
- "xX_34 .req v27\n"
- "xX_41 .req v28\n"
- "xX_42 .req v29\n"
- "xX_43 .req v30\n"
- "xX_44 .req v31\n"
- " U .req v0\n"
- "qU .req q0\n"
-
- // Load the tile, and compute compute the matrix xX
- "ldr qX_11, [%x[inptr0]]\n"
- "ldr qX_12, [%x[inptr0], %x[colstride1]]\n"
- "ldr qX_13, [%x[inptr0], %x[colstride2]]\n"
- "ldr qX_14, [%x[inptr0], %x[colstride3]]\n"
- "add %x[inptr0], %x[inptr0], #0x10\n"
-
- "ldr qX_21, [%x[inptr1]]\n"
- "fsub xX_11.4s, x_11.4s, x_13.4s\n"
- "ldr qX_22, [%x[inptr1], %x[colstride1]]\n"
- "fadd xX_12.4s, x_12.4s, x_13.4s\n"
- "ldr qX_23, [%x[inptr1], %x[colstride2]]\n"
- "fsub xX_13.4s, x_13.4s, x_12.4s\n"
- "ldr qX_24, [%x[inptr1], %x[colstride3]]\n"
- "fsub xX_14.4s, x_12.4s, x_14.4s\n"
- "add %x[inptr1], %x[inptr1], #0x10\n"
-
- "ldr qX_31, [%x[inptr2]]\n"
- "fsub xX_21.4s, x_21.4s, x_23.4s\n"
- "ldr qX_32, [%x[inptr2], %x[colstride1]]\n"
- "fadd xX_22.4s, x_22.4s, x_23.4s\n"
- "ldr qX_33, [%x[inptr2], %x[colstride2]]\n"
- "fsub xX_23.4s, x_23.4s, x_22.4s\n"
- "ldr qX_34, [%x[inptr2], %x[colstride3]]\n"
- "fsub xX_24.4s, x_22.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], #0x10\n"
-
- "ldr qX_41, [%x[inptr3]]\n"
- "fsub xX_31.4s, x_31.4s, x_33.4s\n"
- "ldr qX_42, [%x[inptr3], %x[colstride1]]\n"
- "fadd xX_32.4s, x_32.4s, x_33.4s\n"
- "ldr qX_43, [%x[inptr3], %x[colstride2]]\n"
- "fsub xX_33.4s, x_33.4s, x_32.4s\n"
- "ldr qX_44, [%x[inptr3], %x[colstride3]]\n"
- "fsub xX_34.4s, x_32.4s, x_34.4s\n"
- "add %x[inptr3], %x[inptr3], #0x10\n"
-
- // Complete computing xX while beginning to compute and store
- // $U = X.T x X$
-
- "fsub xX_41.4s, x_41.4s, x_43.4s\n"
-
- "fsub U.4s, xX_11.4s, xX_31.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fsub U.4s, xX_12.4s, xX_32.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, xX_13.4s, xX_33.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, xX_14.4s, xX_34.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], #0x10\n"
-
- "fadd xX_42.4s, x_42.4s, x_43.4s\n"
-
- "fadd U.4s, xX_21.4s, xX_31.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, xX_22.4s, xX_32.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fadd U.4s, xX_23.4s, xX_33.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fadd U.4s, xX_24.4s, xX_34.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], #0x10\n"
-
- "fsub xX_43.4s, x_43.4s, x_42.4s\n"
-
- "fsub U.4s, xX_31.4s, xX_21.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fsub U.4s, xX_32.4s, xX_22.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, xX_33.4s, xX_23.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, xX_34.4s, xX_24.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], #0x10\n"
-
- "fsub xX_44.4s, x_42.4s, x_44.4s\n"
-
- "fsub U.4s, xX_21.4s, xX_41.4s\n"
- "str qU, [%x[outptr12]]\n"
- "fsub U.4s, xX_22.4s, xX_42.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "fsub U.4s, xX_23.4s, xX_43.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "fsub U.4s, xX_24.4s, xX_44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
- "add %x[outptr12], %x[outptr12], #0x10\n"
-
- ".unreq qU\n"
- ".unreq U\n"
- ".unreq X_11\n" ".unreq qX_11\n"
- ".unreq X_12\n" ".unreq qX_12\n"
- ".unreq X_13\n" ".unreq qX_13\n"
- ".unreq X_14\n" ".unreq qX_14\n"
- ".unreq X_21\n" ".unreq qX_21\n"
- ".unreq X_22\n" ".unreq qX_22\n"
- ".unreq X_23\n" ".unreq qX_23\n"
- ".unreq X_24\n" ".unreq qX_24\n"
- ".unreq X_31\n" ".unreq qX_31\n"
- ".unreq X_32\n" ".unreq qX_32\n"
- ".unreq X_33\n" ".unreq qX_33\n"
- ".unreq X_34\n" ".unreq qX_34\n"
- ".unreq X_41\n" ".unreq qX_41\n"
- ".unreq X_42\n" ".unreq qX_42\n"
- ".unreq X_43\n" ".unreq qX_43\n"
- ".unreq X_44\n" ".unreq qX_44\n"
- ".unreq xX_11\n"
- ".unreq xX_12\n"
- ".unreq xX_13\n"
- ".unreq xX_14\n"
- ".unreq xX_21\n"
- ".unreq xX_22\n"
- ".unreq xX_23\n"
- ".unreq xX_24\n"
- ".unreq xX_31\n"
- ".unreq xX_32\n"
- ".unreq xX_33\n"
- ".unreq xX_34\n"
- ".unreq xX_41\n"
- ".unreq xX_42\n"
- ".unreq xX_43\n"
- ".unreq xX_44\n"
-
- : [inptr0] "+r" (inptr0),
- [inptr1] "+r" (inptr1),
- [inptr2] "+r" (inptr2),
- [inptr3] "+r" (inptr3),
- [outptr0] "+r" (outptr0),
- [outptr4] "+r" (outptr1),
- [outptr8] "+r" (outptr2),
- [outptr12] "+r" (outptr3)
- : [colstride1] "r" (input_col_stride * sizeof(float)),
- [colstride2] "r" (input_col_stride * sizeof(float) * 2),
- [colstride3] "r" (input_col_stride * sizeof(float) * 3),
- [mstride1] "r" (matrix_stride * sizeof(float)),
- [mstride2] "r" (matrix_stride * sizeof(float) * 2),
- [mstride3] "r" (matrix_stride * sizeof(float) * 3)
- : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
- "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
- "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
- "v30", "v31"
- );
- }
-}
-
-// Pad top by 1
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<1, 0, 0, 0, 4>(
- int &n_channels, // Number of channels in the tile
- const float* &inptr0,
- const int input_row_stride,
- const int input_col_stride,
- float* &outptr0,
- const int matrix_stride
-) {
- // We use 4 pointers to point to the starting position on each row and use
- // three offsets to extract elements from each of the other 3 columns.
- auto inptr1 = inptr0 + 0*input_row_stride;
- auto inptr2 = inptr0 + 1*input_row_stride;
-
- // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
- // offsets to access the intermediate matrices.
- auto outptr1 = outptr0 + matrix_stride * 4;
- auto outptr2 = outptr0 + matrix_stride * 8;
- auto outptr3 = outptr0 + matrix_stride * 12;
-
- for (; n_channels > 3; n_channels -= 4) {
- asm volatile (
- "X_21 .req v4\n" "qX_21 .req q4\n"
- "X_22 .req v5\n" "qX_22 .req q5\n"
- "X_23 .req v6\n" "qX_23 .req q6\n"
- "X_24 .req v7\n" "qX_24 .req q7\n"
- "X_31 .req v8\n" "qX_31 .req q8\n"
- "X_32 .req v9\n" "qX_32 .req q9\n"
- "X_33 .req v10\n" "qX_33 .req q10\n"
- "X_34 .req v11\n" "qX_34 .req q11\n"
- "X_41 .req v12\n" "qX_41 .req q12\n"
- "X_42 .req v13\n" "qX_42 .req q13\n"
- "X_43 .req v14\n" "qX_43 .req q14\n"
- "X_44 .req v15\n" "qX_44 .req q15\n"
- "xX_21 .req v20\n"
- "xX_22 .req v21\n"
- "xX_23 .req v22\n"
- "xX_24 .req v23\n"
- "xX_31 .req v24\n"
- "xX_32 .req v25\n"
- "xX_33 .req v26\n"
- "xX_34 .req v27\n"
- "xX_41 .req v28\n"
- "xX_42 .req v29\n"
- "xX_43 .req v30\n"
- "xX_44 .req v31\n"
- " U .req v0\n"
- "qU .req q0\n"
-
- // Load the tile, and compute compute the matrix xX
- "ldr qX_21, [%x[inptr1]]\n"
- "ldr qX_22, [%x[inptr1], %x[colstride1]]\n"
- "ldr qX_23, [%x[inptr1], %x[colstride2]]\n"
- "ldr qX_24, [%x[inptr1], %x[colstride3]]\n"
- "add %x[inptr1], %x[inptr1], #0x10\n"
-
- "ldr qX_31, [%x[inptr2]]\n"
- "fsub xX_21.4s, x_21.4s, x_23.4s\n"
- "ldr qX_32, [%x[inptr2], %x[colstride1]]\n"
- "fadd xX_22.4s, x_22.4s, x_23.4s\n"
- "ldr qX_33, [%x[inptr2], %x[colstride2]]\n"
- "fsub xX_23.4s, x_23.4s, x_22.4s\n"
- "ldr qX_34, [%x[inptr2], %x[colstride3]]\n"
- "fsub xX_24.4s, x_22.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], #0x10\n"
-
- "ldr qX_41, [%x[inptr3]]\n"
- "fsub xX_31.4s, x_31.4s, x_33.4s\n"
- "ldr qX_42, [%x[inptr3], %x[colstride1]]\n"
- "fadd xX_32.4s, x_32.4s, x_33.4s\n"
- "ldr qX_43, [%x[inptr3], %x[colstride2]]\n"
- "fsub xX_33.4s, x_33.4s, x_32.4s\n"
- "ldr qX_44, [%x[inptr3], %x[colstride3]]\n"
- "fsub xX_34.4s, x_32.4s, x_34.4s\n"
- "add %x[inptr3], %x[inptr3], #0x10\n"
-
- // Complete computing xX while beginning to compute and store
- // $U = X.T x X$
-
- "fsub xX_41.4s, x_41.4s, x_43.4s\n"
-
- "fneg U.4s, xX_31.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fneg U.4s, xX_32.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fneg U.4s, xX_33.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fneg U.4s, xX_34.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], #0x10\n"
-
- "fadd xX_42.4s, x_42.4s, x_43.4s\n"
-
- "fadd U.4s, xX_21.4s, xX_31.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, xX_22.4s, xX_32.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fadd U.4s, xX_23.4s, xX_33.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fadd U.4s, xX_24.4s, xX_34.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], #0x10\n"
-
- "fsub xX_43.4s, x_43.4s, x_42.4s\n"
-
- "fsub U.4s, xX_31.4s, xX_21.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fsub U.4s, xX_32.4s, xX_22.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, xX_33.4s, xX_23.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, xX_34.4s, xX_24.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], #0x10\n"
-
- "fsub xX_44.4s, x_42.4s, x_44.4s\n"
-
- "fsub U.4s, xX_21.4s, xX_41.4s\n"
- "str qU, [%x[outptr12]]\n"
- "fsub U.4s, xX_22.4s, xX_42.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "fsub U.4s, xX_23.4s, xX_43.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "fsub U.4s, xX_24.4s, xX_44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
- "add %x[outptr12], %x[outptr12], #0x10\n"
-
- ".unreq qU\n"
- ".unreq U\n"
- ".unreq X_21\n" ".unreq qX_21\n"
- ".unreq X_22\n" ".unreq qX_22\n"
- ".unreq X_23\n" ".unreq qX_23\n"
- ".unreq X_24\n" ".unreq qX_24\n"
- ".unreq X_31\n" ".unreq qX_31\n"
- ".unreq X_32\n" ".unreq qX_32\n"
- ".unreq X_33\n" ".unreq qX_33\n"
- ".unreq X_34\n" ".unreq qX_34\n"
- ".unreq X_41\n" ".unreq qX_41\n"
- ".unreq X_42\n" ".unreq qX_42\n"
- ".unreq X_43\n" ".unreq qX_43\n"
- ".unreq X_44\n" ".unreq qX_44\n"
- ".unreq xX_21\n"
- ".unreq xX_22\n"
- ".unreq xX_23\n"
- ".unreq xX_24\n"
- ".unreq xX_31\n"
- ".unreq xX_32\n"
- ".unreq xX_33\n"
- ".unreq xX_34\n"
- ".unreq xX_41\n"
- ".unreq xX_42\n"
- ".unreq xX_43\n"
- ".unreq xX_44\n"
-
- : [inptr1] "+r" (inptr0), // Offset for missing row
- [inptr2] "+r" (inptr1), // Offset for missing row
- [inptr3] "+r" (inptr2), // Offset for missing row
- [outptr0] "+r" (outptr0),
- [outptr4] "+r" (outptr1),
- [outptr8] "+r" (outptr2),
- [outptr12] "+r" (outptr3)
- : [colstride1] "r" (input_col_stride * sizeof(float)),
- [colstride2] "r" (input_col_stride * sizeof(float) * 2),
- [colstride3] "r" (input_col_stride * sizeof(float) * 3),
- [mstride1] "r" (matrix_stride * sizeof(float)),
- [mstride2] "r" (matrix_stride * sizeof(float) * 2),
- [mstride3] "r" (matrix_stride * sizeof(float) * 3)
- : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
- "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
- "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
- "v30", "v31"
- );
- }
-}
-
-// Pad left by 1
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 1, 0, 0, 4>(
- int &n_channels, // Number of channels in the tile
- const float* &inptr0,
- const int input_row_stride,
- const int input_col_stride,
- float* &outptr0,
- const int matrix_stride
-) {
- // We use 4 pointers to point to the starting position on each row and use
- // three offsets to extract elements from each of the other 3 columns.
- auto inptr1 = inptr0 + 1*input_row_stride;
- auto inptr2 = inptr0 + 2*input_row_stride;
- auto inptr3 = inptr0 + 3*input_row_stride;
-
- // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
- // offsets to access the intermediate matrices.
- auto outptr1 = outptr0 + matrix_stride * 4;
- auto outptr2 = outptr0 + matrix_stride * 8;
- auto outptr3 = outptr0 + matrix_stride * 12;
-
- for (; n_channels > 3; n_channels -= 4) {
- asm volatile (
- "X_12 .req v1\n" "qX_12 .req q1\n"
- "X_13 .req v2\n" "qX_13 .req q2\n"
- "X_14 .req v3\n" "qX_14 .req q3\n"
- "X_22 .req v5\n" "qX_22 .req q5\n"
- "X_23 .req v6\n" "qX_23 .req q6\n"
- "X_24 .req v7\n" "qX_24 .req q7\n"
- "X_32 .req v9\n" "qX_32 .req q9\n"
- "X_33 .req v10\n" "qX_33 .req q10\n"
- "X_34 .req v11\n" "qX_34 .req q11\n"
- "X_42 .req v13\n" "qX_42 .req q13\n"
- "X_43 .req v14\n" "qX_43 .req q14\n"
- "X_44 .req v15\n" "qX_44 .req q15\n"
- "xX_11 .req v16\n"
- "xX_12 .req v17\n"
- "xX_13 .req v18\n"
- "xX_14 .req v19\n"
- "xX_21 .req v20\n"
- "xX_22 .req v21\n"
- "xX_23 .req v22\n"
- "xX_24 .req v23\n"
- "xX_31 .req v24\n"
- "xX_32 .req v25\n"
- "xX_33 .req v26\n"
- "xX_34 .req v27\n"
- "xX_41 .req v28\n"
- "xX_42 .req v29\n"
- "xX_43 .req v30\n"
- "xX_44 .req v31\n"
- " U .req v0\n"
- "qU .req q0\n"
-
- // Load the tile, and compute compute the matrix xX
- "ldr qX_12, [%x[inptr0]]\n"
- "ldr qX_13, [%x[inptr0], %x[colstride1]]\n"
- "ldr qX_14, [%x[inptr0], %x[colstride2]]\n"
- "add %x[inptr0], %x[inptr0], #0x10\n"
-
- "fneg xX_11.4s, x_13.4s\n"
- "ldr qX_22, [%x[inptr1]]\n"
- "fadd xX_12.4s, x_12.4s, x_13.4s\n"
- "ldr qX_23, [%x[inptr1], %x[colstride1]]\n"
- "fsub xX_13.4s, x_13.4s, x_12.4s\n"
- "ldr qX_24, [%x[inptr1], %x[colstride2]]\n"
- "fsub xX_14.4s, x_12.4s, x_14.4s\n"
- "add %x[inptr1], %x[inptr1], #0x10\n"
-
- "fneg xX_21.4s, x_23.4s\n"
- "ldr qX_32, [%x[inptr2]]\n"
- "fadd xX_22.4s, x_22.4s, x_23.4s\n"
- "ldr qX_33, [%x[inptr2], %x[colstride1]]\n"
- "fsub xX_23.4s, x_23.4s, x_22.4s\n"
- "ldr qX_34, [%x[inptr2], %x[colstride2]]\n"
- "fsub xX_24.4s, x_22.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], #0x10\n"
-
- "fneg xX_31.4s, x_33.4s\n"
- "ldr qX_42, [%x[inptr3]]\n"
- "fadd xX_32.4s, x_32.4s, x_33.4s\n"
- "ldr qX_43, [%x[inptr3], %x[colstride1]]\n"
- "fsub xX_33.4s, x_33.4s, x_32.4s\n"
- "ldr qX_44, [%x[inptr3], %x[colstride2]]\n"
- "fsub xX_34.4s, x_32.4s, x_34.4s\n"
- "add %x[inptr3], %x[inptr3], #0x10\n"
-
- // Complete computing xX while beginning to compute and store
- // $U = X.T x X$
-
- "fneg xX_41.4s, x_43.4s\n"
-
- "fsub U.4s, xX_11.4s, xX_31.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fsub U.4s, xX_12.4s, xX_32.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, xX_13.4s, xX_33.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, xX_14.4s, xX_34.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], #0x10\n"
-
- "fadd xX_42.4s, x_42.4s, x_43.4s\n"
-
- "fadd U.4s, xX_21.4s, xX_31.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, xX_22.4s, xX_32.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fadd U.4s, xX_23.4s, xX_33.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fadd U.4s, xX_24.4s, xX_34.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], #0x10\n"
-
- "fsub xX_43.4s, x_43.4s, x_42.4s\n"
-
- "fsub U.4s, xX_31.4s, xX_21.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fsub U.4s, xX_32.4s, xX_22.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, xX_33.4s, xX_23.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, xX_34.4s, xX_24.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], #0x10\n"
-
- "fsub xX_44.4s, x_42.4s, x_44.4s\n"
-
- "fsub U.4s, xX_21.4s, xX_41.4s\n"
- "str qU, [%x[outptr12]]\n"
- "fsub U.4s, xX_22.4s, xX_42.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "fsub U.4s, xX_23.4s, xX_43.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "fsub U.4s, xX_24.4s, xX_44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
- "add %x[outptr12], %x[outptr12], #0x10\n"
-
- ".unreq X_12\n" ".unreq qX_12\n"
- ".unreq X_13\n" ".unreq qX_13\n"
- ".unreq X_14\n" ".unreq qX_14\n"
- ".unreq X_22\n" ".unreq qX_22\n"
- ".unreq X_23\n" ".unreq qX_23\n"
- ".unreq X_24\n" ".unreq qX_24\n"
- ".unreq X_32\n" ".unreq qX_32\n"
- ".unreq X_33\n" ".unreq qX_33\n"
- ".unreq X_34\n" ".unreq qX_34\n"
- ".unreq X_42\n" ".unreq qX_42\n"
- ".unreq X_43\n" ".unreq qX_43\n"
- ".unreq X_44\n" ".unreq qX_44\n"
- ".unreq xX_11\n"
- ".unreq xX_12\n"
- ".unreq xX_13\n"
- ".unreq xX_14\n"
- ".unreq xX_21\n"
- ".unreq xX_22\n"
- ".unreq xX_23\n"
- ".unreq xX_24\n"
- ".unreq xX_31\n"
- ".unreq xX_32\n"
- ".unreq xX_33\n"
- ".unreq xX_34\n"
- ".unreq xX_41\n"
- ".unreq xX_42\n"
- ".unreq xX_43\n"
- ".unreq xX_44\n"
- ".unreq U\n"
- ".unreq qU\n"
-
- : [inptr0] "+r" (inptr0),
- [inptr1] "+r" (inptr1),
- [inptr2] "+r" (inptr2),
- [inptr3] "+r" (inptr3),
- [outptr0] "+r" (outptr0),
- [outptr4] "+r" (outptr1),
- [outptr8] "+r" (outptr2),
- [outptr12] "+r" (outptr3)
- : [colstride1] "r" (input_col_stride * sizeof(float)),
- [colstride2] "r" (input_col_stride * sizeof(float) * 2),
- [mstride1] "r" (matrix_stride * sizeof(float)),
- [mstride2] "r" (matrix_stride * sizeof(float) * 2),
- [mstride3] "r" (matrix_stride * sizeof(float) * 3)
- : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
- "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
- "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
- "v30", "v31"
- );
- }
-}
-
-// Pad bottom by 1
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 0, 1, 0, 4>(
- int &n_channels, // Number of channels in the tile
- const float* &inptr0,
- const int input_row_stride,
- const int input_col_stride,
- float* &outptr0,
- const int matrix_stride
-) {
- // We use 4 pointers to point to the starting position on each row and use
- // three offsets to extract elements from each of the other 3 columns.
- auto inptr1 = inptr0 + 1*input_row_stride;
- auto inptr2 = inptr0 + 2*input_row_stride;
-
- // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
- // offsets to access the intermediate matrices.
- auto outptr1 = outptr0 + matrix_stride * 4;
- auto outptr2 = outptr0 + matrix_stride * 8;
- auto outptr3 = outptr0 + matrix_stride * 12;
-
- for (; n_channels > 3; n_channels -= 4) {
- asm volatile (
- "X_11 .req v0\n" "qX_11 .req q0\n"
- "X_12 .req v1\n" "qX_12 .req q1\n"
- "X_13 .req v2\n" "qX_13 .req q2\n"
- "X_14 .req v3\n" "qX_14 .req q3\n"
- "X_21 .req v4\n" "qX_21 .req q4\n"
- "X_22 .req v5\n" "qX_22 .req q5\n"
- "X_23 .req v6\n" "qX_23 .req q6\n"
- "X_24 .req v7\n" "qX_24 .req q7\n"
- "X_31 .req v8\n" "qX_31 .req q8\n"
- "X_32 .req v9\n" "qX_32 .req q9\n"
- "X_33 .req v10\n" "qX_33 .req q10\n"
- "X_34 .req v11\n" "qX_34 .req q11\n"
- "xX_11 .req v16\n"
- "xX_12 .req v17\n"
- "xX_13 .req v18\n"
- "xX_14 .req v19\n"
- "xX_21 .req v20\n" "qxX_21 .req q20\n"
- "xX_22 .req v21\n" "qxX_22 .req q21\n"
- "xX_23 .req v22\n" "qxX_23 .req q22\n"
- "xX_24 .req v23\n" "qxX_24 .req q23\n"
- "xX_31 .req v24\n"
- "xX_32 .req v25\n"
- "xX_33 .req v26\n"
- "xX_34 .req v27\n"
- " U .req v0\n"
- "qU .req q0\n"
-
- // Load the tile, and compute compute the matrix xX
- "ldr qX_11, [%x[inptr0]]\n"
- "ldr qX_12, [%x[inptr0], %x[colstride1]]\n"
- "ldr qX_13, [%x[inptr0], %x[colstride2]]\n"
- "ldr qX_14, [%x[inptr0], %x[colstride3]]\n"
- "add %x[inptr0], %x[inptr0], #0x10\n"
-
- "ldr qX_21, [%x[inptr1]]\n"
- "fsub xX_11.4s, x_11.4s, x_13.4s\n"
- "ldr qX_22, [%x[inptr1], %x[colstride1]]\n"
- "fadd xX_12.4s, x_12.4s, x_13.4s\n"
- "ldr qX_23, [%x[inptr1], %x[colstride2]]\n"
- "fsub xX_13.4s, x_13.4s, x_12.4s\n"
- "ldr qX_24, [%x[inptr1], %x[colstride3]]\n"
- "fsub xX_14.4s, x_12.4s, x_14.4s\n"
- "add %x[inptr1], %x[inptr1], #0x10\n"
-
- "ldr qX_31, [%x[inptr2]]\n"
- "fsub xX_21.4s, x_21.4s, x_23.4s\n"
- "ldr qX_32, [%x[inptr2], %x[colstride1]]\n"
- "fadd xX_22.4s, x_22.4s, x_23.4s\n"
- "ldr qX_33, [%x[inptr2], %x[colstride2]]\n"
- "fsub xX_23.4s, x_23.4s, x_22.4s\n"
- "ldr qX_34, [%x[inptr2], %x[colstride3]]\n"
- "fsub xX_24.4s, x_22.4s, x_24.4s\n"
- "add %x[inptr2], %x[inptr2], #0x10\n"
-
- "fsub xX_31.4s, x_31.4s, x_33.4s\n"
- "fadd xX_32.4s, x_32.4s, x_33.4s\n"
- "fsub xX_33.4s, x_33.4s, x_32.4s\n"
- "fsub xX_34.4s, x_32.4s, x_34.4s\n"
-
- // Complete computing xX while beginning to compute and store
- // $U = X.T x X$
-
- "fsub U.4s, xX_11.4s, xX_31.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fsub U.4s, xX_12.4s, xX_32.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, xX_13.4s, xX_33.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, xX_14.4s, xX_34.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], #0x10\n"
-
- "fadd U.4s, xX_21.4s, xX_31.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, xX_22.4s, xX_32.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fadd U.4s, xX_23.4s, xX_33.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fadd U.4s, xX_24.4s, xX_34.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], #0x10\n"
-
- "fsub U.4s, xX_31.4s, xX_21.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fsub U.4s, xX_32.4s, xX_22.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, xX_33.4s, xX_23.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, xX_34.4s, xX_24.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], #0x10\n"
-
- "str qxX_21, [%x[outptr12]]\n"
- "str qxX_22, [%x[outptr12], %x[mstride1]]\n"
- "str qxX_23, [%x[outptr12], %x[mstride2]]\n"
- "str qxX_24, [%x[outptr12], %x[mstride3]]\n"
- "add %x[outptr12], %x[outptr12], #0x10\n"
-
- ".unreq qU\n"
- ".unreq U\n"
- ".unreq X_11\n" ".unreq qX_11\n"
- ".unreq X_12\n" ".unreq qX_12\n"
- ".unreq X_13\n" ".unreq qX_13\n"
- ".unreq X_14\n" ".unreq qX_14\n"
- ".unreq X_21\n" ".unreq qX_21\n"
- ".unreq X_22\n" ".unreq qX_22\n"
- ".unreq X_23\n" ".unreq qX_23\n"
- ".unreq X_24\n" ".unreq qX_24\n"
- ".unreq X_31\n" ".unreq qX_31\n"
- ".unreq X_32\n" ".unreq qX_32\n"
- ".unreq X_33\n" ".unreq qX_33\n"
- ".unreq X_34\n" ".unreq qX_34\n"
- ".unreq xX_11\n"
- ".unreq xX_12\n"
- ".unreq xX_13\n"
- ".unreq xX_14\n"
- ".unreq xX_21\n" ".unreq qxX_21\n"
- ".unreq xX_22\n" ".unreq qxX_22\n"
- ".unreq xX_23\n" ".unreq qxX_23\n"
- ".unreq xX_24\n" ".unreq qxX_24\n"
- ".unreq xX_31\n"
- ".unreq xX_32\n"
- ".unreq xX_33\n"
- ".unreq xX_34\n"
-
- : [inptr0] "+r" (inptr0),
- [inptr1] "+r" (inptr1),
- [inptr2] "+r" (inptr2),
- [outptr0] "+r" (outptr0),
- [outptr4] "+r" (outptr1),
- [outptr8] "+r" (outptr2),
- [outptr12] "+r" (outptr3)
- : [colstride1] "r" (input_col_stride * sizeof(float)),
- [colstride2] "r" (input_col_stride * sizeof(float) * 2),
- [colstride3] "r" (input_col_stride * sizeof(float) * 3),
- [mstride1] "r" (matrix_stride * sizeof(float)),
- [mstride2] "r" (matrix_stride * sizeof(float) * 2),
- [mstride3] "r" (matrix_stride * sizeof(float) * 3)
- : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
- "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
- "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
- "v30", "v31"
- );
- }
-}
-
-// Pad right by 1
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 0, 0, 1, 4>(
- int &n_channels, // Number of channels in the tile
- const float* &inptr0,
- const int input_row_stride,
- const int input_col_stride,
- float* &outptr0,
- const int matrix_stride
-) {
- // We use 4 pointers to point to the starting position on each row and use
- // three offsets to extract elements from each of the other 3 columns.
- auto inptr1 = inptr0 + 1*input_row_stride;
- auto inptr2 = inptr0 + 2*input_row_stride;
- auto inptr3 = inptr0 + 3*input_row_stride;
-
- // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
- // offsets to access the intermediate matrices.
- auto outptr1 = outptr0 + matrix_stride * 4;
- auto outptr2 = outptr0 + matrix_stride * 8;
- auto outptr3 = outptr0 + matrix_stride * 12;
-
- for (; n_channels > 3; n_channels -= 4) {
- asm volatile (
- "X_11 .req v0\n" "qX_11 .req q0\n"
- "X_12 .req v1\n" "qX_12 .req q1\n"
- "X_13 .req v2\n" "qX_13 .req q2\n"
- "X_21 .req v4\n" "qX_21 .req q4\n"
- "X_22 .req v5\n" "qX_22 .req q5\n"
- "X_23 .req v6\n" "qX_23 .req q6\n"
- "X_31 .req v8\n" "qX_31 .req q8\n"
- "X_32 .req v9\n" "qX_32 .req q9\n"
- "X_33 .req v10\n" "qX_33 .req q10\n"
- "X_41 .req v12\n" "qX_41 .req q12\n"
- "X_42 .req v13\n" "qX_42 .req q13\n"
- "X_43 .req v14\n" "qX_43 .req q14\n"
- "xX_11 .req v16\n"
- "xX_12 .req v17\n"
- "xX_13 .req v18\n"
- "xX_14 .req x_12\n"
- "xX_21 .req v20\n"
- "xX_22 .req v21\n"
- "xX_23 .req v22\n"
- "xX_24 .req x_22\n"
- "xX_31 .req v24\n"
- "xX_32 .req v25\n"
- "xX_33 .req v26\n"
- "xX_34 .req x_32\n"
- "xX_41 .req v28\n"
- "xX_42 .req v29\n"
- "xX_43 .req v30\n"
- "xX_44 .req x_42\n"
- " U .req v0\n"
- "qU .req q0\n"
-
- // Load the tile, and compute compute the matrix xX
- "ldr qX_11, [%x[inptr0]]\n"
- "ldr qX_12, [%x[inptr0], %x[colstride1]]\n"
- "ldr qX_13, [%x[inptr0], %x[colstride2]]\n"
- "add %x[inptr0], %x[inptr0], #0x10\n"
-
- "ldr qX_21, [%x[inptr1]]\n"
- "fsub xX_11.4s, x_11.4s, x_13.4s\n"
- "ldr qX_22, [%x[inptr1], %x[colstride1]]\n"
- "fadd xX_12.4s, x_12.4s, x_13.4s\n"
- "ldr qX_23, [%x[inptr1], %x[colstride2]]\n"
- "fsub xX_13.4s, x_13.4s, x_12.4s\n"
- "add %x[inptr1], %x[inptr1], #0x10\n"
-
- "ldr qX_31, [%x[inptr2]]\n"
- "fsub xX_21.4s, x_21.4s, x_23.4s\n"
- "ldr qX_32, [%x[inptr2], %x[colstride1]]\n"
- "fadd xX_22.4s, x_22.4s, x_23.4s\n"
- "ldr qX_33, [%x[inptr2], %x[colstride2]]\n"
- "fsub xX_23.4s, x_23.4s, x_22.4s\n"
- "add %x[inptr2], %x[inptr2], #0x10\n"
-
- "ldr qX_41, [%x[inptr3]]\n"
- "fsub xX_31.4s, x_31.4s, x_33.4s\n"
- "ldr qX_42, [%x[inptr3], %x[colstride1]]\n"
- "fadd xX_32.4s, x_32.4s, x_33.4s\n"
- "ldr qX_43, [%x[inptr3], %x[colstride2]]\n"
- "fsub xX_33.4s, x_33.4s, x_32.4s\n"
- "add %x[inptr3], %x[inptr3], #0x10\n"
-
- // Complete computing xX while beginning to compute and store
- // $U = X.T x X$
-
- "fsub xX_41.4s, x_41.4s, x_43.4s\n"
-
- "fsub U.4s, xX_11.4s, xX_31.4s\n"
- "str qU, [%x[outptr0]]\n"
- "fsub U.4s, xX_12.4s, xX_32.4s\n"
- "str qU, [%x[outptr0], %x[mstride1]]\n"
- "fsub U.4s, xX_13.4s, xX_33.4s\n"
- "str qU, [%x[outptr0], %x[mstride2]]\n"
- "fsub U.4s, xX_14.4s, xX_34.4s\n"
- "str qU, [%x[outptr0], %x[mstride3]]\n"
- "add %x[outptr0], %x[outptr0], #0x10\n"
-
- "fadd xX_42.4s, x_42.4s, x_43.4s\n"
-
- "fadd U.4s, xX_21.4s, xX_31.4s\n"
- "str qU, [%x[outptr4]]\n"
- "fadd U.4s, xX_22.4s, xX_32.4s\n"
- "str qU, [%x[outptr4], %x[mstride1]]\n"
- "fadd U.4s, xX_23.4s, xX_33.4s\n"
- "str qU, [%x[outptr4], %x[mstride2]]\n"
- "fadd U.4s, xX_24.4s, xX_34.4s\n"
- "str qU, [%x[outptr4], %x[mstride3]]\n"
- "add %x[outptr4], %x[outptr4], #0x10\n"
-
- "fsub xX_43.4s, x_43.4s, x_42.4s\n"
-
- "fsub U.4s, xX_31.4s, xX_21.4s\n"
- "str qU, [%x[outptr8]]\n"
- "fsub U.4s, xX_32.4s, xX_22.4s\n"
- "str qU, [%x[outptr8], %x[mstride1]]\n"
- "fsub U.4s, xX_33.4s, xX_23.4s\n"
- "str qU, [%x[outptr8], %x[mstride2]]\n"
- "fsub U.4s, xX_34.4s, xX_24.4s\n"
- "str qU, [%x[outptr8], %x[mstride3]]\n"
- "add %x[outptr8], %x[outptr8], #0x10\n"
-
- "fsub U.4s, xX_21.4s, xX_41.4s\n"
- "str qU, [%x[outptr12]]\n"
- "fsub U.4s, xX_22.4s, xX_42.4s\n"
- "str qU, [%x[outptr12], %x[mstride1]]\n"
- "fsub U.4s, xX_23.4s, xX_43.4s\n"
- "str qU, [%x[outptr12], %x[mstride2]]\n"
- "fsub U.4s, xX_24.4s, xX_44.4s\n"
- "str qU, [%x[outptr12], %x[mstride3]]\n"
- "add %x[outptr12], %x[outptr12], #0x10\n"
-
- ".unreq qU\n"
- ".unreq U\n"
- ".unreq X_11\n" ".unreq qX_11\n"
- ".unreq X_12\n" ".unreq qX_12\n"
- ".unreq X_13\n" ".unreq qX_13\n"
- ".unreq X_21\n" ".unreq qX_21\n"
- ".unreq X_22\n" ".unreq qX_22\n"
- ".unreq X_23\n" ".unreq qX_23\n"
- ".unreq X_31\n" ".unreq qX_31\n"
- ".unreq X_32\n" ".unreq qX_32\n"
- ".unreq X_33\n" ".unreq qX_33\n"
- ".unreq X_41\n" ".unreq qX_41\n"
- ".unreq X_42\n" ".unreq qX_42\n"
- ".unreq X_43\n" ".unreq qX_43\n"
- ".unreq xX_11\n"
- ".unreq xX_12\n"
- ".unreq xX_13\n"
- ".unreq xX_14\n"
- ".unreq xX_21\n"
- ".unreq xX_22\n"
- ".unreq xX_23\n"
- ".unreq xX_24\n"
- ".unreq xX_31\n"
- ".unreq xX_32\n"
- ".unreq xX_33\n"
- ".unreq xX_34\n"
- ".unreq xX_41\n"
- ".unreq xX_42\n"
- ".unreq xX_43\n"
- ".unreq xX_44\n"
-
- : [inptr0] "+r" (inptr0),
- [inptr1] "+r" (inptr1),
- [inptr2] "+r" (inptr2),
- [inptr3] "+r" (inptr3),
- [outptr0] "+r" (outptr0),
- [outptr4] "+r" (outptr1),
- [outptr8] "+r" (outptr2),
- [outptr12] "+r" (outptr3)
- : [colstride1] "r" (input_col_stride * sizeof(float)),
- [colstride2] "r" (input_col_stride * sizeof(float) * 2),
- [mstride1] "r" (matrix_stride * sizeof(float)),
- [mstride2] "r" (matrix_stride * sizeof(float) * 2),
- [mstride3] "r" (matrix_stride * sizeof(float) * 3)
- : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
- "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
- "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
- "v30", "v31"
- );
- }
-}
-}
-#endif
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp
deleted file mode 100644
index 033442aa14..0000000000
--- a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-namespace winograd {
- /* Transform a kernel into the Winograd domain.
- *
- * NOTE: It is assumed that the kernel is in the form [height x width x
- * input_channels x output_channel].
- */
- template <typename T>
- struct winograd2x2_3x3_gemm_kernel_transform_impl{
- static void execute(
- const KernelShape &shape,
- const T* const kernel,
- T* const matrix_base,
- const int matrix_stride,
- const int matrix_row_stride
- );
-
- protected:
- template <const int output_channel_tail>
- static void transform_kernel(
- const T* const kernel,
- const int n_input_channels,
- const int n_output_channels,
- T* const matrix_base,
- const int matrix_stride,
- const int matrix_row_stride
- );
- };
-}
-
-/*****************************************************************************/
-/* Transform a fp32 kernel into the Winograd domain.
- */
-#include "kernel_2x2_3x3/a64_float.hpp" // AArch64 specialisations
-
-namespace winograd
-{
-template <>
-inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::execute(
- const KernelShape &shape,
- const float* const kernel,
- float* const matrix_base,
- const int matrix_stride,
- const int matrix_row_stride
-) {
- // Delegate based on tail size
- const int n_input_channels = shape.n_input_channels;
- const int n_output_channels = shape.n_output_channels;
-
- switch (n_output_channels % 4) {
- case 0:
- transform_kernel<0>(
- kernel, n_input_channels, n_output_channels,
- matrix_base, matrix_stride, matrix_row_stride
- );
- break;
- case 1:
- transform_kernel<1>(
- kernel, n_input_channels, n_output_channels,
- matrix_base, matrix_stride, matrix_row_stride
- );
- break;
- case 2:
- transform_kernel<2>(
- kernel, n_input_channels, n_output_channels,
- matrix_base, matrix_stride, matrix_row_stride
- );
- break;
- case 3:
- transform_kernel<3>(
- kernel, n_input_channels, n_output_channels,
- matrix_base, matrix_stride, matrix_row_stride
- );
- break;
- default:
- ARM_COMPUTE_ERROR("Cannot happen");
- break;
- }
-}
-
-template <>
-template<const int output_channel_tail>
-inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel(
- const float* const kernel,
- const int n_input_channels,
- const int n_output_channels,
- float* const matrix_base,
- const int mstride,
- const int matrix_row_stride
-) {
- // Use one input pointer for each row of the kernel, use two additional
- // offsets to extract columns.
- const int kernel_col_stride = n_input_channels * n_output_channels;
- const int kernel_row_stride = 3 * kernel_col_stride;
- const float *inptr0 = kernel;
- const float *inptr1 = kernel + kernel_row_stride;
- const float *inptr2 = kernel + kernel_row_stride*2;
-
- // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
- // offsets to extract further matrices.
- float *outptr0 = matrix_base;
- float *outptr4 = matrix_base + mstride * 4;
- float *outptr8 = matrix_base + mstride * 8;
- float *outptr12 = matrix_base + mstride * 12;
-
- // For every input channel
- for (int in_c = 0; in_c < n_input_channels; in_c++) {
- // For every output channel
- for (int c = 0; c < n_output_channels; c++) {
- // Read in the kernel
- float w11 = inptr0[0], w12 = inptr0[kernel_col_stride], w13 = inptr0[kernel_col_stride*2];
- float w21 = inptr1[0], w22 = inptr1[kernel_col_stride], w23 = inptr1[kernel_col_stride*2];
- float w31 = inptr2[0], w32 = inptr2[kernel_col_stride], w33 = inptr2[kernel_col_stride*2];
-
- // Progress input pointers
- inptr0++;
- inptr1++;
- inptr2++;
-
- // Compute the kernel W w, note we need only compute the middle two rows
- // (2 and 3) because the first and last rows are merely copies of values
- // from the matrix w.
- float Ww11 = w11, Ww12 = w12, Ww13 = w13;
- float Ww21 = 0.5*(w11 + w21 + w31), Ww22 = 0.5*(w12 + w22 + w32), Ww23 = 0.5*(w13 + w23 + w33);
- float Ww31 = 0.5*(w11 - w21 + w31), Ww32 = 0.5*(w12 - w22 + w32), Ww33 = 0.5*(w13 - w23 + w33);
- float Ww41 = w31, Ww42 = w32, Ww43 = w33;
-
- // Hence compute W w W.T; again note we need compute only the middle two
- // columns since the first and last columns are copies of the first and
- // last columns of the previous matrix.
- float WwWT11 = Ww11, WwWT12 = 0.5*(Ww11 + Ww12 + Ww13), WwWT13 = 0.5*(Ww11 - Ww12 + Ww13), WwWT14 = Ww13;
- float WwWT21 = Ww21, WwWT22 = 0.5*(Ww21 + Ww22 + Ww23), WwWT23 = 0.5*(Ww21 - Ww22 + Ww23), WwWT24 = Ww23;
- float WwWT31 = Ww31, WwWT32 = 0.5*(Ww31 + Ww32 + Ww33), WwWT33 = 0.5*(Ww31 - Ww32 + Ww33), WwWT34 = Ww33;
- float WwWT41 = Ww41, WwWT42 = 0.5*(Ww41 + Ww42 + Ww43), WwWT43 = 0.5*(Ww41 - Ww42 + Ww43), WwWT44 = Ww43;
-
- // Store the computed weights
- outptr0[0 * mstride] = WwWT11;
- outptr0[1 * mstride] = WwWT12;
- outptr0[2 * mstride] = WwWT13;
- outptr0[3 * mstride] = WwWT14;
-
- outptr4[0 * mstride] = WwWT21;
- outptr4[1 * mstride] = WwWT22;
- outptr4[2 * mstride] = WwWT23;
- outptr4[3 * mstride] = WwWT24;
-
- outptr8[0 * mstride] = WwWT31;
- outptr8[1 * mstride] = WwWT32;
- outptr8[2 * mstride] = WwWT33;
- outptr8[3 * mstride] = WwWT34;
-
- outptr12[0 * mstride] = WwWT41;
- outptr12[1 * mstride] = WwWT42;
- outptr12[2 * mstride] = WwWT43;
- outptr12[3 * mstride] = WwWT44;
-
- // Progress output pointers
- outptr0++;
- outptr4++;
- outptr8++;
- outptr12++;
- }
-
- // Progression to complete stride
- outptr0 += matrix_row_stride - n_output_channels;
- outptr4 += matrix_row_stride - n_output_channels;
- outptr8 += matrix_row_stride - n_output_channels;
- outptr12 += matrix_row_stride - n_output_channels;
- }
-}
-}
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp
deleted file mode 100644
index 3dd62d1ac1..0000000000
--- a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp
+++ /dev/null
@@ -1,822 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-#ifdef __aarch64__
-namespace winograd {
-template <>
-template <>
-inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel<0>(
- const float* const kernel,
- const int n_input_channels,
- const int n_output_channels,
- float* const matrix_base,
- const int mstride,
- const int matrix_row_stride
-) {
- // Use one input pointer for each row of the kernel, use two additional
- // offsets to extract columns.
- const int kernel_col_stride = n_input_channels * n_output_channels;
- const int kernel_row_stride = 3 * kernel_col_stride;
- const float *inptr0 = kernel;
- const float *inptr1 = kernel + kernel_row_stride;
- const float *inptr2 = kernel + kernel_row_stride*2;
-
- // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
- // offsets to extract further matrices.
- float *outptr0 = matrix_base;
- float *outptr4 = matrix_base + mstride * 4;
- float *outptr8 = matrix_base + mstride * 8;
- float *outptr12 = matrix_base + mstride * 12;
-
- // For every input channel
- for (int in_c = 0; in_c < n_input_channels; in_c++) {
- int n_remaining_channels = n_output_channels;
-
- asm volatile (
- // Registers into which to read the kernel
- "w_11 .req v0\n" "qw_11 .req q0\n"
- "w_12 .req v1\n" "qw_12 .req q1\n"
- "w_13 .req v2\n" "qw_13 .req q2\n"
- "w_21 .req v3\n" "qw_21 .req q3\n"
- "w_22 .req v4\n" "qw_22 .req q4\n"
- "w_23 .req v5\n" "qw_23 .req q5\n"
- "w_31 .req v6\n" "qw_31 .req q6\n"
- "w_32 .req v7\n" "qw_32 .req q7\n"
- "w_33 .req v8\n" "qw_33 .req q8\n"
-
- // Transformed matrix Ww
- "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n"
- "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n"
- "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n"
- "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n"
-
- // Output matrix U = WwWT
- "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n"
- "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n"
- "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n"
- "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n"
-
- // Storage view of output matrices
- "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n"
- "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n"
- "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n"
- "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n"
-
- "half .req v23\n" // {0.5, ..., 0.5}
- "dup half.4s, %w[one_half]\n"
- "scratch .req v24\n"
-
- "1:"
- // Load tile of the kernel
- "ldr qw_11, [%x[inptr0]]\n"
- "str qU11, [%x[outptr0]]\n"
- "ldr qw_12, [%x[inptr0], %x[colstride1]]\n"
- "ldr qw_13, [%x[inptr0], %x[colstride2]]\n"
- "str qU14, [%x[outptr0], %x[mstride3]]\n"
- "add %x[inptr0], %x[inptr0], #0x10\n"
-
- "ldr qw_21, [%x[inptr1]]\n"
- "ldr qw_22, [%x[inptr1], %x[colstride1]]\n"
- "ldr qw_23, [%x[inptr1], %x[colstride2]]\n"
- "add %x[inptr1], %x[inptr1], #0x10\n"
-
- "ldr qw_31, [%x[inptr2]]\n"
- "str qU41, [%x[outptr12]]\n"
- "ldr qw_32, [%x[inptr2], %x[colstride1]]\n"
- "ldr qw_33, [%x[inptr2], %x[colstride2]]\n"
- "str qU44, [%x[outptr12], %x[mstride3]]\n"
- "add %x[inptr2], %x[inptr2], #0x10\n"
-
- // Compute 2nd and 3rd rows of Ww
- "fadd scratch.4s, w_11.4s, w_31.4s\n"
- "fmul Ww21.4s, scratch.4s, half.4s\n"
- "fmla Ww21.4s, w_21.4s, half.4s\n"
- "str qU21, [%x[outptr4]]\n"
- "fmul Ww31.4s, scratch.4s, half.4s\n"
- "fmls Ww31.4s, w_21.4s, half.4s\n"
- "str qU31, [%x[outptr8]]\n"
-
- "fadd scratch.4s, w_12.4s, w_32.4s\n"
- "fmul Ww22.4s, scratch.4s, half.4s\n"
- "fmla Ww22.4s, w_22.4s, half.4s\n"
- "fmul Ww32.4s, scratch.4s, half.4s\n"
- "fmls Ww32.4s, w_22.4s, half.4s\n"
-
- "fadd scratch.4s, w_13.4s, w_33.4s\n"
- "fmul Ww23.4s, scratch.4s, half.4s\n"
- "fmla Ww23.4s, w_23.4s, half.4s\n"
- "str qU24, [%x[outptr4], %x[mstride3]]\n"
- "fmul Ww33.4s, scratch.4s, half.4s\n"
- "fmls Ww33.4s, w_23.4s, half.4s\n"
- "str qU34, [%x[outptr8], %x[mstride3]]\n"
-
- // Compute and store U, only need to compute the 2nd and 3rd columns
- // of U and update output pointers
- "fadd scratch.4s, Ww11.4s, Ww13.4s\n"
- "fmul U12.4s, scratch.4s, half.4s\n"
- "fmla U12.4s, Ww12.4s, half.4s\n"
- "str qU12, [%x[outptr0], %x[mstride1]]\n"
- "fmul U13.4s, scratch.4s, half.4s\n"
- "fmls U13.4s, Ww12.4s, half.4s\n"
- "str qU13, [%x[outptr0], %x[mstride2]]\n"
- "add %x[outptr0], %x[outptr0], #0x10\n"
-
- "fadd scratch.4s, Ww21.4s, Ww23.4s\n"
- "fmul U22.4s, scratch.4s, half.4s\n"
- "fmla U22.4s, Ww22.4s, half.4s\n"
- "str qU22, [%x[outptr4], %x[mstride1]]\n"
- "fmul U23.4s, scratch.4s, half.4s\n"
- "fmls U23.4s, Ww22.4s, half.4s\n"
- "str qU23, [%x[outptr4], %x[mstride2]]\n"
- "add %x[outptr4], %x[outptr4], #0x10\n"
-
- "fadd scratch.4s, Ww31.4s, Ww33.4s\n"
- "fmul U32.4s, scratch.4s, half.4s\n"
- "fmla U32.4s, Ww32.4s, half.4s\n"
- "str qU32, [%x[outptr8], %x[mstride1]]\n"
- "fmul U33.4s, scratch.4s, half.4s\n"
- "fmls U33.4s, Ww32.4s, half.4s\n"
- "str qU33, [%x[outptr8], %x[mstride2]]\n"
- "add %x[outptr8], %x[outptr8], #0x10\n"
-
- "fadd scratch.4s, Ww41.4s, Ww43.4s\n"
- "fmul U42.4s, scratch.4s, half.4s\n"
- "fmla U42.4s, Ww42.4s, half.4s\n"
- "str qU42, [%x[outptr12], %x[mstride1]]\n"
- "fmul U43.4s, scratch.4s, half.4s\n"
- "fmls U43.4s, Ww42.4s, half.4s\n"
- "str qU43, [%x[outptr12], %x[mstride2]]\n"
- "add %x[outptr12], %x[outptr12], #0x10\n"
-
- "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n"
- "bne 1b\n"
-
- // Clear aliases
- ".unreq half\n"
- ".unreq scratch\n"
- ".unreq w_11\n" ".unreq qw_11\n"
- ".unreq w_12\n" ".unreq qw_12\n"
- ".unreq w_13\n" ".unreq qw_13\n"
- ".unreq w_21\n" ".unreq qw_21\n"
- ".unreq w_22\n" ".unreq qw_22\n"
- ".unreq w_23\n" ".unreq qw_23\n"
- ".unreq w_31\n" ".unreq qw_31\n"
- ".unreq w_32\n" ".unreq qw_32\n"
- ".unreq w_33\n" ".unreq qw_33\n"
- ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n"
- ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n"
- ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n"
- ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n"
- ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n"
- ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n"
- ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n"
- ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n"
- ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n"
- ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n"
- ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n"
- ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n"
-
- : [inptr0] "+r" (inptr0),
- [inptr1] "+r" (inptr1),
- [inptr2] "+r" (inptr2),
- [outptr0] "+r" (outptr0),
- [outptr4] "+r" (outptr4),
- [outptr8] "+r" (outptr8),
- [outptr12] "+r" (outptr12),
- [n_remaining_channels] "+r" (n_remaining_channels)
- : [mstride1] "r" (sizeof(float) * mstride),
- [mstride2] "r" (sizeof(float) * mstride * 2),
- [mstride3] "r" (sizeof(float) * mstride * 3),
- [colstride1] "r" (sizeof(float) * kernel_col_stride),
- [colstride2] "r" (sizeof(float) * kernel_col_stride * 2),
- [one_half] "r" (0.5f)
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
- "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
- "v20", "v21", "v22", "v23", "v24"
- );
-
- // Progression to complete stride
- outptr0 += matrix_row_stride - n_output_channels;
- outptr4 += matrix_row_stride - n_output_channels;
- outptr8 += matrix_row_stride - n_output_channels;
- outptr12 += matrix_row_stride - n_output_channels;
- }
-}
-
-template <>
-template <>
-inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel<2>(
- const float* const kernel,
- const int n_input_channels,
- const int n_output_channels,
- float* const matrix_base,
- const int mstride,
- const int matrix_row_stride
-) {
- // Use one input pointer for each row of the kernel, use two additional
- // offsets to extract columns.
- const int kernel_col_stride = n_input_channels * n_output_channels;
- const int kernel_row_stride = 3 * kernel_col_stride;
- const float *inptr0 = kernel;
- const float *inptr1 = kernel + kernel_row_stride;
- const float *inptr2 = kernel + kernel_row_stride*2;
-
- // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
- // offsets to extract further matrices.
- float *outptr0 = matrix_base;
- float *outptr4 = matrix_base + mstride * 4;
- float *outptr8 = matrix_base + mstride * 8;
- float *outptr12 = matrix_base + mstride * 12;
-
- // For every input channel
- for (int in_c = 0; in_c < n_input_channels; in_c++) {
- int n_remaining_channels = n_output_channels;
-
- asm volatile (
- // Registers into which to read the kernel
- "w_11 .req v0\n" "qw_11 .req q0\n" "dw_11 .req d0\n"
- "w_12 .req v1\n" "qw_12 .req q1\n" "dw_12 .req d1\n"
- "w_13 .req v2\n" "qw_13 .req q2\n" "dw_13 .req d2\n"
- "w_21 .req v3\n" "qw_21 .req q3\n" "dw_21 .req d3\n"
- "w_22 .req v4\n" "qw_22 .req q4\n" "dw_22 .req d4\n"
- "w_23 .req v5\n" "qw_23 .req q5\n" "dw_23 .req d5\n"
- "w_31 .req v6\n" "qw_31 .req q6\n" "dw_31 .req d6\n"
- "w_32 .req v7\n" "qw_32 .req q7\n" "dw_32 .req d7\n"
- "w_33 .req v8\n" "qw_33 .req q8\n" "dw_33 .req d8\n"
-
- // Transformed matrix Ww
- "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n"
- "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n"
- "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n"
- "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n"
-
- // Output matrix U = WwWT
- "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n"
- "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n"
- "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n"
- "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n"
-
- // Storage view of output matrices
- "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n"
- "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n"
- "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n"
- "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n"
-
- "dU11 .req d0\n" "dU12 .req d15\n" "dU13 .req d16\n" "dU14 .req d2\n"
- "dU21 .req d9\n" "dU22 .req d17\n" "dU23 .req d18\n" "dU24 .req d11\n"
- "dU31 .req d12\n" "dU32 .req d19\n" "dU33 .req d20\n" "dU34 .req d14\n"
- "dU41 .req d6\n" "dU42 .req d21\n" "dU43 .req d22\n" "dU44 .req d8\n"
-
- "half .req v23\n" // {0.5, ..., 0.5}
- "dup half.4s, %w[one_half]\n"
- "scratch .req v24\n"
-
- // Subtract the tail from the number of remaining channels and jump to
- // the tail if necessary.
- "subs %x[n_remaining_channels], %x[n_remaining_channels], #2\n"
- "beq 2f\n"
-
- "1:"
- // Load tile of the kernel
- "ldr qw_11, [%x[inptr0]]\n"
- "str qU11, [%x[outptr0]]\n"
- "ldr qw_12, [%x[inptr0], %x[colstride1]]\n"
- "ldr qw_13, [%x[inptr0], %x[colstride2]]\n"
- "str qU14, [%x[outptr0], %x[mstride3]]\n"
- "add %x[inptr0], %x[inptr0], #0x10\n"
-
- "ldr qw_21, [%x[inptr1]]\n"
- "ldr qw_22, [%x[inptr1], %x[colstride1]]\n"
- "ldr qw_23, [%x[inptr1], %x[colstride2]]\n"
- "add %x[inptr1], %x[inptr1], #0x10\n"
-
- "ldr qw_31, [%x[inptr2]]\n"
- "str qU41, [%x[outptr12]]\n"
- "ldr qw_32, [%x[inptr2], %x[colstride1]]\n"
- "ldr qw_33, [%x[inptr2], %x[colstride2]]\n"
- "str qU44, [%x[outptr12], %x[mstride3]]\n"
- "add %x[inptr2], %x[inptr2], #0x10\n"
-
- // Compute 2nd and 3rd rows of Ww
- "fadd scratch.4s, w_11.4s, w_31.4s\n"
- "fmul Ww21.4s, scratch.4s, half.4s\n"
- "fmla Ww21.4s, w_21.4s, half.4s\n"
- "str qU21, [%x[outptr4]]\n"
- "fmul Ww31.4s, scratch.4s, half.4s\n"
- "fmls Ww31.4s, w_21.4s, half.4s\n"
- "str qU31, [%x[outptr8]]\n"
-
- "fadd scratch.4s, w_12.4s, w_32.4s\n"
- "fmul Ww22.4s, scratch.4s, half.4s\n"
- "fmla Ww22.4s, w_22.4s, half.4s\n"
- "fmul Ww32.4s, scratch.4s, half.4s\n"
- "fmls Ww32.4s, w_22.4s, half.4s\n"
-
- "fadd scratch.4s, w_13.4s, w_33.4s\n"
- "fmul Ww23.4s, scratch.4s, half.4s\n"
- "fmla Ww23.4s, w_23.4s, half.4s\n"
- "str qU24, [%x[outptr4], %x[mstride3]]\n"
- "fmul Ww33.4s, scratch.4s, half.4s\n"
- "fmls Ww33.4s, w_23.4s, half.4s\n"
- "str qU34, [%x[outptr8], %x[mstride3]]\n"
-
- // Compute and store U, only need to compute the 2nd and 3rd columns
- // of U and update output pointers
- "fadd scratch.4s, Ww11.4s, Ww13.4s\n"
- "fmul U12.4s, scratch.4s, half.4s\n"
- "fmla U12.4s, Ww12.4s, half.4s\n"
- "str qU12, [%x[outptr0], %x[mstride1]]\n"
- "fmul U13.4s, scratch.4s, half.4s\n"
- "fmls U13.4s, Ww12.4s, half.4s\n"
- "str qU13, [%x[outptr0], %x[mstride2]]\n"
- "add %x[outptr0], %x[outptr0], #0x10\n"
-
- "fadd scratch.4s, Ww21.4s, Ww23.4s\n"
- "fmul U22.4s, scratch.4s, half.4s\n"
- "fmla U22.4s, Ww22.4s, half.4s\n"
- "str qU22, [%x[outptr4], %x[mstride1]]\n"
- "fmul U23.4s, scratch.4s, half.4s\n"
- "fmls U23.4s, Ww22.4s, half.4s\n"
- "str qU23, [%x[outptr4], %x[mstride2]]\n"
- "add %x[outptr4], %x[outptr4], #0x10\n"
-
- "fadd scratch.4s, Ww31.4s, Ww33.4s\n"
- "fmul U32.4s, scratch.4s, half.4s\n"
- "fmla U32.4s, Ww32.4s, half.4s\n"
- "str qU32, [%x[outptr8], %x[mstride1]]\n"
- "fmul U33.4s, scratch.4s, half.4s\n"
- "fmls U33.4s, Ww32.4s, half.4s\n"
- "str qU33, [%x[outptr8], %x[mstride2]]\n"
- "add %x[outptr8], %x[outptr8], #0x10\n"
-
- "fadd scratch.4s, Ww41.4s, Ww43.4s\n"
- "fmul U42.4s, scratch.4s, half.4s\n"
- "fmla U42.4s, Ww42.4s, half.4s\n"
- "str qU42, [%x[outptr12], %x[mstride1]]\n"
- "fmul U43.4s, scratch.4s, half.4s\n"
- "fmls U43.4s, Ww42.4s, half.4s\n"
- "str qU43, [%x[outptr12], %x[mstride2]]\n"
- "add %x[outptr12], %x[outptr12], #0x10\n"
-
- "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n"
- "bne 1b\n"
-
- // Tail size 2
- "2:"
- // Load tile of the kernel
- "ldr dw_11, [%x[inptr0]]\n"
- "str dU11, [%x[outptr0]]\n"
- "ldr dw_12, [%x[inptr0], %x[colstride1]]\n"
- "ldr dw_13, [%x[inptr0], %x[colstride2]]\n"
- "str dU14, [%x[outptr0], %x[mstride3]]\n"
- "add %x[inptr0], %x[inptr0], #0x08\n"
-
- "ldr dw_21, [%x[inptr1]]\n"
- "ldr dw_22, [%x[inptr1], %x[colstride1]]\n"
- "ldr dw_23, [%x[inptr1], %x[colstride2]]\n"
- "add %x[inptr1], %x[inptr1], #0x08\n"
-
- "ldr dw_31, [%x[inptr2]]\n"
- "str dU41, [%x[outptr12]]\n"
- "ldr dw_32, [%x[inptr2], %x[colstride1]]\n"
- "ldr dw_33, [%x[inptr2], %x[colstride2]]\n"
- "str dU44, [%x[outptr12], %x[mstride3]]\n"
- "add %x[inptr2], %x[inptr2], #0x08\n"
-
- // Compute 2nd and 3rd rows of Ww
- "fadd scratch.2s, w_11.2s, w_31.2s\n"
- "fmul Ww21.2s, scratch.2s, half.2s\n"
- "fmla Ww21.2s, w_21.2s, half.2s\n"
- "str dU21, [%x[outptr4]]\n"
- "fmul Ww31.2s, scratch.2s, half.2s\n"
- "fmls Ww31.2s, w_21.2s, half.2s\n"
- "str dU31, [%x[outptr8]]\n"
-
- "fadd scratch.2s, w_12.2s, w_32.2s\n"
- "fmul Ww22.2s, scratch.2s, half.2s\n"
- "fmla Ww22.2s, w_22.2s, half.2s\n"
- "fmul Ww32.2s, scratch.2s, half.2s\n"
- "fmls Ww32.2s, w_22.2s, half.2s\n"
-
- "fadd scratch.2s, w_13.2s, w_33.2s\n"
- "fmul Ww23.2s, scratch.2s, half.2s\n"
- "fmla Ww23.2s, w_23.2s, half.2s\n"
- "str dU24, [%x[outptr4], %x[mstride3]]\n"
- "fmul Ww33.2s, scratch.2s, half.2s\n"
- "fmls Ww33.2s, w_23.2s, half.2s\n"
- "str dU34, [%x[outptr8], %x[mstride3]]\n"
-
- // Compute and store U, only need to compute the 2nd and 3rd columns of
- // U and update output pointers
- "fadd scratch.2s, Ww11.2s, Ww13.2s\n"
- "fmul U12.2s, scratch.2s, half.2s\n"
- "fmla U12.2s, Ww12.2s, half.2s\n"
- "str dU12, [%x[outptr0], %x[mstride1]]\n"
- "fmul U13.2s, scratch.2s, half.2s\n"
- "fmls U13.2s, Ww12.2s, half.2s\n"
- "str dU13, [%x[outptr0], %x[mstride2]]\n"
- "add %x[outptr0], %x[outptr0], #0x08\n"
-
- "fadd scratch.2s, Ww21.2s, Ww23.2s\n"
- "fmul U22.2s, scratch.2s, half.2s\n"
- "fmla U22.2s, Ww22.2s, half.2s\n"
- "str dU22, [%x[outptr4], %x[mstride1]]\n"
- "fmul U23.2s, scratch.2s, half.2s\n"
- "fmls U23.2s, Ww22.2s, half.2s\n"
- "str dU23, [%x[outptr4], %x[mstride2]]\n"
- "add %x[outptr4], %x[outptr4], #0x08\n"
-
- "fadd scratch.2s, Ww31.2s, Ww33.2s\n"
- "fmul U32.2s, scratch.2s, half.2s\n"
- "fmla U32.2s, Ww32.2s, half.2s\n"
- "str dU32, [%x[outptr8], %x[mstride1]]\n"
- "fmul U33.2s, scratch.2s, half.2s\n"
- "fmls U33.2s, Ww32.2s, half.2s\n"
- "str dU33, [%x[outptr8], %x[mstride2]]\n"
- "add %x[outptr8], %x[outptr8], #0x08\n"
-
- "fadd scratch.2s, Ww41.2s, Ww43.2s\n"
- "fmul U42.2s, scratch.2s, half.2s\n"
- "fmla U42.2s, Ww42.2s, half.2s\n"
- "str dU42, [%x[outptr12], %x[mstride1]]\n"
- "fmul U43.2s, scratch.2s, half.2s\n"
- "fmls U43.2s, Ww42.2s, half.2s\n"
- "str dU43, [%x[outptr12], %x[mstride2]]\n"
- "add %x[outptr12], %x[outptr12], #0x08\n"
-
- // Clear aliases
- ".unreq half\n"
- ".unreq scratch\n"
- ".unreq w_11\n" ".unreq qw_11\n" ".unreq dw_11\n"
- ".unreq w_12\n" ".unreq qw_12\n" ".unreq dw_12\n"
- ".unreq w_13\n" ".unreq qw_13\n" ".unreq dw_13\n"
- ".unreq w_21\n" ".unreq qw_21\n" ".unreq dw_21\n"
- ".unreq w_22\n" ".unreq qw_22\n" ".unreq dw_22\n"
- ".unreq w_23\n" ".unreq qw_23\n" ".unreq dw_23\n"
- ".unreq w_31\n" ".unreq qw_31\n" ".unreq dw_31\n"
- ".unreq w_32\n" ".unreq qw_32\n" ".unreq dw_32\n"
- ".unreq w_33\n" ".unreq qw_33\n" ".unreq dw_33\n"
- ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n"
- ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n"
- ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n"
- ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n"
- ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n"
- ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n"
- ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n"
- ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n"
- ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n"
- ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n"
- ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n"
- ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n"
- ".unreq dU11\n" ".unreq dU12\n" ".unreq dU13\n" ".unreq dU14\n"
- ".unreq dU21\n" ".unreq dU22\n" ".unreq dU23\n" ".unreq dU24\n"
- ".unreq dU31\n" ".unreq dU32\n" ".unreq dU33\n" ".unreq dU34\n"
- ".unreq dU41\n" ".unreq dU42\n" ".unreq dU43\n" ".unreq dU44\n"
-
- : [inptr0] "+r" (inptr0),
- [inptr1] "+r" (inptr1),
- [inptr2] "+r" (inptr2),
- [outptr0] "+r" (outptr0),
- [outptr4] "+r" (outptr4),
- [outptr8] "+r" (outptr8),
- [outptr12] "+r" (outptr12),
- [n_remaining_channels] "+r" (n_remaining_channels)
- : [mstride1] "r" (sizeof(float) * mstride),
- [mstride2] "r" (sizeof(float) * mstride * 2),
- [mstride3] "r" (sizeof(float) * mstride * 3),
- [colstride1] "r" (sizeof(float) * kernel_col_stride),
- [colstride2] "r" (sizeof(float) * kernel_col_stride * 2),
- [one_half] "r" (0.5f)
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
- "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
- "v20", "v21", "v22", "v23", "v24"
- );
-
- // Progression to complete stride
- outptr0 += matrix_row_stride - n_output_channels;
- outptr4 += matrix_row_stride - n_output_channels;
- outptr8 += matrix_row_stride - n_output_channels;
- outptr12 += matrix_row_stride - n_output_channels;
- }
-}
-
-template <>
-template <>
-inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel<1>(
- const float* const kernel,
- const int n_input_channels,
- const int n_output_channels,
- float* const matrix_base,
- const int mstride,
- const int matrix_row_stride
-) {
- // Use one input pointer for each row of the kernel, use two additional
- // offsets to extract columns.
- const int kernel_col_stride = n_input_channels * n_output_channels;
- const int kernel_row_stride = 3 * kernel_col_stride;
- const float *inptr0 = kernel;
- const float *inptr1 = kernel + kernel_row_stride;
- const float *inptr2 = kernel + kernel_row_stride*2;
-
- // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
- // offsets to extract further matrices.
- float *outptr0 = matrix_base;
- float *outptr4 = matrix_base + mstride * 4;
- float *outptr8 = matrix_base + mstride * 8;
- float *outptr12 = matrix_base + mstride * 12;
-
- // For every input channel
- for (int in_c = 0; in_c < n_input_channels; in_c++) {
- int n_remaining_channels = n_output_channels;
-
- asm volatile (
- // Registers into which to read the kernel
- "w_11 .req v0\n" "qw_11 .req q0\n" "sw_11 .req s0\n"
- "w_12 .req v1\n" "qw_12 .req q1\n" "sw_12 .req s1\n"
- "w_13 .req v2\n" "qw_13 .req q2\n" "sw_13 .req s2\n"
- "w_21 .req v3\n" "qw_21 .req q3\n" "sw_21 .req s3\n"
- "w_22 .req v4\n" "qw_22 .req q4\n" "sw_22 .req s4\n"
- "w_23 .req v5\n" "qw_23 .req q5\n" "sw_23 .req s5\n"
- "w_31 .req v6\n" "qw_31 .req q6\n" "sw_31 .req s6\n"
- "w_32 .req v7\n" "qw_32 .req q7\n" "sw_32 .req s7\n"
- "w_33 .req v8\n" "qw_33 .req q8\n" "sw_33 .req s8\n"
-
- // Transformed matrix Ww
- "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n"
- "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n"
- "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n"
- "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n"
-
- // Output matrix U = WwWT
- "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n"
- "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n"
- "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n"
- "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n"
-
- // Storage view of output matrices
- "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n"
- "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n"
- "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n"
- "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n"
-
- "sU11 .req s0\n" "sU12 .req s15\n" "sU13 .req s16\n" "sU14 .req s2\n"
- "sU21 .req s9\n" "sU22 .req s17\n" "sU23 .req s18\n" "sU24 .req s11\n"
- "sU31 .req s12\n" "sU32 .req s19\n" "sU33 .req s20\n" "sU34 .req s14\n"
- "sU41 .req s6\n" "sU42 .req s21\n" "sU43 .req s22\n" "sU44 .req s8\n"
-
- "half .req v23\n" // {0.5, ..., 0.5}
- "dup half.4s, %w[one_half]\n"
- "scratch .req v24\n"
-
- // Subtract the tail from the number of remaining channels and jump to
- // the tail if necessary.
- "subs %x[n_remaining_channels], %x[n_remaining_channels], #1\n"
- "beq 2f\n"
-
- "1:"
- // Load tile of the kernel
- "ldr qw_11, [%x[inptr0]]\n"
- "str qU11, [%x[outptr0]]\n"
- "ldr qw_12, [%x[inptr0], %x[colstride1]]\n"
- "ldr qw_13, [%x[inptr0], %x[colstride2]]\n"
- "str qU14, [%x[outptr0], %x[mstride3]]\n"
- "add %x[inptr0], %x[inptr0], #0x10\n"
-
- "ldr qw_21, [%x[inptr1]]\n"
- "ldr qw_22, [%x[inptr1], %x[colstride1]]\n"
- "ldr qw_23, [%x[inptr1], %x[colstride2]]\n"
- "add %x[inptr1], %x[inptr1], #0x10\n"
-
- "ldr qw_31, [%x[inptr2]]\n"
- "str qU41, [%x[outptr12]]\n"
- "ldr qw_32, [%x[inptr2], %x[colstride1]]\n"
- "ldr qw_33, [%x[inptr2], %x[colstride2]]\n"
- "str qU44, [%x[outptr12], %x[mstride3]]\n"
- "add %x[inptr2], %x[inptr2], #0x10\n"
-
- // Compute 2nd and 3rd rows of Ww
- "fadd scratch.4s, w_11.4s, w_31.4s\n"
- "fmul Ww21.4s, scratch.4s, half.4s\n"
- "fmla Ww21.4s, w_21.4s, half.4s\n"
- "str qU21, [%x[outptr4]]\n"
- "fmul Ww31.4s, scratch.4s, half.4s\n"
- "fmls Ww31.4s, w_21.4s, half.4s\n"
- "str qU31, [%x[outptr8]]\n"
-
- "fadd scratch.4s, w_12.4s, w_32.4s\n"
- "fmul Ww22.4s, scratch.4s, half.4s\n"
- "fmla Ww22.4s, w_22.4s, half.4s\n"
- "fmul Ww32.4s, scratch.4s, half.4s\n"
- "fmls Ww32.4s, w_22.4s, half.4s\n"
-
- "fadd scratch.4s, w_13.4s, w_33.4s\n"
- "fmul Ww23.4s, scratch.4s, half.4s\n"
- "fmla Ww23.4s, w_23.4s, half.4s\n"
- "str qU24, [%x[outptr4], %x[mstride3]]\n"
- "fmul Ww33.4s, scratch.4s, half.4s\n"
- "fmls Ww33.4s, w_23.4s, half.4s\n"
- "str qU34, [%x[outptr8], %x[mstride3]]\n"
-
- // Compute and store U, only need to compute the 2nd and 3rd columns
- // of U and update output pointers
- "fadd scratch.4s, Ww11.4s, Ww13.4s\n"
- "fmul U12.4s, scratch.4s, half.4s\n"
- "fmla U12.4s, Ww12.4s, half.4s\n"
- "str qU12, [%x[outptr0], %x[mstride1]]\n"
- "fmul U13.4s, scratch.4s, half.4s\n"
- "fmls U13.4s, Ww12.4s, half.4s\n"
- "str qU13, [%x[outptr0], %x[mstride2]]\n"
- "add %x[outptr0], %x[outptr0], #0x10\n"
-
- "fadd scratch.4s, Ww21.4s, Ww23.4s\n"
- "fmul U22.4s, scratch.4s, half.4s\n"
- "fmla U22.4s, Ww22.4s, half.4s\n"
- "str qU22, [%x[outptr4], %x[mstride1]]\n"
- "fmul U23.4s, scratch.4s, half.4s\n"
- "fmls U23.4s, Ww22.4s, half.4s\n"
- "str qU23, [%x[outptr4], %x[mstride2]]\n"
- "add %x[outptr4], %x[outptr4], #0x10\n"
-
- "fadd scratch.4s, Ww31.4s, Ww33.4s\n"
- "fmul U32.4s, scratch.4s, half.4s\n"
- "fmla U32.4s, Ww32.4s, half.4s\n"
- "str qU32, [%x[outptr8], %x[mstride1]]\n"
- "fmul U33.4s, scratch.4s, half.4s\n"
- "fmls U33.4s, Ww32.4s, half.4s\n"
- "str qU33, [%x[outptr8], %x[mstride2]]\n"
- "add %x[outptr8], %x[outptr8], #0x10\n"
-
- "fadd scratch.4s, Ww41.4s, Ww43.4s\n"
- "fmul U42.4s, scratch.4s, half.4s\n"
- "fmla U42.4s, Ww42.4s, half.4s\n"
- "str qU42, [%x[outptr12], %x[mstride1]]\n"
- "fmul U43.4s, scratch.4s, half.4s\n"
- "fmls U43.4s, Ww42.4s, half.4s\n"
- "str qU43, [%x[outptr12], %x[mstride2]]\n"
- "add %x[outptr12], %x[outptr12], #0x10\n"
-
- "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n"
- "bne 1b\n"
-
- // Tail size 1
- "2:"
- // Load tile of the kernel
- "ldr sw_11, [%x[inptr0]]\n"
- "str sU11, [%x[outptr0]]\n"
- "ldr sw_12, [%x[inptr0], %x[colstride1]]\n"
- "ldr sw_13, [%x[inptr0], %x[colstride2]]\n"
- "str sU14, [%x[outptr0], %x[mstride3]]\n"
- "add %x[inptr0], %x[inptr0], #0x04\n"
-
- "ldr sw_21, [%x[inptr1]]\n"
- "ldr sw_22, [%x[inptr1], %x[colstride1]]\n"
- "ldr sw_23, [%x[inptr1], %x[colstride2]]\n"
- "add %x[inptr1], %x[inptr1], #0x04\n"
-
- "ldr sw_31, [%x[inptr2]]\n"
- "str sU41, [%x[outptr12]]\n"
- "ldr sw_32, [%x[inptr2], %x[colstride1]]\n"
- "ldr sw_33, [%x[inptr2], %x[colstride2]]\n"
- "str sU44, [%x[outptr12], %x[mstride3]]\n"
- "add %x[inptr2], %x[inptr2], #0x04\n"
-
- // Compute 2nd and 3rd rows of Ww
- "fadd scratch.2s, w_11.2s, w_31.2s\n"
- "fmul Ww21.2s, scratch.2s, half.2s\n"
- "fmla Ww21.2s, w_21.2s, half.2s\n"
- "str sU21, [%x[outptr4]]\n"
- "fmul Ww31.2s, scratch.2s, half.2s\n"
- "fmls Ww31.2s, w_21.2s, half.2s\n"
- "str sU31, [%x[outptr8]]\n"
-
- "fadd scratch.2s, w_12.2s, w_32.2s\n"
- "fmul Ww22.2s, scratch.2s, half.2s\n"
- "fmla Ww22.2s, w_22.2s, half.2s\n"
- "fmul Ww32.2s, scratch.2s, half.2s\n"
- "fmls Ww32.2s, w_22.2s, half.2s\n"
-
- "fadd scratch.2s, w_13.2s, w_33.2s\n"
- "fmul Ww23.2s, scratch.2s, half.2s\n"
- "fmla Ww23.2s, w_23.2s, half.2s\n"
- "str sU24, [%x[outptr4], %x[mstride3]]\n"
- "fmul Ww33.2s, scratch.2s, half.2s\n"
- "fmls Ww33.2s, w_23.2s, half.2s\n"
- "str sU34, [%x[outptr8], %x[mstride3]]\n"
-
- // Compute and store U, only need to compute the 2nd and 3rd columns of
- // U and update output pointers
- "fadd scratch.2s, Ww11.2s, Ww13.2s\n"
- "fmul U12.2s, scratch.2s, half.2s\n"
- "fmla U12.2s, Ww12.2s, half.2s\n"
- "str sU12, [%x[outptr0], %x[mstride1]]\n"
- "fmul U13.2s, scratch.2s, half.2s\n"
- "fmls U13.2s, Ww12.2s, half.2s\n"
- "str sU13, [%x[outptr0], %x[mstride2]]\n"
- "add %x[outptr0], %x[outptr0], #0x04\n"
-
- "fadd scratch.2s, Ww21.2s, Ww23.2s\n"
- "fmul U22.2s, scratch.2s, half.2s\n"
- "fmla U22.2s, Ww22.2s, half.2s\n"
- "str sU22, [%x[outptr4], %x[mstride1]]\n"
- "fmul U23.2s, scratch.2s, half.2s\n"
- "fmls U23.2s, Ww22.2s, half.2s\n"
- "str sU23, [%x[outptr4], %x[mstride2]]\n"
- "add %x[outptr4], %x[outptr4], #0x04\n"
-
- "fadd scratch.2s, Ww31.2s, Ww33.2s\n"
- "fmul U32.2s, scratch.2s, half.2s\n"
- "fmla U32.2s, Ww32.2s, half.2s\n"
- "str sU32, [%x[outptr8], %x[mstride1]]\n"
- "fmul U33.2s, scratch.2s, half.2s\n"
- "fmls U33.2s, Ww32.2s, half.2s\n"
- "str sU33, [%x[outptr8], %x[mstride2]]\n"
- "add %x[outptr8], %x[outptr8], #0x04\n"
-
- "fadd scratch.2s, Ww41.2s, Ww43.2s\n"
- "fmul U42.2s, scratch.2s, half.2s\n"
- "fmla U42.2s, Ww42.2s, half.2s\n"
- "str sU42, [%x[outptr12], %x[mstride1]]\n"
- "fmul U43.2s, scratch.2s, half.2s\n"
- "fmls U43.2s, Ww42.2s, half.2s\n"
- "str sU43, [%x[outptr12], %x[mstride2]]\n"
- "add %x[outptr12], %x[outptr12], #0x04\n"
-
- // Clear aliases
- ".unreq half\n"
- ".unreq scratch\n"
- ".unreq w_11\n" ".unreq qw_11\n" ".unreq sw_11\n"
- ".unreq w_12\n" ".unreq qw_12\n" ".unreq sw_12\n"
- ".unreq w_13\n" ".unreq qw_13\n" ".unreq sw_13\n"
- ".unreq w_21\n" ".unreq qw_21\n" ".unreq sw_21\n"
- ".unreq w_22\n" ".unreq qw_22\n" ".unreq sw_22\n"
- ".unreq w_23\n" ".unreq qw_23\n" ".unreq sw_23\n"
- ".unreq w_31\n" ".unreq qw_31\n" ".unreq sw_31\n"
- ".unreq w_32\n" ".unreq qw_32\n" ".unreq sw_32\n"
- ".unreq w_33\n" ".unreq qw_33\n" ".unreq sw_33\n"
- ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n"
- ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n"
- ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n"
- ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n"
- ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n"
- ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n"
- ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n"
- ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n"
- ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n"
- ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n"
- ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n"
- ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n"
- ".unreq sU11\n" ".unreq sU12\n" ".unreq sU13\n" ".unreq sU14\n"
- ".unreq sU21\n" ".unreq sU22\n" ".unreq sU23\n" ".unreq sU24\n"
- ".unreq sU31\n" ".unreq sU32\n" ".unreq sU33\n" ".unreq sU34\n"
- ".unreq sU41\n" ".unreq sU42\n" ".unreq sU43\n" ".unreq sU44\n"
-
- : [inptr0] "+r" (inptr0),
- [inptr1] "+r" (inptr1),
- [inptr2] "+r" (inptr2),
- [outptr0] "+r" (outptr0),
- [outptr4] "+r" (outptr4),
- [outptr8] "+r" (outptr8),
- [outptr12] "+r" (outptr12),
- [n_remaining_channels] "+r" (n_remaining_channels)
- : [mstride1] "r" (sizeof(float) * mstride),
- [mstride2] "r" (sizeof(float) * mstride * 2),
- [mstride3] "r" (sizeof(float) * mstride * 3),
- [colstride1] "r" (sizeof(float) * kernel_col_stride),
- [colstride2] "r" (sizeof(float) * kernel_col_stride * 2),
- [one_half] "r" (0.5f)
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
- "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
- "v20", "v21", "v22", "v23", "v24"
- );
-
- // Progression to complete stride
- outptr0 += matrix_row_stride - n_output_channels;
- outptr4 += matrix_row_stride - n_output_channels;
- outptr8 += matrix_row_stride - n_output_channels;
- outptr12 += matrix_row_stride - n_output_channels;
- }
-}
-}
-#endif // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp
deleted file mode 100644
index 0992c0bb44..0000000000
--- a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp
+++ /dev/null
@@ -1,356 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-namespace winograd {
- /* Transform from the Winograd domain back to the spatial domain.
- */
- template <typename T>
- struct Winograd2x2_3x3GemmOutput {
- static void execute(
- const Tensor4DShape &output_shape,
- T* const matrix_base,
- const int matrix_stride,
- const int matrix_row_stride,
- T* const output
- );
-
- protected:
- /* Specialised implementation method. */
- template <bool tail_M, bool tail_N, int channel_tail>
- static void _execute(
- const Tensor4DShape &output_shape,
- T *output,
- const T *input,
- const int matrix_stride,
- const int matrix_row_stride
- );
- };
-
- /* Two-stage implementation of the transformation from the Winograd domain.
- *
- * First computes Z.F and then computes (Z.F).Z^T.
- */
- template <typename T>
- struct Winograd2x2_3x3GemmOutput_TwoStage {
- static void execute(
- const Tensor4DShape &output_shape,
- T* const matrix_base,
- const int matrix_stride,
- const int matrix_row_stride,
- T* const output
- );
-
- protected:
- template <int channel_tail>
- static void compute_zf(
- const int n_rows, const int n_channels,
- T* const zf, const T* const input[16]
- );
-
- template <bool tail_M, bool tail_N, int channel_tail>
- static void compute_zfzT(
- const Tensor4DShape &output_shape,
- T* const output, const T* const zf
- );
- };
-}
-
-#include "output_2x2_3x3/a64_float.hpp"
-// #include "output_2x2_3x3/a64_float_two_stage.hpp"
-
-/*****************************************************************************/
-/*
-template <typename T>
-void winograd::Winograd2x2_3x3GemmOutput<T>::execute(
- const Tensor4DShape &output_shape,
- const int tile_M,
- const int tile_N,
- T* const matrix_base,
- const int matrix_stride,
- const int matrix_row_stride,
- T* const output
-) {
- T* const antipadding = reinterpret_cast<T *>(malloc(sizeof(T) * output_shape.n_channels));
-
- // Get input pointers
- const T* inptrs[16];
- for (int i = 0; i < 16; i++) {
- inptrs[i] = matrices[i];
- }
-
- for (int batch = 0; batch < output_shape.n_batches; batch++) {
- for (int tile_i = 0; tile_i < tile_M; tile_i++) {
- for (int tile_j = 0; tile_j < tile_N; tile_j++) {
- // Get pointers for each of the 4 output cells required for this computation
- T* outptrs[4];
- for (int cell_i = 0, c = 0; cell_i < 2; cell_i++) {
- for (int cell_j = 0; cell_j < 2; cell_j++, c++) {
- const int i = tile_i*2 + cell_i;
- const int j = tile_j*2 + cell_j;
-
- if (i < output_shape.n_rows && j < output_shape.n_cols) {
- outptrs[c] = output + (
- (batch*output_shape.n_rows + i) * output_shape.n_cols +
- j) * output_shape.n_channels;
- } else {
- outptrs[c] = antipadding;
- }
- } // cell_j
- } // cell_i
-
- for (int n = 0; n < output_shape.n_channels; n++) {
- // Read 16 values and progress pointers
- T v[16];
- for (int i = 0; i < 16; i++) {
- v[i] = *(inptrs[i]++);
- }
-
- // Compute output for 4 pixels
- *(outptrs[0]++) = v[ 0] + v[ 1] + v[ 2] +
- v[ 4] + v[ 5] + v[ 6] +
- v[ 8] + v[ 9] + v[10];
- *(outptrs[1]++) = v[ 1] - v[ 2] - v[ 3] +
- v[ 5] - v[ 6] - v[ 7] +
- v[ 9] - v[10] - v[11];
- *(outptrs[2]++) = v[ 4] + v[ 5] + v[ 6] -
- v[ 8] - v[ 9] - v[10] -
- v[12] - v[13] - v[14];
- *(outptrs[3]++) = v[ 5] - v[ 6] - v[ 7] -
- v[ 9] + v[10] + v[11] -
- v[13] + v[14] + v[15];
- } // output_channel
- } // tile_j
- } // tile_i
- } // batch
-
- free(antipadding);
-}
-*/
-
-/*****************************************************************************/
-/*
-template <typename T>
-void winograd::Winograd2x2_3x3GemmOutput_TwoStage<T>::execute(
- const Tensor4DShape &output_shape,
- T* const matrices[16], T* const output
-) {
- // Allocate memory for the intermediate matrices
- const int tile_M = iceildiv(output_shape.n_rows, 2);
- const int tile_N = iceildiv(output_shape.n_cols, 2);
- const int n_rows = output_shape.n_batches * tile_M * tile_N;
- const int n_channels = output_shape.n_channels;
- T* matrices_zf = reinterpret_cast<T*>(
- calloc(8 * n_rows * n_channels, sizeof(T))
- );
-
- // Perform the first stage transform, computing ZF.
- // Specializations should dispatch to different methods based on tail size.
- compute_zf<0>(n_rows, n_channels, matrices_zf, matrices);
-
- // Perform the second stage transform, finishing Z F Z^T - variable dispatch
- // based on size of the output. Specialisations can also dispatch based on
- // the tail-size of the channel.
- if (output_shape.n_rows % 2 && output_shape.n_cols % 2) {
- compute_zfzT<true, true, 0>(output_shape, output, matrices_zf);
- } else if (output_shape.n_rows % 2) {
- compute_zfzT<true, false, 0>(output_shape, output, matrices_zf);
- } else if (output_shape.n_cols % 2) {
- compute_zfzT<false, true, 0>(output_shape, output, matrices_zf);
- } else {
- compute_zfzT<false, false, 0>(output_shape, output, matrices_zf);
- }
-
- free(reinterpret_cast<void*>(matrices_zf));
-}
-
-template <typename T>
-template <int channel_tail>
-void winograd::Winograd2x2_3x3GemmOutput_TwoStage<T>::compute_zf(
- const int n_rows, const int n_channels,
- T* output, const T* const input[16]
-) {
- // Extract 8 output pointers
- T* outptr[8];
- for (int i = 0; i < 8; i++) {
- outptr[i] = output + i*n_rows*n_channels;
- }
-
- // Copy the 16 input pointers
- const T* inptr[16];
- for (int i = 0; i < 16; i++) {
- inptr[i] = input[i];
- }
-
- // For every row of the matrices
- for (int i = 0; i < n_rows; i++) {
- // For every channel
- for (int j = 0; j < n_channels; j++) {
- // Extract values from the input matrices
- T val[16];
- for (int n = 0; n < 16; n++) {
- val[n] = *(inptr[n]++);
- }
-
- // Compute output values
- *(outptr[0]++) = val[0] + val[1] + val[2];
- *(outptr[1]++) = val[1] - val[2] - val[3];
- *(outptr[2]++) = val[4] + val[5] + val[6];
- *(outptr[3]++) = val[5] - val[6] - val[7];
- *(outptr[4]++) = val[8] + val[9] + val[10];
- *(outptr[5]++) = val[9] - val[10] - val[11];
- *(outptr[6]++) = val[12] + val[13] + val[14];
- *(outptr[7]++) = val[13] - val[14] - val[15];
- }
- }
-}
-
-template <typename T>
-template <bool tail_M, bool tail_N, int channel_tail>
-void winograd::Winograd2x2_3x3GemmOutput_TwoStage<T>::compute_zfzT(
- const Tensor4DShape &output_shape,
- T* const output, const T* const input
-) {
- // Sizing information
- const int tile_M = output_shape.n_rows / 2;
- const int tile_N = output_shape.n_cols / 2;
-
- const int n_rows = (output_shape.n_batches *
- (tile_M + (tail_M ? 1 : 0)) *
- (tile_N + (tail_N ? 1 : 0)));
- const int n_channels = output_shape.n_channels;
-
- // Extract 8 input pointers
- const T* inptr[8];
- for (int i = 0; i < 8; i++) {
- inptr[i] = input + i*n_rows*n_channels;
- }
-
- // Extract 4 output pointers
- T* outptr00 = output;
- T* outptr01 = outptr00 + n_channels;
- T* outptr10 = outptr00 + output_shape.n_cols * n_channels;
- T* outptr11 = outptr10 + n_channels;
-
- // Progress over the output tiles, generating output values.
- for (int batch = 0; batch < output_shape.n_batches; batch++) {
- for (int tile_i = 0; tile_i < tile_M; tile_i++) {
- for (int tile_j = 0; tile_j < tile_N; tile_j++) {
- for (int channel = 0; channel < n_channels; channel++) {
- // Read values from the input pointers
- T v[8];
- for (int i = 0; i < 8; i++) {
- v[i] = *(inptr[i]++);
- }
-
- // Compute the output values and progress the output pointers.
- *(outptr00++) = v[0] + v[2] + v[4];
- *(outptr01++) = v[1] + v[3] + v[5];
- *(outptr10++) = v[2] - v[4] - v[6];
- *(outptr11++) = v[3] - v[5] - v[7];
- }
-
- // Progress the output pointers to the next column
- outptr00 += n_channels;
- outptr01 += n_channels;
- outptr10 += n_channels;
- outptr11 += n_channels;
- }
-
- if (tail_N) {
- // Only evaluate the left-most columns of the output
- for (int channel = 0; channel < n_channels; channel++) {
- // Read values from the input pointers
- T v[8];
- for (int i = 0; i < 4; i++) {
- v[i * 2] = *inptr[i * 2];
- }
- for (int i = 0; i < 8; i++) {
- inptr[i]++;
- }
-
- // Compute the output values and progress the output pointers.
- *(outptr00++) = v[0] + v[2] + v[4];
- *(outptr10++) = v[2] - v[4] - v[6];
- }
-
- // Progress the output pointers to the next column
- outptr01 += n_channels; // Account for being skipped above
- outptr11 += n_channels; // Account for being skipped above
- }
-
- // Progress the output pointers to the next row
- outptr00 += output_shape.n_cols * n_channels;
- outptr01 += output_shape.n_cols * n_channels;
- outptr10 += output_shape.n_cols * n_channels;
- outptr11 += output_shape.n_cols * n_channels;
- }
-
- if (tail_M) {
- // Only work on the upper row of the output
- for (int tile_j = 0; tile_j < tile_N; tile_j++) {
- for (int channel = 0; channel < n_channels; channel++) {
- // Read values from the input pointers
- T v[8];
- for (int i = 0; i < 8; i++) {
- v[i] = *(inptr[i]++);
- }
-
- // Compute the output values and progress the output pointers.
- *(outptr00++) = v[0] + v[2] + v[4];
- *(outptr01++) = v[1] + v[3] + v[5];
- }
-
- // Progress the output pointers to the next column
- outptr00 += n_channels;
- outptr01 += n_channels;
- outptr10 += 2 * n_channels; // Account for being skipped above
- outptr11 += 2 * n_channels; // Account for being skipped above
- }
-
- if (tail_N) {
- // Only evaluate the upper-left cell of the output
- for (int channel = 0; channel < n_channels; channel++) {
- // Read values from the input pointers
- T v[8];
- for (int i = 0; i < 3; i++) {
- v[i * 2] = *inptr[i * 2];
- }
- for (int i = 0; i < 8; i++) {
- inptr[i]++;
- }
-
- // Compute the output values and progress the output pointers.
- *(outptr00++) = v[0] + v[2] + v[4];
- }
-
- // Progress the output pointers to the next column
- outptr01 += n_channels; // Account for being skipped above
- outptr10 += n_channels; // Account for being skipped above
- outptr11 += n_channels; // Account for being skipped above
- }
- }
- }
-}
-*/
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp
deleted file mode 100644
index 5925f9d569..0000000000
--- a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp
+++ /dev/null
@@ -1,650 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-/* Float implementation for AArch64.
- */
-#ifdef __aarch64__
-namespace winograd {
-
-
-template <>
-template <>
-inline void Winograd2x2_3x3GemmOutput<float>::_execute<false, false, 0>(
- const Tensor4DShape &output_shape,
- float *output,
- const float *input,
- const int mstride,
- const int matrix_row_stride
-) {
- const int tile_M = output_shape.n_rows / 2;
- const int tile_N = output_shape.n_cols / 2;
- int batch = output_shape.n_batches;
- float *outptr = output;
-
- const float *inptr0 = input;
- const float *inptr4 = input + 4 * mstride;
- const float *inptr8 = input + 8 * mstride;
- const float *inptr12 = input + 12 * mstride;
-
- const size_t col_stride = sizeof(float) * output_shape.n_channels;
- const size_t row_stride = col_stride * tile_N * 2;
-
- asm volatile (
- // Aliases for elements of the input matrix `F`
- // V-register Q-register
- "F11 .req v0\n" "qF11 .req q0\n"
- "F12 .req v1\n" "qF12 .req q1\n"
- "F13 .req v2\n" "qF13 .req q2\n"
- "F14 .req v3\n" "qF14 .req q3\n"
- "F21 .req v4\n" "qF21 .req q4\n"
- "F22 .req v5\n" "qF22 .req q5\n"
- "F23 .req v6\n" "qF23 .req q6\n"
- "F24 .req v7\n" "qF24 .req q7\n"
- "F31 .req v8\n" "qF31 .req q8\n"
- "F32 .req v9\n" "qF32 .req q9\n"
- "F33 .req v10\n" "qF33 .req q10\n"
- "F34 .req v11\n" "qF34 .req q11\n"
- "F41 .req v12\n" "qF41 .req q12\n"
- "F42 .req v13\n" "qF42 .req q13\n"
- "F43 .req v14\n" "qF43 .req q14\n"
- "F44 .req v15\n" "qF44 .req q15\n"
-
- // Aliases for elements of the intermediate matrix `FZ`
- "FZ11 .req v16\n"
- "FZ12 .req v17\n"
- "FZ21 .req v18\n"
- "FZ22 .req v19\n"
- "FZ31 .req v20\n"
- "FZ32 .req v21\n"
- "FZ41 .req v22\n"
- "FZ42 .req v23\n"
-
- // Aliases for elements of the output matrix `f` (called `g` due to case
- // insensitivity of aliases).
- " g11 .req v24\n"
- "qg11 .req q24\n"
- " g12 .req v25\n"
- "qg12 .req q25\n"
- " g21 .req v26\n"
- "qg21 .req q26\n"
- " g22 .req v27\n"
- "qg22 .req q27\n"
-
- // Prepare the various strides
- "col_stride .req %x[col_stride]\n"
- "row_stride .req %x[row_stride]\n"
- "row_plus_col_stride .req %x[row_plus_col_stride]\n"
-
- "mstride1 .req %x[mstride1]\n"
- "mstride2 .req %x[mstride2]\n"
- "mstride3 .req %x[mstride3]\n"
-
- "tile_i .req x19\n" // Tile row counter
- "tile_j .req x20\n" // Tile column counter
- "channel .req x21\n" // Channel counter
-
- "1:" // Loop over batches
- "mov tile_i, %x[tile_M]\n" // Reset tile row counter
-
- "2:" // Loop over rows of tiles
- "mov tile_j, %x[tile_N]\n" // Reset tile column counter
-
- "3:" // Loop over columns of tiles
- // Perform initial loads of the matrix `F`
- "ldr qF11, [%x[inptr0]]\n"
- "ldr qF12, [%x[inptr0], mstride1]\n"
- "ldr qF13, [%x[inptr0], mstride2]\n"
- "ldr qF14, [%x[inptr0], mstride3]\n"
- "add %x[inptr0], %x[inptr0], #0x10\n"
- "ldr qF21, [%x[inptr4]]\n"
- "ldr qF22, [%x[inptr4], mstride1]\n"
- "subs channel, %x[n_channels], #4\n" // Reset channel counter
-
- "ldr qF23, [%x[inptr4], mstride2]\n"
- "ldr qF24, [%x[inptr4], mstride3]\n"
- "add %x[inptr4], %x[inptr4], #0x10\n"
- "beq 5f\n" // Jump straight to tail if necessary
-
- "4:" // Loop over channels
- "ldr qF31, [%x[inptr8]]\n"
- "fadd FZ11.4s, F11.4s, F12.4s\n"
-
- "ldr qF32, [%x[inptr8], mstride1]\n"
- "fsub FZ12.4s, F12.4s, F13.4s\n"
-
- "ldr qF33, [%x[inptr8], mstride2]\n"
- "fadd FZ11.4s, FZ11.4s, F13.4s\n"
-
- "ldr qF34, [%x[inptr8], mstride3]\n"
- "fsub FZ12.4s, FZ12.4s, F14.4s\n"
-
- "ldr qF41, [%x[inptr12]]\n"
- "fadd FZ21.4s, F21.4s, F22.4s\n"
-
- "ldr qF42, [%x[inptr12], mstride1]\n"
- "fsub FZ22.4s, F22.4s, F23.4s\n"
-
- "ldr qF43, [%x[inptr12], mstride2]\n"
- "fadd FZ21.4s, FZ21.4s, F23.4s\n"
-
- "ldr qF44, [%x[inptr12], mstride3]\n"
- "fsub FZ22.4s, FZ22.4s, F24.4s\n"
-
- "fadd FZ31.4s, F31.4s, F32.4s\n"
- "add %x[inptr8], %x[inptr8], #0x10\n"
-
- "fsub FZ32.4s, F32.4s, F33.4s\n"
- "add %x[inptr12], %x[inptr12], #0x10\n"
-
- "fadd FZ31.4s, FZ31.4s, F33.4s\n"
-
- "fsub FZ32.4s, FZ32.4s, F34.4s\n"
-
- "fadd g11.4s, FZ11.4s, FZ21.4s\n"
-
- "fadd g12.4s, FZ12.4s, FZ22.4s\n"
-
- "fadd g11.4s, g11.4s, FZ31.4s\n"
-
- "fadd g12.4s, g12.4s, FZ32.4s\n"
-
- "ldr qF11, [%x[inptr0]]\n"
- "fadd FZ41.4s, F41.4s, F42.4s\n"
-
- "ldr qF12, [%x[inptr0], mstride1]\n"
- "fsub g21.4s, FZ21.4s, FZ31.4s\n"
-
- "ldr qF13, [%x[inptr0], mstride2]\n"
- "fsub FZ42.4s, F42.4s, F43.4s\n"
-
- "ldr qF14, [%x[inptr0], mstride3]\n"
- "str qg11, [%x[outptr]]\n"
-
- "ldr qF21, [%x[inptr4]]\n"
- "fadd FZ41.4s, FZ41.4s, F43.4s\n"
-
- "ldr qF22, [%x[inptr4], mstride1]\n"
- "str qg12, [%x[outptr], col_stride]\n"
-
- "ldr qF23, [%x[inptr4], mstride2]\n"
- "fsub FZ42.4s, FZ42.4s, F44.4s\n"
-
- "ldr qF24, [%x[inptr4], mstride3]\n"
- "fsub g22.4s, FZ22.4s, FZ32.4s\n"
-
- "fsub g21.4s, g21.4s, FZ41.4s\n"
- "add %x[inptr0], %x[inptr0], #0x10\n"
-
- "fsub g22.4s, g22.4s, FZ42.4s\n"
- "add %x[inptr4], %x[inptr4], #0x10\n"
-
- "subs channel, channel, #4\n"
-
- "str qg21, [%x[outptr], row_stride]\n"
-
- "str qg22, [%x[outptr], row_plus_col_stride]\n"
-
- "add %x[outptr], %x[outptr], #0x10\n"
-
- "bne 4b\n"
-
- "5:" // Channel tail
- "ldr qF31, [%x[inptr8]]\n"
- "fadd FZ11.4s, F11.4s, F12.4s\n"
-
- "ldr qF32, [%x[inptr8], mstride1]\n"
- "fsub FZ12.4s, F12.4s, F13.4s\n"
-
- "ldr qF33, [%x[inptr8], mstride2]\n"
- "fadd FZ11.4s, FZ11.4s, F13.4s\n"
-
- "ldr qF34, [%x[inptr8], mstride3]\n"
- "fsub FZ12.4s, FZ12.4s, F14.4s\n"
-
- "ldr qF41, [%x[inptr12]]\n"
- "fadd FZ21.4s, F21.4s, F22.4s\n"
-
- "ldr qF42, [%x[inptr12], mstride1]\n"
- "fsub FZ22.4s, F22.4s, F23.4s\n"
-
- "ldr qF43, [%x[inptr12], mstride2]\n"
- "fadd FZ21.4s, FZ21.4s, F23.4s\n"
-
- "ldr qF44, [%x[inptr12], mstride3]\n"
- "fsub FZ22.4s, FZ22.4s, F24.4s\n"
-
- "fadd FZ31.4s, F31.4s, F32.4s\n"
- "add %x[inptr8], %x[inptr8], #0x10\n"
-
- "fsub FZ32.4s, F32.4s, F33.4s\n"
- "add %x[inptr12], %x[inptr12], #0x10\n"
-
- "fadd FZ31.4s, FZ31.4s, F33.4s\n"
-
- "fsub FZ32.4s, FZ32.4s, F34.4s\n"
-
- "fadd g11.4s, FZ11.4s, FZ21.4s\n"
-
- "fadd g12.4s, FZ12.4s, FZ22.4s\n"
-
- "fadd g11.4s, g11.4s, FZ31.4s\n"
-
- "fadd g12.4s, g12.4s, FZ32.4s\n"
-
- "fadd FZ41.4s, F41.4s, F42.4s\n"
-
- "fsub g21.4s, FZ21.4s, FZ31.4s\n"
-
- "fsub FZ42.4s, F42.4s, F43.4s\n"
-
- "str qg11, [%x[outptr]]\n"
-
- "fadd FZ41.4s, FZ41.4s, F43.4s\n"
-
- "str qg12, [%x[outptr], col_stride]\n"
-
- "fsub FZ42.4s, FZ42.4s, F44.4s\n"
-
- "fsub g22.4s, FZ22.4s, FZ32.4s\n"
-
- "fsub g21.4s, g21.4s, FZ41.4s\n"
-
- "fsub g22.4s, g22.4s, FZ42.4s\n"
-
- "subs channel, channel, #4\n"
-
- "str qg21, [%x[outptr], row_stride]\n"
-
- // Progress input pointers to the next row of the matrix
- "add %x[inptr0], %x[inptr0], %x[mrowpad]\n"
- "add %x[inptr4], %x[inptr4], %x[mrowpad]\n"
- "add %x[inptr8], %x[inptr8], %x[mrowpad]\n"
- "add %x[inptr12], %x[inptr12], %x[mrowpad]\n"
-
- "str qg22, [%x[outptr], row_plus_col_stride]\n"
-
- "add %x[outptr], %x[outptr], #0x10\n"
-
-
- "add %x[outptr], %x[outptr], col_stride\n"
- "subs tile_j, tile_j, #1\n"
- "bne 3b\n"
-
- "add %x[outptr], %x[outptr], row_stride\n"
- "subs tile_i, tile_i, #1\n"
- "bne 2b\n"
-
- "subs %[batch], %[batch], #1\n"
- "bne 1b\n"
-
- ".unreq F11\n" ".unreq qF11\n"
- ".unreq F12\n" ".unreq qF12\n"
- ".unreq F13\n" ".unreq qF13\n"
- ".unreq F14\n" ".unreq qF14\n"
- ".unreq F21\n" ".unreq qF21\n"
- ".unreq F22\n" ".unreq qF22\n"
- ".unreq F23\n" ".unreq qF23\n"
- ".unreq F24\n" ".unreq qF24\n"
- ".unreq F31\n" ".unreq qF31\n"
- ".unreq F32\n" ".unreq qF32\n"
- ".unreq F33\n" ".unreq qF33\n"
- ".unreq F34\n" ".unreq qF34\n"
- ".unreq F41\n" ".unreq qF41\n"
- ".unreq F42\n" ".unreq qF42\n"
- ".unreq F43\n" ".unreq qF43\n"
- ".unreq F44\n" ".unreq qF44\n"
-
- ".unreq FZ11\n" ".unreq FZ12\n"
- ".unreq FZ21\n" ".unreq FZ22\n"
- ".unreq FZ31\n" ".unreq FZ32\n"
- ".unreq FZ41\n" ".unreq FZ42\n"
-
- ".unreq g11\n" ".unreq qg11\n"
- ".unreq g12\n" ".unreq qg12\n"
- ".unreq g21\n" ".unreq qg21\n"
- ".unreq g22\n" ".unreq qg22\n"
-
- ".unreq col_stride\n"
- ".unreq row_stride\n"
- ".unreq row_plus_col_stride\n"
-
- ".unreq mstride1\n"
- ".unreq mstride2\n"
- ".unreq mstride3\n"
-
- ".unreq tile_i \n"
- ".unreq tile_j \n"
- ".unreq channel\n"
-
- : [batch] "+r" (batch),
- [outptr] "+r" (outptr),
- [inptr0] "+r" (inptr0),
- [inptr4] "+r" (inptr4),
- [inptr8] "+r" (inptr8),
- [inptr12] "+r" (inptr12)
- : [tile_M] "r" (tile_M),
- [tile_N] "r" (tile_N),
- [n_channels] "r" (output_shape.n_channels),
- [col_stride] "r" (col_stride),
- [row_stride] "r" (row_stride),
- [row_plus_col_stride] "r" (row_stride + col_stride),
- [mstride1] "r" (mstride * sizeof(float)),
- [mstride2] "r" (2 * mstride * sizeof(float)),
- [mstride3] "r" (3 * mstride * sizeof(float)),
- [mrowpad] "r" ((matrix_row_stride - output_shape.n_channels) * sizeof(float))
- : "x19", "x20", "x21",
- "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
- "q12", "q13", "q14", "q15", "q16", "q17", "q18", "q19", "q20", "q21",
- "q22", "q23", "q24", "q25", "q26", "q27",
- "cc", "memory"
- );
-}
-
-template <>
-template <bool tail_M, bool tail_N, const int channel_tail>
-inline void Winograd2x2_3x3GemmOutput<float>::_execute(
- const Tensor4DShape &output_shape,
- float *output,
- const float *input,
- const int mstride,
- const int matrix_row_stride
-) {
- // Compute basic information about the shape of the matrices
- const int tile_M = output_shape.n_rows / 2;
- const int tile_N = output_shape.n_cols / 2;
- const int n_channels = output_shape.n_channels;
-
- // Extract 16 input pointers
- const float* inptr[16];
- for (int i = 0; i < 16; i++) {
- inptr[i] = input + i*mstride;
- }
-
- // Extract 4 output pointers
- float *outptr00 = output;
- float *outptr01 = outptr00 + n_channels;
- float *outptr10 = outptr00 + output_shape.n_cols * n_channels;
- float *outptr11 = outptr10 + n_channels;
-
- // Progress over the output tiles, generating output values.
- for (int batch = 0; batch < output_shape.n_batches; batch++) {
- for (int tile_i = 0; tile_i < tile_M; tile_i++) {
- for (int tile_j = 0; tile_j < tile_N; tile_j++) {
- for (int channel = 0; channel < n_channels; channel++) {
- // Read values from the input pointers
- float F[4][4];
- for (int i = 0; i < 4; i++) {
- for (int j = 0; j < 4; j++) {
- F[i][j] = *(inptr[i*4 + j]++);
- }
- }
-
- // Compute the matrix F.Z
- float ZF[4][2];
- ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
- ZF[0][1] = F[0][1] - F[0][2] - F[0][3];
- ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
- ZF[1][1] = F[1][1] - F[1][2] - F[1][3];
- ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
- ZF[2][1] = F[2][1] - F[2][2] - F[2][3];
- ZF[3][0] = F[3][0] + F[3][1] + F[3][2];
- ZF[3][1] = F[3][1] - F[3][2] - F[3][3];
-
- // Hence compute the output matrix Z^T . (F.Z)
- *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
- *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1];
- *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0];
- *(outptr11++) = ZF[1][1] - ZF[2][1] - ZF[3][1];
- }
-
- // Progress the input pointers to the next row
- for (int i = 0; i < 16; i++) {
- inptr[i] += matrix_row_stride - n_channels;
- }
-
- // Progress the output pointers to the next column
- outptr00 += n_channels;
- outptr01 += n_channels;
- outptr10 += n_channels;
- outptr11 += n_channels;
- }
-
- if (tail_N) {
- // Only evaluate the left-most columns of the output
- for (int channel = 0; channel < n_channels; channel++) {
- // Read values from the input pointers
- float F[4][3];
- for (int i = 0; i < 4; i++) {
- for (int j = 0; j < 3; j++) {
- F[i][j] = *(inptr[i*4 + j]++);
- }
- }
- for (int i = 0; i < 4; i++) {
- inptr[i*4 + 3]++;
- }
-
- // Compute the matrix F.Z
- float ZF[4][1];
- ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
- ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
- ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
- ZF[3][0] = F[3][0] + F[3][1] + F[3][2];
-
- // Hence compute the output matrix Z^T . (F.Z)
- *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
- *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0];
- }
-
- // Progress the input pointers to the next row
- for (int i = 0; i < 16; i++) {
- inptr[i] += matrix_row_stride - n_channels;
- }
-
- // Progress the output pointers to the next column
- outptr01 += n_channels; // Account for being skipped above
- outptr11 += n_channels; // Account for being skipped above
- }
-
- // Progress the output pointers to the next row
- outptr00 += output_shape.n_cols * n_channels;
- outptr01 += output_shape.n_cols * n_channels;
- outptr10 += output_shape.n_cols * n_channels;
- outptr11 += output_shape.n_cols * n_channels;
- }
-
- if (tail_M) {
- // Only work on the upper row of the output
- for (int tile_j = 0; tile_j < tile_N; tile_j++) {
- for (int channel = 0; channel < n_channels; channel++) {
- // Read values from the input pointers
- float F[3][4];
- for (int i = 0; i < 3; i++) {
- for (int j = 0; j < 4; j++) {
- F[i][j] = *(inptr[i*4 + j]++);
- }
- }
- for (int j = 0; j < 4; j++) {
- inptr[12 + j]++;
- }
-
- // Compute the matrix F.Z
- float ZF[3][2];
- ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
- ZF[0][1] = F[0][1] - F[0][2] - F[0][3];
- ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
- ZF[1][1] = F[1][1] - F[1][2] - F[1][3];
- ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
- ZF[2][1] = F[2][1] - F[2][2] - F[2][3];
-
- // Hence compute the output matrix Z^T . (F.Z)
- *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
- *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1];
- }
-
- // Progress the input pointers to the next row
- for (int i = 0; i < 16; i++) {
- inptr[i] += matrix_row_stride - n_channels;
- }
-
- // Progress the output pointers to the next column
- outptr00 += n_channels;
- outptr01 += n_channels;
- outptr10 += 2 * n_channels; // Account for being skipped above
- outptr11 += 2 * n_channels; // Account for being skipped above
- }
-
- if (tail_N) {
- // Only evaluate the upper-left cell of the output
- for (int channel = 0; channel < n_channels; channel++) {
- // Read values from the input pointers
- float F[3][3];
- for (int i = 0; i < 3; i++) {
- for (int j = 0; j < 3; j++) {
- F[i][j] = *(inptr[i*4 + j]);
- }
- }
- for (int i = 0; i < 16; i++) {
- inptr[i]++;
- }
-
- // Compute the matrix F.Z
- float ZF[3][1];
- ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
- ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
- ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
-
- // Hence compute the output matrix Z^T . (F.Z)
- *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
- }
-
- // Progress the input pointers to the next row
- for (int i = 0; i < 16; i++) {
- inptr[i] += matrix_row_stride - n_channels;
- }
-
- // Progress the output pointers to the next column
- outptr01 += n_channels; // Account for being skipped above
- outptr10 += n_channels; // Account for being skipped above
- outptr11 += n_channels; // Account for being skipped above
- }
- }
- }
-}
-
-/*****************************************************************************/
-template <>
-inline void Winograd2x2_3x3GemmOutput<float>::execute(
- const Tensor4DShape &output_shape,
- float* const matrix_base,
- const int matrix_stride,
- const int matrix_row_stride,
- float* const output
-) {
- // Dispatch to an appropriate implementation based on the shape of the output
- // tensor.
- if (output_shape.n_rows % 2 && output_shape.n_cols % 2) {
- constexpr bool tail_M = true, tail_N = true;
- switch (output_shape.n_channels % 4) {
- case 0:
- _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 1:
- _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 2:
- _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 3:
- _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- default:
- assert(0);
- break;
- }
- } else if (output_shape.n_rows % 2) {
- constexpr bool tail_M = true, tail_N = false;
- switch (output_shape.n_channels % 4) {
- case 0:
- _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 1:
- _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 2:
- _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 3:
- _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- default:
- assert(0);
- break;
- }
- } else if (output_shape.n_cols % 2) {
- constexpr bool tail_M = false, tail_N = true;
- switch (output_shape.n_channels % 4) {
- case 0:
- _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 1:
- _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 2:
- _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 3:
- _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- default:
- assert(0);
- break;
-
- }
- } else {
- constexpr bool tail_M = false, tail_N = false;
- switch (output_shape.n_channels % 4) {
- case 0:
- _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 1:
- _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 2:
- _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 3:
- _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- default:
- assert(0);
- break;
-
- }
- }
-}
-/*****************************************************************************/
-
-} // namespace winograd
-#endif // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp
deleted file mode 100644
index f551b12b52..0000000000
--- a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp
+++ /dev/null
@@ -1,655 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-#ifdef __aarch64__
-
-/*****************************************************************************/
-// Compute ZF specializations
-
-template <>
-template <>
-inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::compute_zf<0>(
- const int n_rows, const int n_channels,
- float* output, const float* const input[16]
-) {
- // Make copies of some variables
- int row = n_rows;
- float* outptr = output;
- const float* inptr = input[0];
-
- // Perform the transformation
- asm volatile (
- // "inptr0 .req %x[inptr]\n"
- "inptr1 .req x0\n"
- "inptr2 .req x1\n"
- "inptr3 .req x2\n"
- "inptr4 .req x3\n"
- "inptr5 .req x4\n"
- "inptr6 .req x5\n"
- "inptr7 .req x6\n"
- "inptr8 .req x7\n"
- "inptr9 .req x8\n"
- "inptr10 .req x9\n"
- "inptr11 .req x10\n"
- "inptr12 .req x11\n"
- "inptr13 .req x12\n"
- "inptr14 .req x13\n"
- "inptr15 .req x14\n"
-
- // "outptr0 .req %x[outptr]\n"
- "outptr1 .req x15\n"
- "outptr2 .req x16\n"
- "outptr3 .req x17\n"
- "outptr4 .req x18\n"
- "outptr5 .req x19\n"
- "outptr6 .req x20\n"
- "outptr7 .req x21\n"
-
- // Compute additional pointers into the input and output matrices.
- "mstride .req x22\n" // Matrix stride
- "mul mstride, %x[row], %x[n_channels]\n"
- "lsl mstride, mstride, #2\n" // * sizeof(float)
-
- "add inptr1, %x[inptr], mstride\n"
- "add inptr2, %x[inptr], mstride, LSL #1\n"
- "add inptr3, inptr2, mstride\n"
- "add inptr4, inptr3, mstride\n"
- "add inptr5, inptr4, mstride\n"
- "add inptr6, inptr5, mstride\n"
- "add inptr7, inptr6, mstride\n"
- "add inptr8, inptr7, mstride\n"
- "add inptr9, inptr8, mstride\n"
- "add inptr10, inptr9, mstride\n"
- "add inptr11, inptr10, mstride\n"
- "add inptr12, inptr11, mstride\n"
- "add inptr13, inptr12, mstride\n"
- "add inptr14, inptr13, mstride\n"
- "add inptr15, inptr14, mstride\n"
-
- "add outptr1, %[outptr], mstride\n"
- "add outptr2, outptr1, mstride\n"
- "add outptr3, outptr2, mstride\n"
- "add outptr4, outptr3, mstride\n"
- "add outptr5, outptr4, mstride\n"
- "add outptr6, outptr5, mstride\n"
- "add outptr7, outptr6, mstride\n"
-
- ".unreq mstride\n"
-
- "column .req x22\n" // Column loop counter
-
- "1:" // Loop over rows
- "ldr q0, [%x[inptr]], #0x10\n"
- "ldr q1, [inptr1], #0x10\n"
- "ldr q2, [inptr2], #0x10\n"
- "ldr q3, [inptr3], #0x10\n"
- "ldr q4, [inptr4], #0x10\n"
- "ldr q5, [inptr5], #0x10\n"
- "ldr q6, [inptr6], #0x10\n"
- "ldr q7, [inptr7], #0x10\n"
- "subs column, %x[n_channels], #0x4\n"
- "beq 3f\n"
-
- "2:" // Loop over columns
- "ldr q8, [inptr8], #0x10\n"
- "prfm pldl1keep, [%x[inptr], #196]\n"
- "fadd v16.4s, v0.4s, v1.4s\n"
-
- "ldr q9, [inptr9], #0x10\n"
- "prfm pldl1keep, [inptr1, #196]\n"
- "fsub v17.4s, v1.4s, v2.4s\n"
-
- "ldr q10, [inptr10], #0x10\n"
- "prfm pldl1keep, [inptr2, #196]\n"
- "fadd v16.4s, v16.4s, v2.4s\n"
-
- "ldr q11, [inptr11], #0x10\n"
- "prfm pldl1keep, [inptr3, #196]\n"
- "fsub v17.4s, v17.4s, v3.4s\n"
-
- "ldr q12, [inptr12], #0x10\n"
- "prfm pldl1keep, [inptr4, #196]\n"
- "str q16, [%x[outptr]], #0x10\n"
-
- "ldr q13, [inptr13], #0x10\n"
- "prfm pldl1keep, [inptr5, #196]\n"
- "str q17, [outptr1], #0x10\n"
-
- "ldr q14, [inptr14], #0x10\n"
- "prfm pldl1keep, [inptr6, #196]\n"
- "fadd v16.4s, v4.4s, v5.4s\n"
-
- "ldr q15, [inptr15], #0x10\n"
- "prfm pldl1keep, [inptr7, #196]\n"
- "fsub v17.4s, v5.4s, v6.4s\n"
-
- "ldr q0, [%x[inptr]], #0x10\n"
- "prfm pldl1keep, [inptr8, #196]\n"
- "fadd v16.4s, v16.4s, v6.4s\n"
-
- "ldr q1, [inptr1], #0x10\n"
- "prfm pldl1keep, [inptr9, #196]\n"
- "fsub v17.4s, v17.4s, v7.4s\n"
-
- "ldr q2, [inptr2], #0x10\n"
- "prfm pldl1keep, [inptr10, #196]\n"
- "str q16, [outptr2], #0x10\n"
-
- "ldr q3, [inptr3], #0x10\n"
- "prfm pldl1keep, [inptr11, #196]\n"
- "str q17, [outptr3], #0x10\n"
-
- "ldr q4, [inptr4], #0x10\n"
- "prfm pldl1keep, [inptr12, #196]\n"
- "fadd v16.4s, v8.4s, v9.4s\n"
-
- "ldr q5, [inptr5], #0x10\n"
- "prfm pldl1keep, [inptr13, #196]\n"
- "fsub v17.4s, v9.4s, v10.4s\n"
-
- "ldr q6, [inptr6], #0x10\n"
- "prfm pldl1keep, [inptr14, #196]\n"
- "fadd v16.4s, v16.4s, v10.4s\n"
-
- "ldr q7, [inptr7], #0x10\n"
- "prfm pldl1keep, [inptr15, #196]\n"
- "fsub v17.4s, v17.4s, v11.4s\n"
-
- "str q16, [outptr4], #0x10\n"
- "fadd v16.4s, v12.4s, v13.4s\n"
- "fsub v18.4s, v13.4s, v14.4s\n"
-
- "str q17, [outptr5], #0x10\n"
- "fadd v16.4s, v16.4s, v14.4s\n"
- "fsub v18.4s, v18.4s, v15.4s\n"
-
- "str q16, [outptr6], #0x10\n"
- "subs column, column, #0x4\n"
-
- "str q18, [outptr7], #0x10\n"
- "bne 2b\n"
-
- "3:" // Tail
- "ldr q8, [inptr8], #0x10\n"
- "prfm pldl1keep, [%x[inptr], #196]\n"
- "fadd v16.4s, v0.4s, v1.4s\n"
-
- "ldr q9, [inptr9], #0x10\n"
- "prfm pldl1keep, [inptr1, #196]\n"
- "fsub v17.4s, v1.4s, v2.4s\n"
-
- "ldr q10, [inptr10], #0x10\n"
- "prfm pldl1keep, [inptr2, #196]\n"
- "fadd v16.4s, v16.4s, v2.4s\n"
-
- "ldr q11, [inptr11], #0x10\n"
- "prfm pldl1keep, [inptr3, #196]\n"
- "fsub v17.4s, v17.4s, v3.4s\n"
-
- "ldr q12, [inptr12], #0x10\n"
- "prfm pldl1keep, [inptr4, #196]\n"
- "str q16, [%x[outptr]], #0x10\n"
-
- "ldr q13, [inptr13], #0x10\n"
- "prfm pldl1keep, [inptr5, #196]\n"
- "str q17, [outptr1], #0x10\n"
-
- "ldr q14, [inptr14], #0x10\n"
- "prfm pldl1keep, [inptr6, #196]\n"
- "fadd v16.4s, v4.4s, v5.4s\n"
-
- "ldr q15, [inptr15], #0x10\n"
- "prfm pldl1keep, [inptr7, #196]\n"
- "fsub v17.4s, v5.4s, v6.4s\n"
-
- "prfm pldl1keep, [inptr8, #196]\n"
- "prfm pldl1keep, [inptr9, #196]\n"
- "fadd v16.4s, v16.4s, v6.4s\n"
-
- "prfm pldl1keep, [inptr10, #196]\n"
- "prfm pldl1keep, [inptr11, #196]\n"
- "fsub v17.4s, v17.4s, v7.4s\n"
-
- "prfm pldl1keep, [inptr12, #196]\n"
- "prfm pldl1keep, [inptr13, #196]\n"
- "str q16, [outptr2], #0x10\n"
-
- "prfm pldl1keep, [inptr14, #196]\n"
- "prfm pldl1keep, [inptr15, #196]\n"
- "str q17, [outptr3], #0x10\n"
-
- "fadd v16.4s, v8.4s, v9.4s\n"
- "fsub v17.4s, v9.4s, v10.4s\n"
-
- "fadd v16.4s, v16.4s, v10.4s\n"
- "fsub v17.4s, v17.4s, v11.4s\n"
-
- "str q16, [outptr4], #0x10\n"
- "fadd v16.4s, v12.4s, v13.4s\n"
- "fsub v18.4s, v13.4s, v14.4s\n"
-
- "str q17, [outptr5], #0x10\n"
- "fadd v16.4s, v16.4s, v14.4s\n"
- "fsub v18.4s, v18.4s, v15.4s\n"
-
- "str q16, [outptr6], #0x10\n"
- "str q18, [outptr7], #0x10\n"
-
- "subs %x[row], %x[row], #0x1\n"
- "bne 1b\n"
-
- ".unreq inptr1\n"
- ".unreq inptr2\n"
- ".unreq inptr3\n"
- ".unreq inptr4\n"
- ".unreq inptr5\n"
- ".unreq inptr6\n"
- ".unreq inptr7\n"
- ".unreq inptr8\n"
- ".unreq inptr9\n"
- ".unreq inptr10\n"
- ".unreq inptr11\n"
- ".unreq inptr12\n"
- ".unreq inptr13\n"
- ".unreq inptr14\n"
- ".unreq inptr15\n"
- ".unreq outptr1\n"
- ".unreq outptr2\n"
- ".unreq outptr3\n"
- ".unreq outptr4\n"
- ".unreq outptr5\n"
- ".unreq outptr6\n"
- ".unreq outptr7\n"
-
- : [row] "+r" (row),
- [inptr] "+r" (inptr),
- [outptr] "+r" (outptr)
- : [n_channels] "r" (n_channels),
- [sizeof_float] "i" (sizeof(float))
- : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
- "q12", "q13", "q14", "q15", "q16", "q17", "x0", "x1", "x2", "x3", "x4",
- "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15",
- "x16", "x17", "x18", "x19", "x20", "x21", "x22", "cc", "memory"
- );
-}
-
-/*****************************************************************************/
-// Compute ZFZ^T specializations
-
-template <>
-template <>
-inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::compute_zfzT<false, false, 0>(
- const Tensor4DShape &output_shape,
- float* const output, const float* const input
-) {
- const int tile_M = output_shape.n_rows / 2;
- const int tile_N = output_shape.n_cols / 2;
- int batch = output_shape.n_batches;
- float *outptr = output;
- const float *inptr = input;
-
- asm volatile (
- // Compute input pointers
- "inptr1 .req x0\n"
- "inptr2 .req x1\n"
- "inptr3 .req x2\n"
- "inptr4 .req x3\n"
- "inptr5 .req x4\n"
- "inptr6 .req x5\n"
- "inptr7 .req x6\n"
- "inptr8 .req x7\n"
-
- "mstride .req x8\n"
- "mul mstride, %x[tile_M], %x[tile_N]\n"
- "mul mstride, mstride, %x[n_channels]\n"
- "lsl mstride, mstride, #2\n" // * sizeof(float)
-
- "add inptr1, %[inptr], mstride\n"
- "add inptr2, inptr1, mstride\n"
- "add inptr3, inptr2, mstride\n"
- "add inptr4, inptr3, mstride\n"
- "add inptr5, inptr4, mstride\n"
- "add inptr6, inptr5, mstride\n"
- "add inptr7, inptr6, mstride\n"
- "add inptr8, inptr7, mstride\n"
-
- ".unreq mstride\n"
-
- // Compute initial output pointers
- "outptr01 .req x8\n"
- "outptr10 .req x9\n"
- "outptr11 .req x10\n"
-
- "add outptr01, %x[outptr], %x[n_channels], LSL #2\n"
- "add outptr10, %x[outptr], %x[row_stride], LSL #2\n"
- "add outptr11, outptr10, %x[n_channels], LSL #2\n"
-
- "tile_i .req x11\n"
- "tile_j .req x12\n"
- "channel .req x13\n"
-
- "1:" // Loop over batches
- "mov tile_i, %x[tile_M]\n"
-
- "2:" // Loop over rows of output tiles
- "mov tile_j, %x[tile_N]\n"
-
- "3:" // Loop over columns of output tiles
- "ldr q0, [%x[inptr]], #0x10\n"
- "ldr q2, [inptr2], #0x10\n"
- "subs channel, %x[n_channels], #0x4\n"
-
- "ldr q1, [inptr1], #0x10\n"
- "ldr q3, [inptr3], #0x10\n"
- "beq 6f\n"
-
- "4:"
- "ldr q4, [inptr4], #0x10\n"
- "ldr q5, [inptr5], #0x10\n"
- "fadd v16.4s, v0.4s, v2.4s\n"
-
- "ldr q6, [inptr6], #0x10\n"
- "ldr q7, [inptr7], #0x10\n"
- "fadd v17.4s, v1.4s, v3.4s\n"
-
- "ldr q8, [%x[inptr]], #0x10\n"
- "ldr q10, [inptr2], #0x10\n"
- "fadd v16.4s, v16.4s, v4.4s\n"
-
- "ldr q9, [inptr1], #0x10\n"
- "ldr q11, [inptr3], #0x10\n"
- "fadd v17.4s, v17.4s, v5.4s\n"
-
- "str q16, [%x[outptr]], #0x10\n"
- "prfm pldl1strm, [%x[inptr], #196]\n"
- "fsub v18.4s, v2.4s, v4.4s\n"
-
- "str q17, [outptr01], #0x10\n"
- "prfm pldl1strm, [inptr2, #196]\n"
- "fsub v19.4s, v3.4s, v5.4s\n"
-
- "prfm pldl1strm, [inptr1, #196]\n"
- "prfm pldl1strm, [inptr3, #196]\n"
- "fsub v18.4s, v18.4s, v6.4s\n"
-
- "prfm pldl1strm, [inptr4, #196]\n"
- "prfm pldl1strm, [inptr5, #196]\n"
- "fsub v19.4s, v19.4s, v7.4s\n"
-
- "str q18, [outptr10], #0x10\n"
- "prfm pldl1strm, [inptr6, #196]\n"
- "prfm pldl1strm, [inptr7, #196]\n"
-
- "subs channel, channel, #0x4\n"
-
- "str q19, [outptr11], #0x10\n"
- "beq 6f\n" // Branch to tail
-
- "ldr q12, [inptr4], #0x10\n"
- "ldr q13, [inptr5], #0x10\n"
- "fadd v16.4s, v8.4s, v10.4s\n"
-
- "ldr q14, [inptr6], #0x10\n"
- "ldr q15, [inptr7], #0x10\n"
- "fadd v17.4s, v9.4s, v11.4s\n"
-
- "ldr q0, [%x[inptr]], #0x10\n"
- "ldr q2, [inptr2], #0x10\n"
- "fadd v16.4s, v16.4s, v12.4s\n"
-
- "ldr q1, [inptr1], #0x10\n"
- "ldr q3, [inptr3], #0x10\n"
- "fadd v17.4s, v17.4s, v13.4s\n"
-
- "str q16, [%x[outptr]], #0x10\n"
- "prfm pldl1strm, [%x[inptr], #196]\n"
- "fsub v18.4s, v10.4s, v12.4s\n"
-
- "str q17, [outptr01], #0x10\n"
- "prfm pldl1strm, [inptr2, #196]\n"
- "fsub v19.4s, v11.4s, v13.4s\n"
-
- "prfm pldl1strm, [inptr1, #196]\n"
- "prfm pldl1strm, [inptr3, #196]\n"
- "fsub v18.4s, v18.4s, v14.4s\n"
-
- "prfm pldl1strm, [inptr4, #196]\n"
- "prfm pldl1strm, [inptr5, #196]\n"
- "fsub v19.4s, v19.4s, v15.4s\n"
-
- "str q18, [outptr10], #0x10\n"
- "prfm pldl1strm, [inptr6, #196]\n"
- "prfm pldl1strm, [inptr7, #196]\n"
-
- "subs channel, channel, #0x4\n"
-
- "str q19, [outptr11], #0x10\n"
- "bne 4b\n" // Continue loop
-
- "5:" // Tail
- "ldr q12, [inptr4], #0x10\n"
- "ldr q13, [inptr5], #0x10\n"
- "fadd v16.4s, v8.4s, v10.4s\n"
-
- "ldr q14, [inptr6], #0x10\n"
- "ldr q15, [inptr7], #0x10\n"
- "fadd v17.4s, v9.4s, v11.4s\n"
-
- "fadd v16.4s, v16.4s, v12.4s\n"
-
- "fadd v17.4s, v17.4s, v13.4s\n"
-
- "str q16, [%x[outptr]], #0x10\n"
- "fsub v18.4s, v10.4s, v12.4s\n"
- "fsub v19.4s, v11.4s, v13.4s\n"
-
- "str q17, [outptr01], #0x10\n"
- "fsub v18.4s, v18.4s, v14.4s\n"
- "fsub v19.4s, v19.4s, v15.4s\n"
-
- "str q18, [outptr10], #0x10\n"
- "str q19, [outptr11], #0x10\n"
- "b 7f\n"
-
- "6:" // Tail
- "ldr q4, [inptr4], #0x10\n"
- "ldr q5, [inptr5], #0x10\n"
- "fadd v16.4s, v0.4s, v2.4s\n"
-
- "ldr q6, [inptr6], #0x10\n"
- "ldr q7, [inptr7], #0x10\n"
- "fadd v17.4s, v1.4s, v3.4s\n"
-
- "fadd v16.4s, v16.4s, v4.4s\n"
-
- "fadd v17.4s, v17.4s, v5.4s\n"
-
- "str q16, [%x[outptr]], #0x10\n"
- "fsub v18.4s, v2.4s, v4.4s\n"
- "fsub v19.4s, v3.4s, v5.4s\n"
-
- "str q17, [outptr01], #0x10\n"
- "fsub v18.4s, v18.4s, v6.4s\n"
- "fsub v19.4s, v19.4s, v7.4s\n"
-
- "str q18, [outptr10], #0x10\n"
- "str q19, [outptr11], #0x10\n"
-
- "7:"
- "add %x[outptr], %x[outptr], %x[n_channels], LSL #2\n"
- "add outptr01, outptr01, %x[n_channels], LSL #2\n"
- "add outptr10, outptr10, %x[n_channels], LSL #2\n"
- "add outptr11, outptr11, %x[n_channels], LSL #2\n"
-
- "subs tile_j, tile_j, #1\n"
- "bne 3b\n"
-
- // Progress the output pointers to the new row
- "add %x[outptr], %x[outptr], %x[row_stride], LSL #2\n"
- "add outptr01, outptr01, %x[row_stride], LSL #2\n"
- "add outptr10, outptr10, %x[row_stride], LSL #2\n"
- "add outptr11, outptr11, %x[row_stride], LSL #2\n"
-
- "subs tile_i, tile_i, #1\n"
- "bne 2b\n"
-
- "subs %[batch], %[batch], #1\n"
- "bne 1b\n"
- "5:"
-
- ".unreq inptr1\n"
- ".unreq inptr2\n"
- ".unreq inptr3\n"
- ".unreq inptr4\n"
- ".unreq inptr5\n"
- ".unreq inptr6\n"
- ".unreq inptr7\n"
- ".unreq inptr8\n"
- ".unreq outptr01\n"
- ".unreq outptr10\n"
- ".unreq outptr11\n"
- : [batch] "+r" (batch),
- [outptr] "+r" (outptr),
- [inptr] "+r" (inptr)
- : [tile_M] "r" (tile_M),
- [tile_N] "r" (tile_N),
- [n_channels] "r" (output_shape.n_channels),
- [row_stride] "r" (output_shape.n_cols * output_shape.n_channels)
- : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11",
- "x12", "x13", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
- "cc", "memory"
- );
-}
-/*****************************************************************************/
-
-/*****************************************************************************/
-template <>
-inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::execute(
- const Tensor4DShape &output_shape,
- float* const matrices[16], float* const output
-) {
- // profiler prof;
-
- // Allocate memory for the intermediate matrices
- const int tile_M = iceildiv(output_shape.n_rows, 2);
- const int tile_N = iceildiv(output_shape.n_cols, 2);
- const int n_rows = output_shape.n_batches * tile_M * tile_N;
- const int n_channels = output_shape.n_channels;
- float* matrices_zf = reinterpret_cast<float*>(
- calloc(8 * n_rows * n_channels, sizeof(float))
- );
-
- // Perform the first stage transform, computing ZF.
- const auto f_compute_zf = [&] () {
- switch (n_channels % 4) {
- case 0:
- compute_zf<0>(n_rows, n_channels, matrices_zf, matrices);
- break;
- case 1:
- compute_zf<1>(n_rows, n_channels, matrices_zf, matrices);
- break;
- case 2:
- compute_zf<2>(n_rows, n_channels, matrices_zf, matrices);
- break;
- case 3:
- compute_zf<3>(n_rows, n_channels, matrices_zf, matrices);
- };
- };
- // prof("Compute ZF", f_compute_zf, 16 * n_rows * n_channels * sizeof(float), 0, 8 * n_rows * n_channels * sizeof(float));
- f_compute_zf();
-
- // Perform the second stage transform, finishing Z F Z^T - variable dispatch
- // based on size of the output and the channel tail.
- const auto f_compute_zfzT = [&] () {
- if (output_shape.n_rows % 2 && output_shape.n_cols % 2) {
- constexpr bool tail_M = true, tail_N = true;
- switch (n_channels % 4) {
- case 0:
- compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
- break;
- case 1:
- compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
- break;
- case 2:
- compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
- break;
- case 3:
- compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
- }
- } else if (output_shape.n_rows % 2) {
- constexpr bool tail_M = true, tail_N = false;
- switch (n_channels % 4) {
- case 0:
- compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
- break;
- case 1:
- compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
- break;
- case 2:
- compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
- break;
- case 3:
- compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
- }
- } else if (output_shape.n_cols % 2) {
- constexpr bool tail_M = false, tail_N = true;
- switch (n_channels % 4) {
- case 0:
- compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
- break;
- case 1:
- compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
- break;
- case 2:
- compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
- break;
- case 3:
- compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
- }
- } else {
- constexpr bool tail_M = false, tail_N = false;
- switch (n_channels % 4) {
- case 0:
- compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
- break;
- case 1:
- compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
- break;
- case 2:
- compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
- break;
- case 3:
- compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
- }
- }
- };
- // prof("Compute ZFZT", f_compute_zfzT, 8 * n_rows * n_channels * sizeof(float), 0, 4 * n_rows * n_channels * sizeof(float));
- f_compute_zfzT();
-
- free(reinterpret_cast<void*>(matrices_zf));
-}
-/*****************************************************************************/
-
-#endif // __aarch64__