From 8951933e5dd7be8d922affea3cc23a48a05b694d Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Fri, 17 Nov 2017 11:52:36 +0000 Subject: COMPMID-687: Winograd layer. Change-Id: Ica682d08e851491bf4a26b8d17908c014844055e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/110990 Reviewed-by: Anthony Barbier Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com --- .../kernels/winograd/transforms/input_2x2_3x3.hpp | 638 +++++++++ .../transforms/input_2x2_3x3/a64_float.hpp | 1498 ++++++++++++++++++++ .../input_2x2_3x3/a64_float_channelwise.hpp | 961 +++++++++++++ .../kernels/winograd/transforms/kernel_2x2_3x3.hpp | 195 +++ .../transforms/kernel_2x2_3x3/a64_float.hpp | 822 +++++++++++ .../kernels/winograd/transforms/output_2x2_3x3.hpp | 356 +++++ .../transforms/output_2x2_3x3/a64_float.hpp | 650 +++++++++ .../output_2x2_3x3/a64_float_two_stage.hpp | 655 +++++++++ 8 files changed, 5775 insertions(+) create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp (limited to 'arm_compute/core/NEON/kernels/winograd/transforms') diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp new file mode 100644 index 0000000000..7013c66ac0 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp @@ -0,0 +1,638 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include "../tensor.hpp" + +namespace winograd { + /* Transform an input tensor into the Winograd domain. + */ + template + struct Winograd2x2_3x3GemmInput { + static void execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride + ); + + static size_t bytes_read(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + return input_shape.n_batches * tile_rows * (16 + 8*(tile_cols - 1)) * input_shape.n_channels * sizeof(T); + } + + static int flops_performed(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + return input_shape.n_batches * tile_rows * (32 + 24*(tile_cols - 1)) * input_shape.n_channels; + } + + static size_t bytes_written(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + const int M = input_shape.n_batches * tile_rows * tile_cols; + return 16 * M * input_shape.n_channels * sizeof(T); + } + + protected: + template + static void process_tile_tensor( + const int tile_M, // Number of rows of tiles + const int tile_N, // Number of columns of tiles + int n_channels, // Number of input channels + const T* const input, // Base input pointer (appropriate to batch and channel) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + T* const matrix, // 1st output matrix (appropriate to batch and channel) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix + ); + + template + static void process_tile_row( + const int tile_N, // Number of tiles in the row + const T* const input, // Base input pointer (appropriate to batch, channel and row) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + T* const matrix, // 1st output matrix (appropriate to batch, channel and row) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix + ); + }; + + template + struct Winograd2x2_3x3GemmInputChannelwise { + static void execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride + ); + + static size_t bytes_read(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + // We read as many bytes as we write + return bytes_written(input_shape, output_shape); + } + + static int flops_performed(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + return input_shape.n_batches * tile_rows * 32 * tile_cols * input_shape.n_channels; + } + + static size_t bytes_written(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + return winograd::Winograd2x2_3x3GemmInput::bytes_written(input_shape, output_shape); + } + + protected: + typedef void (*tilefunc)(int, const T*, int, int, T*, int); + template + static void process_tile( + int n_channels, // Number of channels in the tile + const T* const input_base, + const int input_row_stride, + const int input_col_stride, + T* const matrix_base, + const int matrix_stride + ); + + private: + template + static void _process_tile( + int &n_channels, const T* &inptr, + const int input_row_stride, const int input_col_stride, + T* &outptr, const int matrix_stride + ); + }; +} + +/*****************************************************************************/ +// Include specialised implementations here +#include "input_2x2_3x3/a64_float.hpp" +#include "input_2x2_3x3/a64_float_channelwise.hpp" +/*****************************************************************************/ + +/*****************************************************************************/ +template +void winograd::Winograd2x2_3x3GemmInput::execute( + const T *inptr_base, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride +) { + // Select an appropriate matrix processing method for the shape and padding + // of the input tensor. + typedef void (*tensorfunc)(int, int, int, const T*, int, int, T*, int, int); + const auto process_tensor = [&padding_type, &input_shape] () -> tensorfunc { + if (padding_type == PADDING_VALID) { + const int pad_bottom = input_shape.n_rows % 2; + const int pad_right = input_shape.n_cols % 2; + + if (pad_bottom == 0 && pad_right == 0) { + return process_tile_tensor; + } else if (pad_bottom == 0 && pad_right == 1) { + return process_tile_tensor; + } else if (pad_bottom == 1 && pad_right == 0) { + return process_tile_tensor; + } else if (pad_bottom == 1 && pad_right == 1) { + return process_tile_tensor; + } + } else { // PADDING_SAME + const int pad_bottom = 1 + input_shape.n_rows % 2; + const int pad_right = 1 + input_shape.n_cols % 2; + + if (pad_bottom == 1 && pad_right == 1) { + return process_tile_tensor; + } else if (pad_bottom == 1 && pad_right == 2) { + return process_tile_tensor; + } else if (pad_bottom == 2 && pad_right == 1) { + return process_tile_tensor; + } else if (pad_bottom == 2 && pad_right == 2) { + return process_tile_tensor; + } + } + + printf("%s::%u Uncovered case.\n", __FILE__, __LINE__); + exit(-1); + return NULL; // No function found + } (); + + // Compute strides + const int input_row_stride = input_shape.n_cols * input_shape.n_channels; + const int input_col_stride = input_shape.n_channels; + + // Process each batch of the tensor in turn. + for (int batch = 0; batch < input_shape.n_batches; batch++) { + // Work out pointers + const T *inptr = inptr_base + (batch * input_shape.n_rows * + input_shape.n_cols * input_shape.n_channels); + T *outptr = outptr_base + batch * matrix_batch_stride; + + // Delegate doing the actual work + process_tensor( + tile_M, tile_N, input_shape.n_channels, + inptr, input_row_stride, input_col_stride, + outptr, matrix_stride, matrix_row_stride + ); + } +} + +/*****************************************************************************/ +template +template +void winograd::Winograd2x2_3x3GemmInput::process_tile_tensor( + const int tile_M, // Number of rows of tiles + const int tile_N, // Number of columns of tiles + int n_channels, // Number of input channels + const T* const input, // Base input pointer (appropriate to batch and channel) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + T* const matrix, // 1st output matrix (appropriate to batch and channel) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix +) { + // Base row processing functions + typedef void (*rowfunc)(int, const T*, int, int, T*, int, int); + const rowfunc process_top_row[3] = { + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 1> + : process_tile_row<1, 1, 0, pad_right, 1>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 2> + : process_tile_row<1, 1, 0, pad_right, 2>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 4> + : process_tile_row<1, 1, 0, pad_right, 4>, + }; + const rowfunc process_middle_row[3] = { + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 1> + : process_tile_row<0, 1, 0, pad_right, 1>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 2> + : process_tile_row<0, 1, 0, pad_right, 2>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 4> + : process_tile_row<0, 1, 0, pad_right, 4>, + }; + const rowfunc process_bottom_row[3] = { + (padding == PADDING_VALID) + ? process_tile_row<0, 0, pad_bottom, pad_right, 1> + : process_tile_row<0, 1, pad_bottom, pad_right, 1>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, pad_bottom, pad_right, 2> + : process_tile_row<0, 1, pad_bottom, pad_right, 2>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, pad_bottom, pad_right, 4> + : process_tile_row<0, 1, pad_bottom, pad_right, 4>, + }; + + // Method to get an input pointer for the given tile row + const auto get_inptr = [&input, &input_row_stride] (const int tile_i) { + if (padding == PADDING_VALID) { + return input + 2 * tile_i * input_row_stride; + } else { + return input + (2 * tile_i - (tile_i ? 1 : 0)) * input_row_stride; + } + }; + + // Wrapper to process a row of tiles, covering all channels. + const auto process_row = + [tile_N, input_row_stride, input_col_stride, matrix_stride, matrix_row_stride, n_channels] + (const rowfunc f[3], const T *inptr, T *outptr) { + int rem_channels = n_channels; + + // While there remain channels to process continue to process the + // row. + for (; rem_channels >= 4; rem_channels -= 4, inptr += 4, outptr += 4) { + f[2](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); + } + for (; rem_channels >= 2; rem_channels -= 2, inptr += 2, outptr += 2) { + f[1](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); + } + if (rem_channels) { + f[0](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); + } + }; + + // Process all rows of tiles in the tensor + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + T* const m_row = matrix + tile_i * tile_N * matrix_row_stride; + const T *row_inptr = get_inptr(tile_i); + + if (tile_i == 0) { + // Top row of the input + process_row(process_top_row, row_inptr, m_row); + } else if (tile_i == tile_M - 1) { + // Bottom row of the input + process_row(process_bottom_row, row_inptr, m_row); + } else { + // Any other row of the input + process_row(process_middle_row, row_inptr, m_row); + } + } +} + +/*****************************************************************************/ +template +template +void winograd::Winograd2x2_3x3GemmInput::process_tile_row( + const int tile_N, // Number of tiles in the row + const T* const input, // Base input pointer (appropriate to batch, channel and row) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + T* const matrix, // 1st output matrix (appropriate to batch, channel and row) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix +) { + // Construct copies of the pointers + const T *inptr = input; + T *outptr = matrix; + + // Storage for the tensors x, X.T x, and X.T x X. + T x[4][4][proc_channels], XTx[4][4][proc_channels], XTxX[4][4][proc_channels]; + + // For every tile in the row + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + // Determine the padding for the tile + const int tile_pad_left = (tile_j == 0) ? pad_left : 0; + const int tile_pad_right = (tile_j == tile_N - 1) ? pad_right : 0; + + // Load tile values. If this is the first tile in the row then we must load + // all values, otherwise we can just load the final two columns of the input. + for (int i = 0; i < 4; i++) { + for (int j = ((tile_j == 0) ? 0 : 2); j < 4; j++) { + // Fill with padding if required + if (i < pad_top || 4 - pad_bottom <= i || + j < tile_pad_left || 4 - tile_pad_right <= j) { + for (int c = 0; c < proc_channels; c++) { + x[i][j][c] = static_cast(0); // Padding + } + } else { + // Load values, note that the initial padding offsets the pointer we + // were provided. + for (int c = 0; c < proc_channels; c++) { + const int row_offset = (i - pad_top) * input_row_stride; + const int col_offset = (j - tile_pad_left) * input_col_stride; + x[i][j][c] = inptr[row_offset + col_offset + c]; + } + } + } + } + + // Compute the matrix X.T x. Note, can elide operations depending on the + // padding. Furthermore, if this isn't the left-most tile we can skip half + // of the operations by copying results from the previous version of X.T x. + // This latter optimisation can be simplified by unrolling the outermost + // loop by two and by renaming the registers containing XTx. + if (tile_j == 0) { + for (int j = 0; j < 4; j++) { + for (int c = 0; c < proc_channels; c++) { + XTx[0][j][c] = x[0][j][c] - x[2][j][c]; + XTx[1][j][c] = x[1][j][c] + x[2][j][c]; + XTx[2][j][c] = -x[1][j][c] + x[2][j][c]; + XTx[3][j][c] = x[1][j][c] - x[3][j][c]; + } + } + } else { + for (int j = 0; j < 2; j++) { + for (int c = 0; c < proc_channels; c++) { + XTx[0][j][c] = XTx[0][j + 2][c]; + XTx[1][j][c] = XTx[1][j + 2][c]; + XTx[2][j][c] = XTx[2][j + 2][c]; + XTx[3][j][c] = XTx[3][j + 2][c]; + } + } + for (int j = 2; j < 4; j++) { + for (int c = 0; c < proc_channels; c++) { + XTx[0][j][c] = x[0][j][c] - x[2][j][c]; + XTx[1][j][c] = x[1][j][c] + x[2][j][c]; + XTx[2][j][c] = -x[1][j][c] + x[2][j][c]; + XTx[3][j][c] = x[1][j][c] - x[3][j][c]; + } + } + } + + // Compute the matrix X.T x X. Note, can elide operations based on the + // padding. + for (int i = 0; i < 4; i++) { + for (int c = 0; c < proc_channels; c++) { + XTxX[i][0][c] = XTx[i][0][c] - XTx[i][2][c]; + XTxX[i][1][c] = XTx[i][1][c] + XTx[i][2][c]; + XTxX[i][2][c] = -XTx[i][1][c] + XTx[i][2][c]; + XTxX[i][3][c] = XTx[i][1][c] - XTx[i][3][c]; + } + } + + // Store the output matrix (X.T x X) + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + // Get a pointer to the relevant output matrix + T *mptr = outptr + (i*4 + j)*matrix_stride; + + // Write out the channels + for (int c = 0; c < proc_channels; c++) { + mptr[c] = XTxX[i][j][c]; + } + } + } + + // Update the pointers + inptr += input_col_stride * ((tile_j == 0 && pad_left) ? 1 : 2); + outptr += matrix_row_stride; + } +} + +/*****************************************************************************/ +template +void winograd::Winograd2x2_3x3GemmInputChannelwise::execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride +) { + const int n_channels = input_shape.n_channels; + const int input_col_stride = n_channels; + const int input_row_stride = input_shape.n_cols * input_col_stride; + + // Determine the padding and hence select appropriate methods for each tile. + tilefunc fs[3][3]; + + if (padding_type == PADDING_VALID) { + constexpr int pad_top = 0; + constexpr int pad_left = 0; + const int pad_right = input_shape.n_cols % 2 == 0; + + fs[0][0] = process_tile; + fs[0][1] = process_tile; + fs[0][2] = (pad_right) ? process_tile : process_tile; + + fs[1][0] = process_tile<0, pad_left, 0, 0>; + fs[1][1] = process_tile<0, 0, 0, 0>; + fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 0> : process_tile<0, 0, 0, 1>; + + if (input_shape.n_rows % 2 == 0) { + constexpr int pad_bottom = 0; + fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; + fs[2][1] = process_tile<0, 0, pad_bottom, 0>; + fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>; + } else { + constexpr int pad_bottom = 1; + fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; + fs[2][1] = process_tile<0, 0, pad_bottom, 0>; + fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>; + } + } else { + constexpr int pad_top = 1; + constexpr int pad_left = 1; + const int pad_right = input_shape.n_cols % 2 == 0; + + fs[0][0] = process_tile; + fs[0][1] = process_tile; + fs[0][2] = (pad_right) ? process_tile : process_tile; + + fs[1][0] = process_tile<0, pad_left, 0, 0>; + fs[1][1] = process_tile<0, 0, 0, 0>; + fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 1> : process_tile<0, 0, 0, 2>; + + if (input_shape.n_rows % 2 == 0) { + constexpr int pad_bottom = 1; + fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; + fs[2][1] = process_tile<0, 0, pad_bottom, 0>; + fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>; + } else { + constexpr int pad_bottom = 2; + fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; + fs[2][1] = process_tile<0, 0, pad_bottom, 0>; + fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>; + } + } + + // Process each tile in turn + for (int batch = 0; batch < input_shape.n_batches; batch++) { + const T* const input_base_batch = inptr + batch*input_shape.n_rows*input_shape.n_cols*n_channels; + + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + const int row_offset = (tile_i == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1); + const T* const input_base_row = input_base_batch + (2*tile_i - row_offset)*input_shape.n_cols*n_channels; + + // Select the set of functions for the row + const int fs_i = (tile_i == 0) ? 0 : ((tile_i < tile_M - 1) ? 1 : 2); + + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + // Select the function for the column + const int fs_j = (tile_j == 0) ? 0 : ((tile_j < tile_N - 1) ? 1 : 2); + const auto f = fs[fs_i][fs_j]; + + // Get pointers into the input and outputs + const int col_offset = (tile_j == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1); + const T* const input_base_col = input_base_row + (2*tile_j - col_offset)*n_channels; + T* const matrix_base = outptr_base + batch*matrix_batch_stride + (tile_i*tile_N + tile_j)*matrix_row_stride; + f(n_channels, input_base_col, input_row_stride, input_col_stride, + matrix_base, matrix_stride); + } + } + } +} + +template +template +void winograd::Winograd2x2_3x3GemmInputChannelwise::process_tile( + int n_channels, // Number of channels in the tile + const T* const input_base, + const int input_row_stride, + const int input_col_stride, + T* const matrix_base, + const int matrix_stride +) { + // Copy pointers + const T *inptr = input_base; + T *outptr = matrix_base; + + // Process channels (modifies inptr, outptr and n_channels) + _process_tile( + n_channels, inptr, input_row_stride, input_col_stride, + outptr, matrix_stride + ); + _process_tile( + n_channels, inptr, input_row_stride, input_col_stride, + outptr, matrix_stride + ); + _process_tile( + n_channels, inptr, input_row_stride, input_col_stride, + outptr, matrix_stride + ); +} + +template +template +void winograd::Winograd2x2_3x3GemmInputChannelwise::_process_tile( + int &n_channels, + const T* &inptr, const int input_row_stride, const int input_col_stride, + T* &outptr, const int matrix_stride +) { + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + T* outptrs[4] = { + outptr, + outptr + matrix_stride * 4, + outptr + matrix_stride * 8, + outptr + matrix_stride * 12 + }; + + // The matrix X; zeroed to account for padding. + T x[4][4]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + x[i][j] = 0; + } + } + + // The matrices X.T x and U + T XTx[4][4], U[4][4]; + + // Now progress through each channel + for (; n_channels >= proc_channels; n_channels -= proc_channels) { + for (int n = 0; n < proc_channels; n++) { + // Load the matrix X + for (int cell_i = pad_top, i = 0; cell_i < 4 - pad_bottom; cell_i++, i++) { + for (int cell_j = pad_left, j = 0; cell_j < 4 - pad_right; cell_j++, j++) { + x[cell_i][cell_j] = inptr[i*input_row_stride + j*input_col_stride]; + } + } + inptr++; + + // Compute the matrix X.T + for (int j = 0; j < 4; j++) { + XTx[0][j] = x[0][j] - x[2][j]; + XTx[1][j] = x[1][j] + x[2][j]; + XTx[2][j] = x[2][j] - x[1][j]; + XTx[3][j] = x[1][j] - x[3][j]; + } + + // Hence compute the matrix U + for (int i = 0; i < 4; i++) { + U[i][0] = XTx[i][0] - XTx[i][2]; + U[i][1] = XTx[i][1] + XTx[i][2]; + U[i][2] = XTx[i][2] - XTx[i][1]; + U[i][3] = XTx[i][1] - XTx[i][3]; + } + + // Store the matrix U + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + outptrs[i][j * matrix_stride] = U[i][j]; + } + outptrs[i]++; + } + } + } + + // Update the output pointer for future calls + outptr = outptrs[0]; +} diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp new file mode 100644 index 0000000000..a99cbe325b --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp @@ -0,0 +1,1498 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include "../input_2x2_3x3.hpp" + +#ifdef __aarch64__ +namespace winograd { + +// Pad left by one column, pad right by one column, no upper or lower padding, 4 channels +template <> +template <> +inline void Winograd2x2_3x3GemmInput::process_tile_row<0, 1, 0, 1, 4>( + const int tile_N, // Number of tiles in the row + const float* const input, // Base input pointer (appropriate to batch, channel and row) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + float* const matrix, // 1st output matrix (appropriate to batch, channel and row) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix +) { + /* SIMD register allocation + * ======================== + * + * In the following code we read 4x4 tiles of a matrix `x`, with which we + * compute another matrix `X.T x` where: + * + * / 1 0 0 0 \ + * X = | 0 1 -1 1 | + * | -1 1 1 0 | + * \ 0 0 0 -1 / + * + * Hence, `X.T` is a program which operates upon rows of the matrix `X`. + * We subsequently compute and store the matrix `U = (X.T x) X`. + * + * Importantly, each iteration of the loop below loads a new matrix `x'` + * where the final two columns of `x'` are the first two columns of the + * previous `x`. That is: + * + * x11 x12 x13 x14 + * x21 x22 x23 x24 + * x31 x32 x33 x34 + * x41 x42 x43 x44 + * + * x'11 x'12 x'13 x'14 + * x'21 x'22 x'23 x'24 + * x'31 x'32 x'33 x'34 + * x'41 x'42 x'43 x'44 + * + * Consequently, while the first iteration of the below loop must load 16 + * values for `x`, the second need load only 8. *Furthermore*, since we noted + * above that the operation `X.T x` was a program which operated upon *rows* + * of the matrix `x` it follows that that the relation that `x'[i][1] = + * x[i][3]` and `x'[i][2] = x[i][4]` applies also the matrices `X.T x'` and + * `X.T x`. That is: + * + * (X.T x)11 (X.T x)12 (X.T x)13 (X.T x)14 + * (X.T x)21 (X.T x)22 (X.T x)23 (X.T x)24 + * (X.T x)31 (X.T x)32 (X.T x)33 (X.T x)34 + * (X.T x)41 (X.T x)42 (X.T x)43 (X.T x)44 + * + * (X.T x')11 (X.T x')12 (X.T x')13 (X.T x')14 + * (X.T x')12 (X.T x')12 (X.T x')12 (X.T x')12 + * (X.T x')13 (X.T x')13 (X.T x')13 (X.T x')13 + * (X.T x')14 (X.T x')14 (X.T x')14 (X.T x')14 + * + * Hence, as well as not needing to load new values for x'[i][1..2] it is + * also unnecessary to recompute values for (X.T x')[i][1..2]. + * + * Following this we break the registers into blocks `A` and `B` used by the + * two stages of the unrolled loop. These registers named such that the + * latter columns of `A` become the earlier columns of `B` and vice-versa: + * + * AXTx11 AXTx12 > AXTx13 AXTx14 | + * AXTx21 AXTx22 > AXTx23 AXTx24 | + * AXTx31 AXTx32 > AXTx33 AXTx34 | + * AXTx41 AXTx42 > AXTx43 AXTx44 | + * + * BXTx13 BXTx14 | BXTx11 BXTx12 > + * BXTx23 BXTx24 | BXTx21 BXTx22 > + * BXTx33 BXTx34 | BXTx31 BXTx32 > + * BXTx43 BXTx44 | BXTx41 BXTx42 > + * + * These 32 named registers require only 16 architectural registers. 1 + * additional architectural register is used as scratch space and 8 + * architectural registers are used to load in the values x[1..4][3,4]. + * + * Input and output addressing + * =========================== + * TODO Description + */ + const float *inptr0 = input; + const float *inptr1 = input + input_row_stride; + const float *inptr2 = input + input_row_stride * 2; + const float *inptr3 = input + input_row_stride * 3; + + float *outptr0 = matrix; + float *outptr4 = matrix + matrix_stride * 4; + float *outptr8 = matrix + matrix_stride * 8; + float *outptr12 = matrix + matrix_stride * 12; + + int tile_j = tile_N; // Tiles to process + + asm volatile ( + // Named SIMD registers according to the policy given above + // Registers into which to load the latter two columns of `x` + "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n" + "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" + "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" + "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n" + + // Registers for storing X.T x (both A and B halves) + "AXTx11 .req v8\n" "BXTx13 .req v8\n" + "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" + "AXTx21 .req v10\n" "BXTx23 .req v10\n" + "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" + "AXTx31 .req v12\n" "BXTx33 .req v12\n" + "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" + "AXTx41 .req v14\n" "BXTx43 .req v14\n" + "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" + "AXTx13 .req v16\n" "BXTx11 .req v16\n" + "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" + "AXTx23 .req v18\n" "BXTx21 .req v18\n" + "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" + "AXTx33 .req v20\n" "BXTx31 .req v20\n" + "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" + "AXTx43 .req v22\n" "BXTx41 .req v22\n" + "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" + + // Result register (TODO Does using more registers yield better + // performance) + "U .req v24\n qU .req q24\n" + + // ---------------------------------------------------------------------- + // Head of loop + // Loads a complete 4x4 tile of x, computes X.T x, computes and stores + // `U = X.T x X`. Prepares for the 'A' half of the loop. + // NOTE: Since the first tile has the leftmost column padded we can + // skip 4 loads and 4 calculations for the matrix X.T x X. + + // Temporarily alias registers for computing the first (non-padded) + // column of x. + "x_12 .req v0\n qx_12 .req q0\n" + "x_22 .req v1\n qx_22 .req q1\n" + "x_32 .req v2\n qx_32 .req q2\n" + "x_42 .req v3\n qx_42 .req q3\n" + + "ldr qx_12, [%x[inptr0]]\n" + "ldr qx_22, [%x[inptr1]]\n" + "ldr qx_32, [%x[inptr2]]\n" + "ldr qx_42, [%x[inptr3]]\n" + + "fsub BXTx12.4s, x_12.4s, x_32.4s\n" + "fadd BXTx22.4s, x_22.4s, x_32.4s\n" + "fsub BXTx32.4s, x_32.4s, x_22.4s\n" + "fsub BXTx42.4s, x_22.4s, x_42.4s\n" + + ".unreq x_12\n .unreq qx_12\n" + ".unreq x_22\n .unreq qx_22\n" + ".unreq x_32\n .unreq qx_32\n" + ".unreq x_42\n .unreq qx_42\n" + + // Load and compute latter two columns of the first tile. Progress the + // input pointers (by three columns so that the each points are the + // second column of the next tile, that is, each points at the first + // column which must be read for the next tile. + "ldr qx_13, [%x[inptr0], %x[colstride1]]\n" + "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" + "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" + "ldr qx_43, [%x[inptr3], %x[colstride1]]\n" + + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride2]]\n" + + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" + + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" + + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride2]]\n" + + "fsub BXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride3]\n" + + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride3]\n" + + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride3]\n" + + "fsub BXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride3]\n" + + // Compute and store U for the first tile + // First row + "fneg U.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fneg U.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fneg U.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row, simultaneously load the first column of inputs for the + // next tile. + "fneg U.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + // Update the loop counter, subtract two to account for both the head and + // the tail. + "subs %x[tile_j], %x[tile_j], #2\n" + "beq 2f\n" // Jump to "A" tail if out of tiles + + // ---------------------------------------------------------------------- + "1:" + // Start part A + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fsub AXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "fsub AXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" + "fsub AXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride2]\n" + "fadd AXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub AXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "fsub AXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride2]\n" + + // Compute and store U. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, AXTx12.4s, AXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, AXTx22.4s, AXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, AXTx32.4s, AXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, AXTx42.4s, AXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + "subs %x[tile_j], %x[tile_j], #1\n" + "beq 3f\n" // Jump to 'B' tail + + // Start part B + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" + "fsub BXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride2]\n" + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "fsub BXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride2]\n" + + // Compute and store U. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + "subs %x[tile_j], %x[tile_j], #1\n" + "bne 1b\n" // Continue loop, otherwise flow into 'A' tail + + // ---------------------------------------------------------------------- + "2:" + // 'A' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fsub AXTx13.4s, x_13.4s, x_33.4s\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "fsub AXTx43.4s, x_23.4s, x_43.4s\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" + + "b 4f\n" // Jump to end of function + + // ---------------------------------------------------------------------- + "3:" + // 'B' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" + + // ---------------------------------------------------------------------- + "4:" + // End of function + + // Clear names + ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n" + ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" + ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" + ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n" + ".unreq AXTx11\n" ".unreq BXTx13\n" + ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" + ".unreq AXTx21\n" ".unreq BXTx23\n" + ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" + ".unreq AXTx31\n" ".unreq BXTx33\n" + ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" + ".unreq AXTx41\n" ".unreq BXTx43\n" + ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" + ".unreq AXTx13\n" ".unreq BXTx11\n" + ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" + ".unreq AXTx23\n" ".unreq BXTx21\n" + ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" + ".unreq AXTx33\n" ".unreq BXTx31\n" + ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" + ".unreq AXTx43\n" ".unreq BXTx41\n" + ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" + ".unreq U\n" ".unreq qU\n" + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [inptr3] "+r" (inptr3), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [tile_j] "+r" (tile_j) // Tile counter + : [colstride1] "r" (1 * input_col_stride * sizeof(float)), + [colstride2] "r" (2 * input_col_stride * sizeof(float)), + [colstride3] "r" (3 * input_col_stride * sizeof(float)), + [mstride1] "r" (1 * matrix_stride * sizeof(float)), + [mstride2] "r" (2 * matrix_stride * sizeof(float)), + [mstride3] "r" (3 * matrix_stride * sizeof(float)), + [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24" + ); +} + +// Pad top, left and right by 1. +template <> +template <> +inline void Winograd2x2_3x3GemmInput::process_tile_row<1, 1, 0, 1, 4>( + const int tile_N, + const float* const input, + const int input_row_stride, + const int input_col_stride, + float* const matrix, + const int matrix_stride, + const int matrix_row_stride +) { + const float *inptr0 = input; + const float *inptr1 = input + input_row_stride; + const float *inptr2 = input + input_row_stride * 2; + + float *outptr0 = matrix; + float *outptr4 = matrix + matrix_stride * 4; + float *outptr8 = matrix + matrix_stride * 8; + float *outptr12 = matrix + matrix_stride * 12; + + int tile_j = tile_N; // Tiles to process + + asm volatile ( + // Named SIMD registers according to the policy given above + // Registers into which to load the latter two columns of `x` + // NOTE: We need only load the latter three rows since we know that the + // first row is padded. + "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" + "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" + "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n" + + // Registers for storing X.T x (both A and B halves) + "AXTx11 .req v8\n" "BXTx13 .req v8\n" + "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" + "AXTx21 .req v10\n" "BXTx23 .req v10\n" + "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" + "AXTx31 .req v12\n" "BXTx33 .req v12\n" + "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" + "AXTx41 .req v14\n" "BXTx43 .req v14\n" + "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" + "AXTx13 .req v16\n" "BXTx11 .req v16\n" + "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" + "AXTx23 .req v18\n" "BXTx21 .req v18\n" + "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" + "AXTx33 .req v20\n" "BXTx31 .req v20\n" + "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" + "AXTx43 .req v22\n" "BXTx41 .req v22\n" + "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" + + // Result register (TODO Does using more registers yield better + // performance) + "U .req v24\n qU .req q24\n" + + // ---------------------------------------------------------------------- + // Head of loop + // Loads a complete 4x4 tile of x, computes X.T x, computes and stores + // `U = X.T x X`. Prepares for the 'A' half of the loop. + // NOTE: Since the first tile has the leftmost column padded we can + // skip 4 loads and 4 calculations for the matrix X.T x X. + + // Temporarily alias registers for computing the first (non-padded) + // column of x. + "x_22 .req v1\n qx_22 .req q1\n" + "x_32 .req v2\n qx_32 .req q2\n" + "x_42 .req v3\n qx_42 .req q3\n" + + "ldr qx_22, [%x[inptr1]]\n" + "ldr qx_32, [%x[inptr2]]\n" + "ldr qx_42, [%x[inptr3]]\n" + + "fneg BXTx12.4s, x_32.4s\n" + "fadd BXTx22.4s, x_22.4s, x_32.4s\n" + "fsub BXTx32.4s, x_32.4s, x_22.4s\n" + "fsub BXTx42.4s, x_22.4s, x_42.4s\n" + + ".unreq x_22\n .unreq qx_22\n" + ".unreq x_32\n .unreq qx_32\n" + ".unreq x_42\n .unreq qx_42\n" + + // Load and compute latter two columns of the first tile. Progress the + // input pointers (by three columns so that the each points are the + // second column of the next tile, that is, each points at the first + // column which must be read for the next tile. + "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" + "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" + "ldr qx_43, [%x[inptr3], %x[colstride1]]\n" + + "fneg BXTx13.4s, x_33.4s\n" + + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" + + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" + + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride2]]\n" + + "fneg BXTx14.4s, x_34.4s\n" + + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride3]\n" + + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride3]\n" + + "fsub BXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride3]\n" + + // Compute and store U for the first tile + // First row + "fneg U.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fneg U.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fneg U.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row, simultaneously load the first column of inputs for the + // next tile. + "fneg U.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + // Update the loop counter, subtract two to account for both the head and + // the tail. + "subs %x[tile_j], %x[tile_j], #2\n" + "beq 2f\n" // Jump to "A" tail if out of tiles + + // ---------------------------------------------------------------------- + "1:" + // Start part A + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fneg AXTx13.4s, x_33.4s\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "fsub AXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" + "fneg AXTx14.4s, x_34.4s\n" + "fadd AXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub AXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "fsub AXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride2]\n" + + // Compute and store U. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, AXTx12.4s, AXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, AXTx22.4s, AXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, AXTx32.4s, AXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, AXTx42.4s, AXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + "subs %x[tile_j], %x[tile_j], #1\n" + "beq 3f\n" // Jump to 'B' tail + + // Start part B + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fneg BXTx13.4s, x_33.4s\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" + "fneg BXTx14.4s, x_34.4s\n" + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "fsub BXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride2]\n" + + // Compute and store U. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + "subs %x[tile_j], %x[tile_j], #1\n" + "bne 1b\n" // Continue loop, otherwise flow into 'A' tail + + // ---------------------------------------------------------------------- + "2:" + // 'A' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fneg AXTx13.4s, x_33.4s\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "fsub AXTx43.4s, x_23.4s, x_43.4s\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" + + "b 4f\n" // Jump to end of function + + // ---------------------------------------------------------------------- + "3:" + // 'B' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fneg BXTx13.4s, x_33.4s\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" + + // ---------------------------------------------------------------------- + "4:" + // End of function + + // Clear names + ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" + ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" + ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n" + ".unreq AXTx11\n" ".unreq BXTx13\n" + ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" + ".unreq AXTx21\n" ".unreq BXTx23\n" + ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" + ".unreq AXTx31\n" ".unreq BXTx33\n" + ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" + ".unreq AXTx41\n" ".unreq BXTx43\n" + ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" + ".unreq AXTx13\n" ".unreq BXTx11\n" + ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" + ".unreq AXTx23\n" ".unreq BXTx21\n" + ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" + ".unreq AXTx33\n" ".unreq BXTx31\n" + ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" + ".unreq AXTx43\n" ".unreq BXTx41\n" + ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" + ".unreq U\n" ".unreq qU\n" + : [inptr1] "+r" (inptr0), // Offset to account for padded row + [inptr2] "+r" (inptr1), // Offset to account for padded row + [inptr3] "+r" (inptr2), // Offset to account for padded row + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [tile_j] "+r" (tile_j) // Tile counter + : [colstride1] "r" (1 * input_col_stride * sizeof(float)), + [colstride2] "r" (2 * input_col_stride * sizeof(float)), + [colstride3] "r" (3 * input_col_stride * sizeof(float)), + [mstride1] "r" (1 * matrix_stride * sizeof(float)), + [mstride2] "r" (2 * matrix_stride * sizeof(float)), + [mstride3] "r" (3 * matrix_stride * sizeof(float)), + [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24" + ); +} + +// Pad left, right and bottom by 1. +template <> +template <> +inline void Winograd2x2_3x3GemmInput::process_tile_row<0, 1, 1, 1, 4>( + const int tile_N, + const float* const input, + const int input_row_stride, + const int input_col_stride, + float* const matrix, + const int matrix_stride, + const int matrix_row_stride +) { + const float *inptr0 = input; + const float *inptr1 = input + input_row_stride; + const float *inptr2 = input + input_row_stride * 2; + + float *outptr0 = matrix; + float *outptr4 = matrix + matrix_stride * 4; + float *outptr8 = matrix + matrix_stride * 8; + float *outptr12 = matrix + matrix_stride * 12; + + int tile_j = tile_N; // Tiles to process + + asm volatile ( + // Named SIMD registers according to the policy given above + // Registers into which to load the latter two columns of `x` + // NOTE: Bottom row is not required since since it is padded. + "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n" + "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" + "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" + + // Registers for storing X.T x (both A and B halves) + "AXTx11 .req v8\n" "BXTx13 .req v8\n" + "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" + "AXTx21 .req v10\n" "BXTx23 .req v10\n" + "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" + "AXTx31 .req v12\n" "BXTx33 .req v12\n" + "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" + "AXTx41 .req v14\n" "BXTx43 .req v14\n" + "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" + "AXTx13 .req v16\n" "BXTx11 .req v16\n" + "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" + "AXTx23 .req v18\n" "BXTx21 .req v18\n" + "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" + "AXTx33 .req v20\n" "BXTx31 .req v20\n" + "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" + "AXTx43 .req v22\n" "BXTx41 .req v22\n" + "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" + + // Result register (TODO Does using more registers yield better + // performance) + "U .req v24\n qU .req q24\n" + + // ---------------------------------------------------------------------- + // Head of loop + // Loads a complete 4x4 tile of x, computes X.T x, computes and stores + // `U = X.T x X`. Prepares for the 'A' half of the loop. + // NOTE: Since the first tile has the leftmost column padded we can + // skip 4 loads and 4 calculations for the matrix X.T x X. + + // Temporarily alias registers for computing the first (non-padded) + // column of x. + "x_12 .req v0\n qx_12 .req q0\n" + "x_22 .req v1\n qx_22 .req q1\n" + "x_32 .req v2\n qx_32 .req q2\n" + + "ldr qx_12, [%x[inptr0]]\n" + "ldr qx_22, [%x[inptr1]]\n" + "ldr qx_32, [%x[inptr2]]\n" + + "fsub BXTx12.4s, x_12.4s, x_32.4s\n" + "fadd BXTx22.4s, x_22.4s, x_32.4s\n" + "fsub BXTx32.4s, x_32.4s, x_22.4s\n" + "mov BXTx42.16b, x_22.16b\n" // Probably should do better + + ".unreq x_12\n .unreq qx_12\n" + ".unreq x_22\n .unreq qx_22\n" + ".unreq x_32\n .unreq qx_32\n" + + // Load and compute latter two columns of the first tile. Progress the + // input pointers (by three columns so that the each points are the + // second column of the next tile, that is, each points at the first + // column which must be read for the next tile. + "ldr qx_13, [%x[inptr0], %x[colstride1]]\n" + "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" + "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" + + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride2]]\n" + + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" + + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" + + "mov BXTx43.16b, x_23.16b\n" + "fsub BXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride3]\n" + + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride3]\n" + + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride3]\n" + + "mov BXTx44.16b, x_24.16b\n" + + // Compute and store U for the first tile + // First row + "fneg U.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fneg U.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fneg U.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row, simultaneously load the first column of inputs for the + // next tile. + "fneg U.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + // Update the loop counter, subtract two to account for both the head and + // the tail. + "subs %x[tile_j], %x[tile_j], #2\n" + "beq 2f\n" // Jump to "A" tail if out of tiles + + // ---------------------------------------------------------------------- + "1:" + // Start part A + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fsub AXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "mov AXTx43.16b, x_23.16b\n" + + "fsub AXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride2]\n" + "fadd AXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub AXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "mov AXTx44.16b, x_24.16b\n" + + // Compute and store U. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, AXTx12.4s, AXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, AXTx22.4s, AXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, AXTx32.4s, AXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, AXTx42.4s, AXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + "subs %x[tile_j], %x[tile_j], #1\n" + "beq 3f\n" // Jump to 'B' tail + + // Start part B + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "mov BXTx43.16b, x_23.16b\n" + + "fsub BXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride2]\n" + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "mov BXTx44.16b, x_24.16b\n" + + // Compute and store U. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + "subs %x[tile_j], %x[tile_j], #1\n" + "bne 1b\n" // Continue loop, otherwise flow into 'A' tail + + // ---------------------------------------------------------------------- + "2:" + // 'A' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fsub AXTx13.4s, x_13.4s, x_33.4s\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "mov AXTx43.16b, x_23.16b\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" + + "b 4f\n" // Jump to end of function + + // ---------------------------------------------------------------------- + "3:" + // 'B' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "mov BXTx43.16b, x_23.16b\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" + + // ---------------------------------------------------------------------- + "4:" + // End of function + + // Clear names + ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n" + ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" + ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" + ".unreq AXTx11\n" ".unreq BXTx13\n" + ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" + ".unreq AXTx21\n" ".unreq BXTx23\n" + ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" + ".unreq AXTx31\n" ".unreq BXTx33\n" + ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" + ".unreq AXTx41\n" ".unreq BXTx43\n" + ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" + ".unreq AXTx13\n" ".unreq BXTx11\n" + ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" + ".unreq AXTx23\n" ".unreq BXTx21\n" + ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" + ".unreq AXTx33\n" ".unreq BXTx31\n" + ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" + ".unreq AXTx43\n" ".unreq BXTx41\n" + ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" + ".unreq U\n" ".unreq qU\n" + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [tile_j] "+r" (tile_j) // Tile counter + : [colstride1] "r" (1 * input_col_stride * sizeof(float)), + [colstride2] "r" (2 * input_col_stride * sizeof(float)), + [colstride3] "r" (3 * input_col_stride * sizeof(float)), + [mstride1] "r" (1 * matrix_stride * sizeof(float)), + [mstride2] "r" (2 * matrix_stride * sizeof(float)), + [mstride3] "r" (3 * matrix_stride * sizeof(float)), + [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24" + ); +} +} +#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp new file mode 100644 index 0000000000..ad1ad55291 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp @@ -0,0 +1,961 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include "../input_2x2_3x3.hpp" + +#ifdef __aarch64__ + +namespace winograd { + +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 0, 0, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 1*input_row_stride; + auto inptr2 = inptr0 + 2*input_row_stride; + auto inptr3 = inptr0 + 3*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_11 .req v0\n" "qX_11 .req q0\n" + "X_12 .req v1\n" "qX_12 .req q1\n" + "X_13 .req v2\n" "qX_13 .req q2\n" + "X_14 .req v3\n" "qX_14 .req q3\n" + "X_21 .req v4\n" "qX_21 .req q4\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_24 .req v7\n" "qX_24 .req q7\n" + "X_31 .req v8\n" "qX_31 .req q8\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_34 .req v11\n" "qX_34 .req q11\n" + "X_41 .req v12\n" "qX_41 .req q12\n" + "X_42 .req v13\n" "qX_42 .req q13\n" + "X_43 .req v14\n" "qX_43 .req q14\n" + "X_44 .req v15\n" "qX_44 .req q15\n" + "xX_11 .req v16\n" + "xX_12 .req v17\n" + "xX_13 .req v18\n" + "xX_14 .req v19\n" + "xX_21 .req v20\n" + "xX_22 .req v21\n" + "xX_23 .req v22\n" + "xX_24 .req v23\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req v27\n" + "xX_41 .req v28\n" + "xX_42 .req v29\n" + "xX_43 .req v30\n" + "xX_44 .req v31\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_11, [%x[inptr0]]\n" + "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" + "ldr qX_14, [%x[inptr0], %x[colstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qX_21, [%x[inptr1]]\n" + "fsub xX_11.4s, x_11.4s, x_13.4s\n" + "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" + "fadd xX_12.4s, x_12.4s, x_13.4s\n" + "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" + "fsub xX_13.4s, x_13.4s, x_12.4s\n" + "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" + "fsub xX_14.4s, x_12.4s, x_14.4s\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qX_31, [%x[inptr2]]\n" + "fsub xX_21.4s, x_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" + "fsub xX_24.4s, x_22.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "ldr qX_41, [%x[inptr3]]\n" + "fsub xX_31.4s, x_31.4s, x_33.4s\n" + "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "ldr qX_44, [%x[inptr3], %x[colstride3]]\n" + "fsub xX_34.4s, x_32.4s, x_34.4s\n" + "add %x[inptr3], %x[inptr3], #0x10\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fsub xX_41.4s, x_41.4s, x_43.4s\n" + + "fsub U.4s, xX_11.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fsub U.4s, xX_12.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, xX_13.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, xX_14.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd xX_42.4s, x_42.4s, x_43.4s\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub xX_43.4s, x_43.4s, x_42.4s\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fsub xX_44.4s, x_42.4s, x_44.4s\n" + + "fsub U.4s, xX_21.4s, xX_41.4s\n" + "str qU, [%x[outptr12]]\n" + "fsub U.4s, xX_22.4s, xX_42.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, xX_23.4s, xX_43.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "fsub U.4s, xX_24.4s, xX_44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq qU\n" + ".unreq U\n" + ".unreq X_11\n" ".unreq qX_11\n" + ".unreq X_12\n" ".unreq qX_12\n" + ".unreq X_13\n" ".unreq qX_13\n" + ".unreq X_14\n" ".unreq qX_14\n" + ".unreq X_21\n" ".unreq qX_21\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_24\n" ".unreq qX_24\n" + ".unreq X_31\n" ".unreq qX_31\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_34\n" ".unreq qX_34\n" + ".unreq X_41\n" ".unreq qX_41\n" + ".unreq X_42\n" ".unreq qX_42\n" + ".unreq X_43\n" ".unreq qX_43\n" + ".unreq X_44\n" ".unreq qX_44\n" + ".unreq xX_11\n" + ".unreq xX_12\n" + ".unreq xX_13\n" + ".unreq xX_14\n" + ".unreq xX_21\n" + ".unreq xX_22\n" + ".unreq xX_23\n" + ".unreq xX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + ".unreq xX_41\n" + ".unreq xX_42\n" + ".unreq xX_43\n" + ".unreq xX_44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [inptr3] "+r" (inptr3), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [colstride3] "r" (input_col_stride * sizeof(float) * 3), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} + +// Pad top by 1 +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<1, 0, 0, 0, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 0*input_row_stride; + auto inptr2 = inptr0 + 1*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_21 .req v4\n" "qX_21 .req q4\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_24 .req v7\n" "qX_24 .req q7\n" + "X_31 .req v8\n" "qX_31 .req q8\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_34 .req v11\n" "qX_34 .req q11\n" + "X_41 .req v12\n" "qX_41 .req q12\n" + "X_42 .req v13\n" "qX_42 .req q13\n" + "X_43 .req v14\n" "qX_43 .req q14\n" + "X_44 .req v15\n" "qX_44 .req q15\n" + "xX_21 .req v20\n" + "xX_22 .req v21\n" + "xX_23 .req v22\n" + "xX_24 .req v23\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req v27\n" + "xX_41 .req v28\n" + "xX_42 .req v29\n" + "xX_43 .req v30\n" + "xX_44 .req v31\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_21, [%x[inptr1]]\n" + "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" + "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" + "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qX_31, [%x[inptr2]]\n" + "fsub xX_21.4s, x_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" + "fsub xX_24.4s, x_22.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "ldr qX_41, [%x[inptr3]]\n" + "fsub xX_31.4s, x_31.4s, x_33.4s\n" + "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "ldr qX_44, [%x[inptr3], %x[colstride3]]\n" + "fsub xX_34.4s, x_32.4s, x_34.4s\n" + "add %x[inptr3], %x[inptr3], #0x10\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fsub xX_41.4s, x_41.4s, x_43.4s\n" + + "fneg U.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fneg U.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fneg U.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fneg U.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd xX_42.4s, x_42.4s, x_43.4s\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub xX_43.4s, x_43.4s, x_42.4s\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fsub xX_44.4s, x_42.4s, x_44.4s\n" + + "fsub U.4s, xX_21.4s, xX_41.4s\n" + "str qU, [%x[outptr12]]\n" + "fsub U.4s, xX_22.4s, xX_42.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, xX_23.4s, xX_43.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "fsub U.4s, xX_24.4s, xX_44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq qU\n" + ".unreq U\n" + ".unreq X_21\n" ".unreq qX_21\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_24\n" ".unreq qX_24\n" + ".unreq X_31\n" ".unreq qX_31\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_34\n" ".unreq qX_34\n" + ".unreq X_41\n" ".unreq qX_41\n" + ".unreq X_42\n" ".unreq qX_42\n" + ".unreq X_43\n" ".unreq qX_43\n" + ".unreq X_44\n" ".unreq qX_44\n" + ".unreq xX_21\n" + ".unreq xX_22\n" + ".unreq xX_23\n" + ".unreq xX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + ".unreq xX_41\n" + ".unreq xX_42\n" + ".unreq xX_43\n" + ".unreq xX_44\n" + + : [inptr1] "+r" (inptr0), // Offset for missing row + [inptr2] "+r" (inptr1), // Offset for missing row + [inptr3] "+r" (inptr2), // Offset for missing row + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [colstride3] "r" (input_col_stride * sizeof(float) * 3), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} + +// Pad left by 1 +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 1, 0, 0, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 1*input_row_stride; + auto inptr2 = inptr0 + 2*input_row_stride; + auto inptr3 = inptr0 + 3*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_12 .req v1\n" "qX_12 .req q1\n" + "X_13 .req v2\n" "qX_13 .req q2\n" + "X_14 .req v3\n" "qX_14 .req q3\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_24 .req v7\n" "qX_24 .req q7\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_34 .req v11\n" "qX_34 .req q11\n" + "X_42 .req v13\n" "qX_42 .req q13\n" + "X_43 .req v14\n" "qX_43 .req q14\n" + "X_44 .req v15\n" "qX_44 .req q15\n" + "xX_11 .req v16\n" + "xX_12 .req v17\n" + "xX_13 .req v18\n" + "xX_14 .req v19\n" + "xX_21 .req v20\n" + "xX_22 .req v21\n" + "xX_23 .req v22\n" + "xX_24 .req v23\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req v27\n" + "xX_41 .req v28\n" + "xX_42 .req v29\n" + "xX_43 .req v30\n" + "xX_44 .req v31\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_12, [%x[inptr0]]\n" + "ldr qX_13, [%x[inptr0], %x[colstride1]]\n" + "ldr qX_14, [%x[inptr0], %x[colstride2]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "fneg xX_11.4s, x_13.4s\n" + "ldr qX_22, [%x[inptr1]]\n" + "fadd xX_12.4s, x_12.4s, x_13.4s\n" + "ldr qX_23, [%x[inptr1], %x[colstride1]]\n" + "fsub xX_13.4s, x_13.4s, x_12.4s\n" + "ldr qX_24, [%x[inptr1], %x[colstride2]]\n" + "fsub xX_14.4s, x_12.4s, x_14.4s\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "fneg xX_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride1]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "ldr qX_34, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_24.4s, x_22.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "fneg xX_31.4s, x_33.4s\n" + "ldr qX_42, [%x[inptr3]]\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "ldr qX_43, [%x[inptr3], %x[colstride1]]\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "ldr qX_44, [%x[inptr3], %x[colstride2]]\n" + "fsub xX_34.4s, x_32.4s, x_34.4s\n" + "add %x[inptr3], %x[inptr3], #0x10\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fneg xX_41.4s, x_43.4s\n" + + "fsub U.4s, xX_11.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fsub U.4s, xX_12.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, xX_13.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, xX_14.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd xX_42.4s, x_42.4s, x_43.4s\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub xX_43.4s, x_43.4s, x_42.4s\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fsub xX_44.4s, x_42.4s, x_44.4s\n" + + "fsub U.4s, xX_21.4s, xX_41.4s\n" + "str qU, [%x[outptr12]]\n" + "fsub U.4s, xX_22.4s, xX_42.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, xX_23.4s, xX_43.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "fsub U.4s, xX_24.4s, xX_44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq X_12\n" ".unreq qX_12\n" + ".unreq X_13\n" ".unreq qX_13\n" + ".unreq X_14\n" ".unreq qX_14\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_24\n" ".unreq qX_24\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_34\n" ".unreq qX_34\n" + ".unreq X_42\n" ".unreq qX_42\n" + ".unreq X_43\n" ".unreq qX_43\n" + ".unreq X_44\n" ".unreq qX_44\n" + ".unreq xX_11\n" + ".unreq xX_12\n" + ".unreq xX_13\n" + ".unreq xX_14\n" + ".unreq xX_21\n" + ".unreq xX_22\n" + ".unreq xX_23\n" + ".unreq xX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + ".unreq xX_41\n" + ".unreq xX_42\n" + ".unreq xX_43\n" + ".unreq xX_44\n" + ".unreq U\n" + ".unreq qU\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [inptr3] "+r" (inptr3), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} + +// Pad bottom by 1 +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 1, 0, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 1*input_row_stride; + auto inptr2 = inptr0 + 2*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_11 .req v0\n" "qX_11 .req q0\n" + "X_12 .req v1\n" "qX_12 .req q1\n" + "X_13 .req v2\n" "qX_13 .req q2\n" + "X_14 .req v3\n" "qX_14 .req q3\n" + "X_21 .req v4\n" "qX_21 .req q4\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_24 .req v7\n" "qX_24 .req q7\n" + "X_31 .req v8\n" "qX_31 .req q8\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_34 .req v11\n" "qX_34 .req q11\n" + "xX_11 .req v16\n" + "xX_12 .req v17\n" + "xX_13 .req v18\n" + "xX_14 .req v19\n" + "xX_21 .req v20\n" "qxX_21 .req q20\n" + "xX_22 .req v21\n" "qxX_22 .req q21\n" + "xX_23 .req v22\n" "qxX_23 .req q22\n" + "xX_24 .req v23\n" "qxX_24 .req q23\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req v27\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_11, [%x[inptr0]]\n" + "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" + "ldr qX_14, [%x[inptr0], %x[colstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qX_21, [%x[inptr1]]\n" + "fsub xX_11.4s, x_11.4s, x_13.4s\n" + "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" + "fadd xX_12.4s, x_12.4s, x_13.4s\n" + "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" + "fsub xX_13.4s, x_13.4s, x_12.4s\n" + "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" + "fsub xX_14.4s, x_12.4s, x_14.4s\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qX_31, [%x[inptr2]]\n" + "fsub xX_21.4s, x_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" + "fsub xX_24.4s, x_22.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "fsub xX_31.4s, x_31.4s, x_33.4s\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "fsub xX_34.4s, x_32.4s, x_34.4s\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fsub U.4s, xX_11.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fsub U.4s, xX_12.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, xX_13.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, xX_14.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "str qxX_21, [%x[outptr12]]\n" + "str qxX_22, [%x[outptr12], %x[mstride1]]\n" + "str qxX_23, [%x[outptr12], %x[mstride2]]\n" + "str qxX_24, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq qU\n" + ".unreq U\n" + ".unreq X_11\n" ".unreq qX_11\n" + ".unreq X_12\n" ".unreq qX_12\n" + ".unreq X_13\n" ".unreq qX_13\n" + ".unreq X_14\n" ".unreq qX_14\n" + ".unreq X_21\n" ".unreq qX_21\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_24\n" ".unreq qX_24\n" + ".unreq X_31\n" ".unreq qX_31\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_34\n" ".unreq qX_34\n" + ".unreq xX_11\n" + ".unreq xX_12\n" + ".unreq xX_13\n" + ".unreq xX_14\n" + ".unreq xX_21\n" ".unreq qxX_21\n" + ".unreq xX_22\n" ".unreq qxX_22\n" + ".unreq xX_23\n" ".unreq qxX_23\n" + ".unreq xX_24\n" ".unreq qxX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [colstride3] "r" (input_col_stride * sizeof(float) * 3), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} + +// Pad right by 1 +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 0, 1, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 1*input_row_stride; + auto inptr2 = inptr0 + 2*input_row_stride; + auto inptr3 = inptr0 + 3*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_11 .req v0\n" "qX_11 .req q0\n" + "X_12 .req v1\n" "qX_12 .req q1\n" + "X_13 .req v2\n" "qX_13 .req q2\n" + "X_21 .req v4\n" "qX_21 .req q4\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_31 .req v8\n" "qX_31 .req q8\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_41 .req v12\n" "qX_41 .req q12\n" + "X_42 .req v13\n" "qX_42 .req q13\n" + "X_43 .req v14\n" "qX_43 .req q14\n" + "xX_11 .req v16\n" + "xX_12 .req v17\n" + "xX_13 .req v18\n" + "xX_14 .req x_12\n" + "xX_21 .req v20\n" + "xX_22 .req v21\n" + "xX_23 .req v22\n" + "xX_24 .req x_22\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req x_32\n" + "xX_41 .req v28\n" + "xX_42 .req v29\n" + "xX_43 .req v30\n" + "xX_44 .req x_42\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_11, [%x[inptr0]]\n" + "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qX_21, [%x[inptr1]]\n" + "fsub xX_11.4s, x_11.4s, x_13.4s\n" + "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" + "fadd xX_12.4s, x_12.4s, x_13.4s\n" + "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" + "fsub xX_13.4s, x_13.4s, x_12.4s\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qX_31, [%x[inptr2]]\n" + "fsub xX_21.4s, x_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "ldr qX_41, [%x[inptr3]]\n" + "fsub xX_31.4s, x_31.4s, x_33.4s\n" + "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "add %x[inptr3], %x[inptr3], #0x10\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fsub xX_41.4s, x_41.4s, x_43.4s\n" + + "fsub U.4s, xX_11.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fsub U.4s, xX_12.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, xX_13.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, xX_14.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd xX_42.4s, x_42.4s, x_43.4s\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub xX_43.4s, x_43.4s, x_42.4s\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fsub U.4s, xX_21.4s, xX_41.4s\n" + "str qU, [%x[outptr12]]\n" + "fsub U.4s, xX_22.4s, xX_42.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, xX_23.4s, xX_43.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "fsub U.4s, xX_24.4s, xX_44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq qU\n" + ".unreq U\n" + ".unreq X_11\n" ".unreq qX_11\n" + ".unreq X_12\n" ".unreq qX_12\n" + ".unreq X_13\n" ".unreq qX_13\n" + ".unreq X_21\n" ".unreq qX_21\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_31\n" ".unreq qX_31\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_41\n" ".unreq qX_41\n" + ".unreq X_42\n" ".unreq qX_42\n" + ".unreq X_43\n" ".unreq qX_43\n" + ".unreq xX_11\n" + ".unreq xX_12\n" + ".unreq xX_13\n" + ".unreq xX_14\n" + ".unreq xX_21\n" + ".unreq xX_22\n" + ".unreq xX_23\n" + ".unreq xX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + ".unreq xX_41\n" + ".unreq xX_42\n" + ".unreq xX_43\n" + ".unreq xX_44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [inptr3] "+r" (inptr3), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} +} +#endif diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp new file mode 100644 index 0000000000..033442aa14 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +namespace winograd { + /* Transform a kernel into the Winograd domain. + * + * NOTE: It is assumed that the kernel is in the form [height x width x + * input_channels x output_channel]. + */ + template + struct winograd2x2_3x3_gemm_kernel_transform_impl{ + static void execute( + const KernelShape &shape, + const T* const kernel, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride + ); + + protected: + template + static void transform_kernel( + const T* const kernel, + const int n_input_channels, + const int n_output_channels, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride + ); + }; +} + +/*****************************************************************************/ +/* Transform a fp32 kernel into the Winograd domain. + */ +#include "kernel_2x2_3x3/a64_float.hpp" // AArch64 specialisations + +namespace winograd +{ +template <> +inline void winograd2x2_3x3_gemm_kernel_transform_impl::execute( + const KernelShape &shape, + const float* const kernel, + float* const matrix_base, + const int matrix_stride, + const int matrix_row_stride +) { + // Delegate based on tail size + const int n_input_channels = shape.n_input_channels; + const int n_output_channels = shape.n_output_channels; + + switch (n_output_channels % 4) { + case 0: + transform_kernel<0>( + kernel, n_input_channels, n_output_channels, + matrix_base, matrix_stride, matrix_row_stride + ); + break; + case 1: + transform_kernel<1>( + kernel, n_input_channels, n_output_channels, + matrix_base, matrix_stride, matrix_row_stride + ); + break; + case 2: + transform_kernel<2>( + kernel, n_input_channels, n_output_channels, + matrix_base, matrix_stride, matrix_row_stride + ); + break; + case 3: + transform_kernel<3>( + kernel, n_input_channels, n_output_channels, + matrix_base, matrix_stride, matrix_row_stride + ); + break; + default: + ARM_COMPUTE_ERROR("Cannot happen"); + break; + } +} + +template <> +template +inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel( + const float* const kernel, + const int n_input_channels, + const int n_output_channels, + float* const matrix_base, + const int mstride, + const int matrix_row_stride +) { + // Use one input pointer for each row of the kernel, use two additional + // offsets to extract columns. + const int kernel_col_stride = n_input_channels * n_output_channels; + const int kernel_row_stride = 3 * kernel_col_stride; + const float *inptr0 = kernel; + const float *inptr1 = kernel + kernel_row_stride; + const float *inptr2 = kernel + kernel_row_stride*2; + + // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three + // offsets to extract further matrices. + float *outptr0 = matrix_base; + float *outptr4 = matrix_base + mstride * 4; + float *outptr8 = matrix_base + mstride * 8; + float *outptr12 = matrix_base + mstride * 12; + + // For every input channel + for (int in_c = 0; in_c < n_input_channels; in_c++) { + // For every output channel + for (int c = 0; c < n_output_channels; c++) { + // Read in the kernel + float w11 = inptr0[0], w12 = inptr0[kernel_col_stride], w13 = inptr0[kernel_col_stride*2]; + float w21 = inptr1[0], w22 = inptr1[kernel_col_stride], w23 = inptr1[kernel_col_stride*2]; + float w31 = inptr2[0], w32 = inptr2[kernel_col_stride], w33 = inptr2[kernel_col_stride*2]; + + // Progress input pointers + inptr0++; + inptr1++; + inptr2++; + + // Compute the kernel W w, note we need only compute the middle two rows + // (2 and 3) because the first and last rows are merely copies of values + // from the matrix w. + float Ww11 = w11, Ww12 = w12, Ww13 = w13; + float Ww21 = 0.5*(w11 + w21 + w31), Ww22 = 0.5*(w12 + w22 + w32), Ww23 = 0.5*(w13 + w23 + w33); + float Ww31 = 0.5*(w11 - w21 + w31), Ww32 = 0.5*(w12 - w22 + w32), Ww33 = 0.5*(w13 - w23 + w33); + float Ww41 = w31, Ww42 = w32, Ww43 = w33; + + // Hence compute W w W.T; again note we need compute only the middle two + // columns since the first and last columns are copies of the first and + // last columns of the previous matrix. + float WwWT11 = Ww11, WwWT12 = 0.5*(Ww11 + Ww12 + Ww13), WwWT13 = 0.5*(Ww11 - Ww12 + Ww13), WwWT14 = Ww13; + float WwWT21 = Ww21, WwWT22 = 0.5*(Ww21 + Ww22 + Ww23), WwWT23 = 0.5*(Ww21 - Ww22 + Ww23), WwWT24 = Ww23; + float WwWT31 = Ww31, WwWT32 = 0.5*(Ww31 + Ww32 + Ww33), WwWT33 = 0.5*(Ww31 - Ww32 + Ww33), WwWT34 = Ww33; + float WwWT41 = Ww41, WwWT42 = 0.5*(Ww41 + Ww42 + Ww43), WwWT43 = 0.5*(Ww41 - Ww42 + Ww43), WwWT44 = Ww43; + + // Store the computed weights + outptr0[0 * mstride] = WwWT11; + outptr0[1 * mstride] = WwWT12; + outptr0[2 * mstride] = WwWT13; + outptr0[3 * mstride] = WwWT14; + + outptr4[0 * mstride] = WwWT21; + outptr4[1 * mstride] = WwWT22; + outptr4[2 * mstride] = WwWT23; + outptr4[3 * mstride] = WwWT24; + + outptr8[0 * mstride] = WwWT31; + outptr8[1 * mstride] = WwWT32; + outptr8[2 * mstride] = WwWT33; + outptr8[3 * mstride] = WwWT34; + + outptr12[0 * mstride] = WwWT41; + outptr12[1 * mstride] = WwWT42; + outptr12[2 * mstride] = WwWT43; + outptr12[3 * mstride] = WwWT44; + + // Progress output pointers + outptr0++; + outptr4++; + outptr8++; + outptr12++; + } + + // Progression to complete stride + outptr0 += matrix_row_stride - n_output_channels; + outptr4 += matrix_row_stride - n_output_channels; + outptr8 += matrix_row_stride - n_output_channels; + outptr12 += matrix_row_stride - n_output_channels; + } +} +} diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp new file mode 100644 index 0000000000..3dd62d1ac1 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp @@ -0,0 +1,822 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ +namespace winograd { +template <> +template <> +inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<0>( + const float* const kernel, + const int n_input_channels, + const int n_output_channels, + float* const matrix_base, + const int mstride, + const int matrix_row_stride +) { + // Use one input pointer for each row of the kernel, use two additional + // offsets to extract columns. + const int kernel_col_stride = n_input_channels * n_output_channels; + const int kernel_row_stride = 3 * kernel_col_stride; + const float *inptr0 = kernel; + const float *inptr1 = kernel + kernel_row_stride; + const float *inptr2 = kernel + kernel_row_stride*2; + + // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three + // offsets to extract further matrices. + float *outptr0 = matrix_base; + float *outptr4 = matrix_base + mstride * 4; + float *outptr8 = matrix_base + mstride * 8; + float *outptr12 = matrix_base + mstride * 12; + + // For every input channel + for (int in_c = 0; in_c < n_input_channels; in_c++) { + int n_remaining_channels = n_output_channels; + + asm volatile ( + // Registers into which to read the kernel + "w_11 .req v0\n" "qw_11 .req q0\n" + "w_12 .req v1\n" "qw_12 .req q1\n" + "w_13 .req v2\n" "qw_13 .req q2\n" + "w_21 .req v3\n" "qw_21 .req q3\n" + "w_22 .req v4\n" "qw_22 .req q4\n" + "w_23 .req v5\n" "qw_23 .req q5\n" + "w_31 .req v6\n" "qw_31 .req q6\n" + "w_32 .req v7\n" "qw_32 .req q7\n" + "w_33 .req v8\n" "qw_33 .req q8\n" + + // Transformed matrix Ww + "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" + "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" + "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" + "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" + + // Output matrix U = WwWT + "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" + "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" + "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" + "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" + + // Storage view of output matrices + "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" + "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" + "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" + "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" + + "half .req v23\n" // {0.5, ..., 0.5} + "dup half.4s, %w[one_half]\n" + "scratch .req v24\n" + + "1:" + // Load tile of the kernel + "ldr qw_11, [%x[inptr0]]\n" + "str qU11, [%x[outptr0]]\n" + "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" + "str qU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qw_21, [%x[inptr1]]\n" + "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qw_31, [%x[inptr2]]\n" + "str qU41, [%x[outptr12]]\n" + "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" + "str qU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.4s, w_11.4s, w_31.4s\n" + "fmul Ww21.4s, scratch.4s, half.4s\n" + "fmla Ww21.4s, w_21.4s, half.4s\n" + "str qU21, [%x[outptr4]]\n" + "fmul Ww31.4s, scratch.4s, half.4s\n" + "fmls Ww31.4s, w_21.4s, half.4s\n" + "str qU31, [%x[outptr8]]\n" + + "fadd scratch.4s, w_12.4s, w_32.4s\n" + "fmul Ww22.4s, scratch.4s, half.4s\n" + "fmla Ww22.4s, w_22.4s, half.4s\n" + "fmul Ww32.4s, scratch.4s, half.4s\n" + "fmls Ww32.4s, w_22.4s, half.4s\n" + + "fadd scratch.4s, w_13.4s, w_33.4s\n" + "fmul Ww23.4s, scratch.4s, half.4s\n" + "fmla Ww23.4s, w_23.4s, half.4s\n" + "str qU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.4s, scratch.4s, half.4s\n" + "fmls Ww33.4s, w_23.4s, half.4s\n" + "str qU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns + // of U and update output pointers + "fadd scratch.4s, Ww11.4s, Ww13.4s\n" + "fmul U12.4s, scratch.4s, half.4s\n" + "fmla U12.4s, Ww12.4s, half.4s\n" + "str qU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.4s, scratch.4s, half.4s\n" + "fmls U13.4s, Ww12.4s, half.4s\n" + "str qU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd scratch.4s, Ww21.4s, Ww23.4s\n" + "fmul U22.4s, scratch.4s, half.4s\n" + "fmla U22.4s, Ww22.4s, half.4s\n" + "str qU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.4s, scratch.4s, half.4s\n" + "fmls U23.4s, Ww22.4s, half.4s\n" + "str qU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fadd scratch.4s, Ww31.4s, Ww33.4s\n" + "fmul U32.4s, scratch.4s, half.4s\n" + "fmla U32.4s, Ww32.4s, half.4s\n" + "str qU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.4s, scratch.4s, half.4s\n" + "fmls U33.4s, Ww32.4s, half.4s\n" + "str qU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fadd scratch.4s, Ww41.4s, Ww43.4s\n" + "fmul U42.4s, scratch.4s, half.4s\n" + "fmla U42.4s, Ww42.4s, half.4s\n" + "str qU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.4s, scratch.4s, half.4s\n" + "fmls U43.4s, Ww42.4s, half.4s\n" + "str qU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" + "bne 1b\n" + + // Clear aliases + ".unreq half\n" + ".unreq scratch\n" + ".unreq w_11\n" ".unreq qw_11\n" + ".unreq w_12\n" ".unreq qw_12\n" + ".unreq w_13\n" ".unreq qw_13\n" + ".unreq w_21\n" ".unreq qw_21\n" + ".unreq w_22\n" ".unreq qw_22\n" + ".unreq w_23\n" ".unreq qw_23\n" + ".unreq w_31\n" ".unreq qw_31\n" + ".unreq w_32\n" ".unreq qw_32\n" + ".unreq w_33\n" ".unreq qw_33\n" + ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" + ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" + ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" + ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" + ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" + ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" + ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" + ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" + ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" + ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" + ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" + ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [n_remaining_channels] "+r" (n_remaining_channels) + : [mstride1] "r" (sizeof(float) * mstride), + [mstride2] "r" (sizeof(float) * mstride * 2), + [mstride3] "r" (sizeof(float) * mstride * 3), + [colstride1] "r" (sizeof(float) * kernel_col_stride), + [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), + [one_half] "r" (0.5f) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24" + ); + + // Progression to complete stride + outptr0 += matrix_row_stride - n_output_channels; + outptr4 += matrix_row_stride - n_output_channels; + outptr8 += matrix_row_stride - n_output_channels; + outptr12 += matrix_row_stride - n_output_channels; + } +} + +template <> +template <> +inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<2>( + const float* const kernel, + const int n_input_channels, + const int n_output_channels, + float* const matrix_base, + const int mstride, + const int matrix_row_stride +) { + // Use one input pointer for each row of the kernel, use two additional + // offsets to extract columns. + const int kernel_col_stride = n_input_channels * n_output_channels; + const int kernel_row_stride = 3 * kernel_col_stride; + const float *inptr0 = kernel; + const float *inptr1 = kernel + kernel_row_stride; + const float *inptr2 = kernel + kernel_row_stride*2; + + // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three + // offsets to extract further matrices. + float *outptr0 = matrix_base; + float *outptr4 = matrix_base + mstride * 4; + float *outptr8 = matrix_base + mstride * 8; + float *outptr12 = matrix_base + mstride * 12; + + // For every input channel + for (int in_c = 0; in_c < n_input_channels; in_c++) { + int n_remaining_channels = n_output_channels; + + asm volatile ( + // Registers into which to read the kernel + "w_11 .req v0\n" "qw_11 .req q0\n" "dw_11 .req d0\n" + "w_12 .req v1\n" "qw_12 .req q1\n" "dw_12 .req d1\n" + "w_13 .req v2\n" "qw_13 .req q2\n" "dw_13 .req d2\n" + "w_21 .req v3\n" "qw_21 .req q3\n" "dw_21 .req d3\n" + "w_22 .req v4\n" "qw_22 .req q4\n" "dw_22 .req d4\n" + "w_23 .req v5\n" "qw_23 .req q5\n" "dw_23 .req d5\n" + "w_31 .req v6\n" "qw_31 .req q6\n" "dw_31 .req d6\n" + "w_32 .req v7\n" "qw_32 .req q7\n" "dw_32 .req d7\n" + "w_33 .req v8\n" "qw_33 .req q8\n" "dw_33 .req d8\n" + + // Transformed matrix Ww + "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" + "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" + "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" + "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" + + // Output matrix U = WwWT + "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" + "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" + "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" + "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" + + // Storage view of output matrices + "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" + "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" + "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" + "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" + + "dU11 .req d0\n" "dU12 .req d15\n" "dU13 .req d16\n" "dU14 .req d2\n" + "dU21 .req d9\n" "dU22 .req d17\n" "dU23 .req d18\n" "dU24 .req d11\n" + "dU31 .req d12\n" "dU32 .req d19\n" "dU33 .req d20\n" "dU34 .req d14\n" + "dU41 .req d6\n" "dU42 .req d21\n" "dU43 .req d22\n" "dU44 .req d8\n" + + "half .req v23\n" // {0.5, ..., 0.5} + "dup half.4s, %w[one_half]\n" + "scratch .req v24\n" + + // Subtract the tail from the number of remaining channels and jump to + // the tail if necessary. + "subs %x[n_remaining_channels], %x[n_remaining_channels], #2\n" + "beq 2f\n" + + "1:" + // Load tile of the kernel + "ldr qw_11, [%x[inptr0]]\n" + "str qU11, [%x[outptr0]]\n" + "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" + "str qU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qw_21, [%x[inptr1]]\n" + "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qw_31, [%x[inptr2]]\n" + "str qU41, [%x[outptr12]]\n" + "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" + "str qU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.4s, w_11.4s, w_31.4s\n" + "fmul Ww21.4s, scratch.4s, half.4s\n" + "fmla Ww21.4s, w_21.4s, half.4s\n" + "str qU21, [%x[outptr4]]\n" + "fmul Ww31.4s, scratch.4s, half.4s\n" + "fmls Ww31.4s, w_21.4s, half.4s\n" + "str qU31, [%x[outptr8]]\n" + + "fadd scratch.4s, w_12.4s, w_32.4s\n" + "fmul Ww22.4s, scratch.4s, half.4s\n" + "fmla Ww22.4s, w_22.4s, half.4s\n" + "fmul Ww32.4s, scratch.4s, half.4s\n" + "fmls Ww32.4s, w_22.4s, half.4s\n" + + "fadd scratch.4s, w_13.4s, w_33.4s\n" + "fmul Ww23.4s, scratch.4s, half.4s\n" + "fmla Ww23.4s, w_23.4s, half.4s\n" + "str qU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.4s, scratch.4s, half.4s\n" + "fmls Ww33.4s, w_23.4s, half.4s\n" + "str qU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns + // of U and update output pointers + "fadd scratch.4s, Ww11.4s, Ww13.4s\n" + "fmul U12.4s, scratch.4s, half.4s\n" + "fmla U12.4s, Ww12.4s, half.4s\n" + "str qU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.4s, scratch.4s, half.4s\n" + "fmls U13.4s, Ww12.4s, half.4s\n" + "str qU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd scratch.4s, Ww21.4s, Ww23.4s\n" + "fmul U22.4s, scratch.4s, half.4s\n" + "fmla U22.4s, Ww22.4s, half.4s\n" + "str qU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.4s, scratch.4s, half.4s\n" + "fmls U23.4s, Ww22.4s, half.4s\n" + "str qU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fadd scratch.4s, Ww31.4s, Ww33.4s\n" + "fmul U32.4s, scratch.4s, half.4s\n" + "fmla U32.4s, Ww32.4s, half.4s\n" + "str qU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.4s, scratch.4s, half.4s\n" + "fmls U33.4s, Ww32.4s, half.4s\n" + "str qU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fadd scratch.4s, Ww41.4s, Ww43.4s\n" + "fmul U42.4s, scratch.4s, half.4s\n" + "fmla U42.4s, Ww42.4s, half.4s\n" + "str qU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.4s, scratch.4s, half.4s\n" + "fmls U43.4s, Ww42.4s, half.4s\n" + "str qU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" + "bne 1b\n" + + // Tail size 2 + "2:" + // Load tile of the kernel + "ldr dw_11, [%x[inptr0]]\n" + "str dU11, [%x[outptr0]]\n" + "ldr dw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr dw_13, [%x[inptr0], %x[colstride2]]\n" + "str dU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x08\n" + + "ldr dw_21, [%x[inptr1]]\n" + "ldr dw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr dw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x08\n" + + "ldr dw_31, [%x[inptr2]]\n" + "str dU41, [%x[outptr12]]\n" + "ldr dw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr dw_33, [%x[inptr2], %x[colstride2]]\n" + "str dU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x08\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.2s, w_11.2s, w_31.2s\n" + "fmul Ww21.2s, scratch.2s, half.2s\n" + "fmla Ww21.2s, w_21.2s, half.2s\n" + "str dU21, [%x[outptr4]]\n" + "fmul Ww31.2s, scratch.2s, half.2s\n" + "fmls Ww31.2s, w_21.2s, half.2s\n" + "str dU31, [%x[outptr8]]\n" + + "fadd scratch.2s, w_12.2s, w_32.2s\n" + "fmul Ww22.2s, scratch.2s, half.2s\n" + "fmla Ww22.2s, w_22.2s, half.2s\n" + "fmul Ww32.2s, scratch.2s, half.2s\n" + "fmls Ww32.2s, w_22.2s, half.2s\n" + + "fadd scratch.2s, w_13.2s, w_33.2s\n" + "fmul Ww23.2s, scratch.2s, half.2s\n" + "fmla Ww23.2s, w_23.2s, half.2s\n" + "str dU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.2s, scratch.2s, half.2s\n" + "fmls Ww33.2s, w_23.2s, half.2s\n" + "str dU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns of + // U and update output pointers + "fadd scratch.2s, Ww11.2s, Ww13.2s\n" + "fmul U12.2s, scratch.2s, half.2s\n" + "fmla U12.2s, Ww12.2s, half.2s\n" + "str dU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.2s, scratch.2s, half.2s\n" + "fmls U13.2s, Ww12.2s, half.2s\n" + "str dU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x08\n" + + "fadd scratch.2s, Ww21.2s, Ww23.2s\n" + "fmul U22.2s, scratch.2s, half.2s\n" + "fmla U22.2s, Ww22.2s, half.2s\n" + "str dU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.2s, scratch.2s, half.2s\n" + "fmls U23.2s, Ww22.2s, half.2s\n" + "str dU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x08\n" + + "fadd scratch.2s, Ww31.2s, Ww33.2s\n" + "fmul U32.2s, scratch.2s, half.2s\n" + "fmla U32.2s, Ww32.2s, half.2s\n" + "str dU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.2s, scratch.2s, half.2s\n" + "fmls U33.2s, Ww32.2s, half.2s\n" + "str dU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x08\n" + + "fadd scratch.2s, Ww41.2s, Ww43.2s\n" + "fmul U42.2s, scratch.2s, half.2s\n" + "fmla U42.2s, Ww42.2s, half.2s\n" + "str dU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.2s, scratch.2s, half.2s\n" + "fmls U43.2s, Ww42.2s, half.2s\n" + "str dU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x08\n" + + // Clear aliases + ".unreq half\n" + ".unreq scratch\n" + ".unreq w_11\n" ".unreq qw_11\n" ".unreq dw_11\n" + ".unreq w_12\n" ".unreq qw_12\n" ".unreq dw_12\n" + ".unreq w_13\n" ".unreq qw_13\n" ".unreq dw_13\n" + ".unreq w_21\n" ".unreq qw_21\n" ".unreq dw_21\n" + ".unreq w_22\n" ".unreq qw_22\n" ".unreq dw_22\n" + ".unreq w_23\n" ".unreq qw_23\n" ".unreq dw_23\n" + ".unreq w_31\n" ".unreq qw_31\n" ".unreq dw_31\n" + ".unreq w_32\n" ".unreq qw_32\n" ".unreq dw_32\n" + ".unreq w_33\n" ".unreq qw_33\n" ".unreq dw_33\n" + ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" + ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" + ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" + ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" + ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" + ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" + ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" + ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" + ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" + ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" + ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" + ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" + ".unreq dU11\n" ".unreq dU12\n" ".unreq dU13\n" ".unreq dU14\n" + ".unreq dU21\n" ".unreq dU22\n" ".unreq dU23\n" ".unreq dU24\n" + ".unreq dU31\n" ".unreq dU32\n" ".unreq dU33\n" ".unreq dU34\n" + ".unreq dU41\n" ".unreq dU42\n" ".unreq dU43\n" ".unreq dU44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [n_remaining_channels] "+r" (n_remaining_channels) + : [mstride1] "r" (sizeof(float) * mstride), + [mstride2] "r" (sizeof(float) * mstride * 2), + [mstride3] "r" (sizeof(float) * mstride * 3), + [colstride1] "r" (sizeof(float) * kernel_col_stride), + [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), + [one_half] "r" (0.5f) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24" + ); + + // Progression to complete stride + outptr0 += matrix_row_stride - n_output_channels; + outptr4 += matrix_row_stride - n_output_channels; + outptr8 += matrix_row_stride - n_output_channels; + outptr12 += matrix_row_stride - n_output_channels; + } +} + +template <> +template <> +inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<1>( + const float* const kernel, + const int n_input_channels, + const int n_output_channels, + float* const matrix_base, + const int mstride, + const int matrix_row_stride +) { + // Use one input pointer for each row of the kernel, use two additional + // offsets to extract columns. + const int kernel_col_stride = n_input_channels * n_output_channels; + const int kernel_row_stride = 3 * kernel_col_stride; + const float *inptr0 = kernel; + const float *inptr1 = kernel + kernel_row_stride; + const float *inptr2 = kernel + kernel_row_stride*2; + + // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three + // offsets to extract further matrices. + float *outptr0 = matrix_base; + float *outptr4 = matrix_base + mstride * 4; + float *outptr8 = matrix_base + mstride * 8; + float *outptr12 = matrix_base + mstride * 12; + + // For every input channel + for (int in_c = 0; in_c < n_input_channels; in_c++) { + int n_remaining_channels = n_output_channels; + + asm volatile ( + // Registers into which to read the kernel + "w_11 .req v0\n" "qw_11 .req q0\n" "sw_11 .req s0\n" + "w_12 .req v1\n" "qw_12 .req q1\n" "sw_12 .req s1\n" + "w_13 .req v2\n" "qw_13 .req q2\n" "sw_13 .req s2\n" + "w_21 .req v3\n" "qw_21 .req q3\n" "sw_21 .req s3\n" + "w_22 .req v4\n" "qw_22 .req q4\n" "sw_22 .req s4\n" + "w_23 .req v5\n" "qw_23 .req q5\n" "sw_23 .req s5\n" + "w_31 .req v6\n" "qw_31 .req q6\n" "sw_31 .req s6\n" + "w_32 .req v7\n" "qw_32 .req q7\n" "sw_32 .req s7\n" + "w_33 .req v8\n" "qw_33 .req q8\n" "sw_33 .req s8\n" + + // Transformed matrix Ww + "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" + "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" + "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" + "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" + + // Output matrix U = WwWT + "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" + "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" + "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" + "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" + + // Storage view of output matrices + "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" + "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" + "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" + "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" + + "sU11 .req s0\n" "sU12 .req s15\n" "sU13 .req s16\n" "sU14 .req s2\n" + "sU21 .req s9\n" "sU22 .req s17\n" "sU23 .req s18\n" "sU24 .req s11\n" + "sU31 .req s12\n" "sU32 .req s19\n" "sU33 .req s20\n" "sU34 .req s14\n" + "sU41 .req s6\n" "sU42 .req s21\n" "sU43 .req s22\n" "sU44 .req s8\n" + + "half .req v23\n" // {0.5, ..., 0.5} + "dup half.4s, %w[one_half]\n" + "scratch .req v24\n" + + // Subtract the tail from the number of remaining channels and jump to + // the tail if necessary. + "subs %x[n_remaining_channels], %x[n_remaining_channels], #1\n" + "beq 2f\n" + + "1:" + // Load tile of the kernel + "ldr qw_11, [%x[inptr0]]\n" + "str qU11, [%x[outptr0]]\n" + "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" + "str qU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qw_21, [%x[inptr1]]\n" + "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qw_31, [%x[inptr2]]\n" + "str qU41, [%x[outptr12]]\n" + "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" + "str qU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.4s, w_11.4s, w_31.4s\n" + "fmul Ww21.4s, scratch.4s, half.4s\n" + "fmla Ww21.4s, w_21.4s, half.4s\n" + "str qU21, [%x[outptr4]]\n" + "fmul Ww31.4s, scratch.4s, half.4s\n" + "fmls Ww31.4s, w_21.4s, half.4s\n" + "str qU31, [%x[outptr8]]\n" + + "fadd scratch.4s, w_12.4s, w_32.4s\n" + "fmul Ww22.4s, scratch.4s, half.4s\n" + "fmla Ww22.4s, w_22.4s, half.4s\n" + "fmul Ww32.4s, scratch.4s, half.4s\n" + "fmls Ww32.4s, w_22.4s, half.4s\n" + + "fadd scratch.4s, w_13.4s, w_33.4s\n" + "fmul Ww23.4s, scratch.4s, half.4s\n" + "fmla Ww23.4s, w_23.4s, half.4s\n" + "str qU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.4s, scratch.4s, half.4s\n" + "fmls Ww33.4s, w_23.4s, half.4s\n" + "str qU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns + // of U and update output pointers + "fadd scratch.4s, Ww11.4s, Ww13.4s\n" + "fmul U12.4s, scratch.4s, half.4s\n" + "fmla U12.4s, Ww12.4s, half.4s\n" + "str qU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.4s, scratch.4s, half.4s\n" + "fmls U13.4s, Ww12.4s, half.4s\n" + "str qU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd scratch.4s, Ww21.4s, Ww23.4s\n" + "fmul U22.4s, scratch.4s, half.4s\n" + "fmla U22.4s, Ww22.4s, half.4s\n" + "str qU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.4s, scratch.4s, half.4s\n" + "fmls U23.4s, Ww22.4s, half.4s\n" + "str qU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fadd scratch.4s, Ww31.4s, Ww33.4s\n" + "fmul U32.4s, scratch.4s, half.4s\n" + "fmla U32.4s, Ww32.4s, half.4s\n" + "str qU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.4s, scratch.4s, half.4s\n" + "fmls U33.4s, Ww32.4s, half.4s\n" + "str qU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fadd scratch.4s, Ww41.4s, Ww43.4s\n" + "fmul U42.4s, scratch.4s, half.4s\n" + "fmla U42.4s, Ww42.4s, half.4s\n" + "str qU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.4s, scratch.4s, half.4s\n" + "fmls U43.4s, Ww42.4s, half.4s\n" + "str qU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" + "bne 1b\n" + + // Tail size 1 + "2:" + // Load tile of the kernel + "ldr sw_11, [%x[inptr0]]\n" + "str sU11, [%x[outptr0]]\n" + "ldr sw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr sw_13, [%x[inptr0], %x[colstride2]]\n" + "str sU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x04\n" + + "ldr sw_21, [%x[inptr1]]\n" + "ldr sw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr sw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x04\n" + + "ldr sw_31, [%x[inptr2]]\n" + "str sU41, [%x[outptr12]]\n" + "ldr sw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr sw_33, [%x[inptr2], %x[colstride2]]\n" + "str sU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x04\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.2s, w_11.2s, w_31.2s\n" + "fmul Ww21.2s, scratch.2s, half.2s\n" + "fmla Ww21.2s, w_21.2s, half.2s\n" + "str sU21, [%x[outptr4]]\n" + "fmul Ww31.2s, scratch.2s, half.2s\n" + "fmls Ww31.2s, w_21.2s, half.2s\n" + "str sU31, [%x[outptr8]]\n" + + "fadd scratch.2s, w_12.2s, w_32.2s\n" + "fmul Ww22.2s, scratch.2s, half.2s\n" + "fmla Ww22.2s, w_22.2s, half.2s\n" + "fmul Ww32.2s, scratch.2s, half.2s\n" + "fmls Ww32.2s, w_22.2s, half.2s\n" + + "fadd scratch.2s, w_13.2s, w_33.2s\n" + "fmul Ww23.2s, scratch.2s, half.2s\n" + "fmla Ww23.2s, w_23.2s, half.2s\n" + "str sU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.2s, scratch.2s, half.2s\n" + "fmls Ww33.2s, w_23.2s, half.2s\n" + "str sU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns of + // U and update output pointers + "fadd scratch.2s, Ww11.2s, Ww13.2s\n" + "fmul U12.2s, scratch.2s, half.2s\n" + "fmla U12.2s, Ww12.2s, half.2s\n" + "str sU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.2s, scratch.2s, half.2s\n" + "fmls U13.2s, Ww12.2s, half.2s\n" + "str sU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x04\n" + + "fadd scratch.2s, Ww21.2s, Ww23.2s\n" + "fmul U22.2s, scratch.2s, half.2s\n" + "fmla U22.2s, Ww22.2s, half.2s\n" + "str sU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.2s, scratch.2s, half.2s\n" + "fmls U23.2s, Ww22.2s, half.2s\n" + "str sU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x04\n" + + "fadd scratch.2s, Ww31.2s, Ww33.2s\n" + "fmul U32.2s, scratch.2s, half.2s\n" + "fmla U32.2s, Ww32.2s, half.2s\n" + "str sU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.2s, scratch.2s, half.2s\n" + "fmls U33.2s, Ww32.2s, half.2s\n" + "str sU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x04\n" + + "fadd scratch.2s, Ww41.2s, Ww43.2s\n" + "fmul U42.2s, scratch.2s, half.2s\n" + "fmla U42.2s, Ww42.2s, half.2s\n" + "str sU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.2s, scratch.2s, half.2s\n" + "fmls U43.2s, Ww42.2s, half.2s\n" + "str sU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x04\n" + + // Clear aliases + ".unreq half\n" + ".unreq scratch\n" + ".unreq w_11\n" ".unreq qw_11\n" ".unreq sw_11\n" + ".unreq w_12\n" ".unreq qw_12\n" ".unreq sw_12\n" + ".unreq w_13\n" ".unreq qw_13\n" ".unreq sw_13\n" + ".unreq w_21\n" ".unreq qw_21\n" ".unreq sw_21\n" + ".unreq w_22\n" ".unreq qw_22\n" ".unreq sw_22\n" + ".unreq w_23\n" ".unreq qw_23\n" ".unreq sw_23\n" + ".unreq w_31\n" ".unreq qw_31\n" ".unreq sw_31\n" + ".unreq w_32\n" ".unreq qw_32\n" ".unreq sw_32\n" + ".unreq w_33\n" ".unreq qw_33\n" ".unreq sw_33\n" + ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" + ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" + ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" + ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" + ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" + ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" + ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" + ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" + ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" + ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" + ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" + ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" + ".unreq sU11\n" ".unreq sU12\n" ".unreq sU13\n" ".unreq sU14\n" + ".unreq sU21\n" ".unreq sU22\n" ".unreq sU23\n" ".unreq sU24\n" + ".unreq sU31\n" ".unreq sU32\n" ".unreq sU33\n" ".unreq sU34\n" + ".unreq sU41\n" ".unreq sU42\n" ".unreq sU43\n" ".unreq sU44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [n_remaining_channels] "+r" (n_remaining_channels) + : [mstride1] "r" (sizeof(float) * mstride), + [mstride2] "r" (sizeof(float) * mstride * 2), + [mstride3] "r" (sizeof(float) * mstride * 3), + [colstride1] "r" (sizeof(float) * kernel_col_stride), + [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), + [one_half] "r" (0.5f) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24" + ); + + // Progression to complete stride + outptr0 += matrix_row_stride - n_output_channels; + outptr4 += matrix_row_stride - n_output_channels; + outptr8 += matrix_row_stride - n_output_channels; + outptr12 += matrix_row_stride - n_output_channels; + } +} +} +#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp new file mode 100644 index 0000000000..0992c0bb44 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +namespace winograd { + /* Transform from the Winograd domain back to the spatial domain. + */ + template + struct Winograd2x2_3x3GemmOutput { + static void execute( + const Tensor4DShape &output_shape, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output + ); + + protected: + /* Specialised implementation method. */ + template + static void _execute( + const Tensor4DShape &output_shape, + T *output, + const T *input, + const int matrix_stride, + const int matrix_row_stride + ); + }; + + /* Two-stage implementation of the transformation from the Winograd domain. + * + * First computes Z.F and then computes (Z.F).Z^T. + */ + template + struct Winograd2x2_3x3GemmOutput_TwoStage { + static void execute( + const Tensor4DShape &output_shape, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output + ); + + protected: + template + static void compute_zf( + const int n_rows, const int n_channels, + T* const zf, const T* const input[16] + ); + + template + static void compute_zfzT( + const Tensor4DShape &output_shape, + T* const output, const T* const zf + ); + }; +} + +#include "output_2x2_3x3/a64_float.hpp" +// #include "output_2x2_3x3/a64_float_two_stage.hpp" + +/*****************************************************************************/ +/* +template +void winograd::Winograd2x2_3x3GemmOutput::execute( + const Tensor4DShape &output_shape, + const int tile_M, + const int tile_N, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output +) { + T* const antipadding = reinterpret_cast(malloc(sizeof(T) * output_shape.n_channels)); + + // Get input pointers + const T* inptrs[16]; + for (int i = 0; i < 16; i++) { + inptrs[i] = matrices[i]; + } + + for (int batch = 0; batch < output_shape.n_batches; batch++) { + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + // Get pointers for each of the 4 output cells required for this computation + T* outptrs[4]; + for (int cell_i = 0, c = 0; cell_i < 2; cell_i++) { + for (int cell_j = 0; cell_j < 2; cell_j++, c++) { + const int i = tile_i*2 + cell_i; + const int j = tile_j*2 + cell_j; + + if (i < output_shape.n_rows && j < output_shape.n_cols) { + outptrs[c] = output + ( + (batch*output_shape.n_rows + i) * output_shape.n_cols + + j) * output_shape.n_channels; + } else { + outptrs[c] = antipadding; + } + } // cell_j + } // cell_i + + for (int n = 0; n < output_shape.n_channels; n++) { + // Read 16 values and progress pointers + T v[16]; + for (int i = 0; i < 16; i++) { + v[i] = *(inptrs[i]++); + } + + // Compute output for 4 pixels + *(outptrs[0]++) = v[ 0] + v[ 1] + v[ 2] + + v[ 4] + v[ 5] + v[ 6] + + v[ 8] + v[ 9] + v[10]; + *(outptrs[1]++) = v[ 1] - v[ 2] - v[ 3] + + v[ 5] - v[ 6] - v[ 7] + + v[ 9] - v[10] - v[11]; + *(outptrs[2]++) = v[ 4] + v[ 5] + v[ 6] - + v[ 8] - v[ 9] - v[10] - + v[12] - v[13] - v[14]; + *(outptrs[3]++) = v[ 5] - v[ 6] - v[ 7] - + v[ 9] + v[10] + v[11] - + v[13] + v[14] + v[15]; + } // output_channel + } // tile_j + } // tile_i + } // batch + + free(antipadding); +} +*/ + +/*****************************************************************************/ +/* +template +void winograd::Winograd2x2_3x3GemmOutput_TwoStage::execute( + const Tensor4DShape &output_shape, + T* const matrices[16], T* const output +) { + // Allocate memory for the intermediate matrices + const int tile_M = iceildiv(output_shape.n_rows, 2); + const int tile_N = iceildiv(output_shape.n_cols, 2); + const int n_rows = output_shape.n_batches * tile_M * tile_N; + const int n_channels = output_shape.n_channels; + T* matrices_zf = reinterpret_cast( + calloc(8 * n_rows * n_channels, sizeof(T)) + ); + + // Perform the first stage transform, computing ZF. + // Specializations should dispatch to different methods based on tail size. + compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); + + // Perform the second stage transform, finishing Z F Z^T - variable dispatch + // based on size of the output. Specialisations can also dispatch based on + // the tail-size of the channel. + if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { + compute_zfzT(output_shape, output, matrices_zf); + } else if (output_shape.n_rows % 2) { + compute_zfzT(output_shape, output, matrices_zf); + } else if (output_shape.n_cols % 2) { + compute_zfzT(output_shape, output, matrices_zf); + } else { + compute_zfzT(output_shape, output, matrices_zf); + } + + free(reinterpret_cast(matrices_zf)); +} + +template +template +void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zf( + const int n_rows, const int n_channels, + T* output, const T* const input[16] +) { + // Extract 8 output pointers + T* outptr[8]; + for (int i = 0; i < 8; i++) { + outptr[i] = output + i*n_rows*n_channels; + } + + // Copy the 16 input pointers + const T* inptr[16]; + for (int i = 0; i < 16; i++) { + inptr[i] = input[i]; + } + + // For every row of the matrices + for (int i = 0; i < n_rows; i++) { + // For every channel + for (int j = 0; j < n_channels; j++) { + // Extract values from the input matrices + T val[16]; + for (int n = 0; n < 16; n++) { + val[n] = *(inptr[n]++); + } + + // Compute output values + *(outptr[0]++) = val[0] + val[1] + val[2]; + *(outptr[1]++) = val[1] - val[2] - val[3]; + *(outptr[2]++) = val[4] + val[5] + val[6]; + *(outptr[3]++) = val[5] - val[6] - val[7]; + *(outptr[4]++) = val[8] + val[9] + val[10]; + *(outptr[5]++) = val[9] - val[10] - val[11]; + *(outptr[6]++) = val[12] + val[13] + val[14]; + *(outptr[7]++) = val[13] - val[14] - val[15]; + } + } +} + +template +template +void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zfzT( + const Tensor4DShape &output_shape, + T* const output, const T* const input +) { + // Sizing information + const int tile_M = output_shape.n_rows / 2; + const int tile_N = output_shape.n_cols / 2; + + const int n_rows = (output_shape.n_batches * + (tile_M + (tail_M ? 1 : 0)) * + (tile_N + (tail_N ? 1 : 0))); + const int n_channels = output_shape.n_channels; + + // Extract 8 input pointers + const T* inptr[8]; + for (int i = 0; i < 8; i++) { + inptr[i] = input + i*n_rows*n_channels; + } + + // Extract 4 output pointers + T* outptr00 = output; + T* outptr01 = outptr00 + n_channels; + T* outptr10 = outptr00 + output_shape.n_cols * n_channels; + T* outptr11 = outptr10 + n_channels; + + // Progress over the output tiles, generating output values. + for (int batch = 0; batch < output_shape.n_batches; batch++) { + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + T v[8]; + for (int i = 0; i < 8; i++) { + v[i] = *(inptr[i]++); + } + + // Compute the output values and progress the output pointers. + *(outptr00++) = v[0] + v[2] + v[4]; + *(outptr01++) = v[1] + v[3] + v[5]; + *(outptr10++) = v[2] - v[4] - v[6]; + *(outptr11++) = v[3] - v[5] - v[7]; + } + + // Progress the output pointers to the next column + outptr00 += n_channels; + outptr01 += n_channels; + outptr10 += n_channels; + outptr11 += n_channels; + } + + if (tail_N) { + // Only evaluate the left-most columns of the output + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + T v[8]; + for (int i = 0; i < 4; i++) { + v[i * 2] = *inptr[i * 2]; + } + for (int i = 0; i < 8; i++) { + inptr[i]++; + } + + // Compute the output values and progress the output pointers. + *(outptr00++) = v[0] + v[2] + v[4]; + *(outptr10++) = v[2] - v[4] - v[6]; + } + + // Progress the output pointers to the next column + outptr01 += n_channels; // Account for being skipped above + outptr11 += n_channels; // Account for being skipped above + } + + // Progress the output pointers to the next row + outptr00 += output_shape.n_cols * n_channels; + outptr01 += output_shape.n_cols * n_channels; + outptr10 += output_shape.n_cols * n_channels; + outptr11 += output_shape.n_cols * n_channels; + } + + if (tail_M) { + // Only work on the upper row of the output + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + T v[8]; + for (int i = 0; i < 8; i++) { + v[i] = *(inptr[i]++); + } + + // Compute the output values and progress the output pointers. + *(outptr00++) = v[0] + v[2] + v[4]; + *(outptr01++) = v[1] + v[3] + v[5]; + } + + // Progress the output pointers to the next column + outptr00 += n_channels; + outptr01 += n_channels; + outptr10 += 2 * n_channels; // Account for being skipped above + outptr11 += 2 * n_channels; // Account for being skipped above + } + + if (tail_N) { + // Only evaluate the upper-left cell of the output + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + T v[8]; + for (int i = 0; i < 3; i++) { + v[i * 2] = *inptr[i * 2]; + } + for (int i = 0; i < 8; i++) { + inptr[i]++; + } + + // Compute the output values and progress the output pointers. + *(outptr00++) = v[0] + v[2] + v[4]; + } + + // Progress the output pointers to the next column + outptr01 += n_channels; // Account for being skipped above + outptr10 += n_channels; // Account for being skipped above + outptr11 += n_channels; // Account for being skipped above + } + } + } +} +*/ diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp new file mode 100644 index 0000000000..5925f9d569 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp @@ -0,0 +1,650 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +/* Float implementation for AArch64. + */ +#ifdef __aarch64__ +namespace winograd { + + +template <> +template <> +inline void Winograd2x2_3x3GemmOutput::_execute( + const Tensor4DShape &output_shape, + float *output, + const float *input, + const int mstride, + const int matrix_row_stride +) { + const int tile_M = output_shape.n_rows / 2; + const int tile_N = output_shape.n_cols / 2; + int batch = output_shape.n_batches; + float *outptr = output; + + const float *inptr0 = input; + const float *inptr4 = input + 4 * mstride; + const float *inptr8 = input + 8 * mstride; + const float *inptr12 = input + 12 * mstride; + + const size_t col_stride = sizeof(float) * output_shape.n_channels; + const size_t row_stride = col_stride * tile_N * 2; + + asm volatile ( + // Aliases for elements of the input matrix `F` + // V-register Q-register + "F11 .req v0\n" "qF11 .req q0\n" + "F12 .req v1\n" "qF12 .req q1\n" + "F13 .req v2\n" "qF13 .req q2\n" + "F14 .req v3\n" "qF14 .req q3\n" + "F21 .req v4\n" "qF21 .req q4\n" + "F22 .req v5\n" "qF22 .req q5\n" + "F23 .req v6\n" "qF23 .req q6\n" + "F24 .req v7\n" "qF24 .req q7\n" + "F31 .req v8\n" "qF31 .req q8\n" + "F32 .req v9\n" "qF32 .req q9\n" + "F33 .req v10\n" "qF33 .req q10\n" + "F34 .req v11\n" "qF34 .req q11\n" + "F41 .req v12\n" "qF41 .req q12\n" + "F42 .req v13\n" "qF42 .req q13\n" + "F43 .req v14\n" "qF43 .req q14\n" + "F44 .req v15\n" "qF44 .req q15\n" + + // Aliases for elements of the intermediate matrix `FZ` + "FZ11 .req v16\n" + "FZ12 .req v17\n" + "FZ21 .req v18\n" + "FZ22 .req v19\n" + "FZ31 .req v20\n" + "FZ32 .req v21\n" + "FZ41 .req v22\n" + "FZ42 .req v23\n" + + // Aliases for elements of the output matrix `f` (called `g` due to case + // insensitivity of aliases). + " g11 .req v24\n" + "qg11 .req q24\n" + " g12 .req v25\n" + "qg12 .req q25\n" + " g21 .req v26\n" + "qg21 .req q26\n" + " g22 .req v27\n" + "qg22 .req q27\n" + + // Prepare the various strides + "col_stride .req %x[col_stride]\n" + "row_stride .req %x[row_stride]\n" + "row_plus_col_stride .req %x[row_plus_col_stride]\n" + + "mstride1 .req %x[mstride1]\n" + "mstride2 .req %x[mstride2]\n" + "mstride3 .req %x[mstride3]\n" + + "tile_i .req x19\n" // Tile row counter + "tile_j .req x20\n" // Tile column counter + "channel .req x21\n" // Channel counter + + "1:" // Loop over batches + "mov tile_i, %x[tile_M]\n" // Reset tile row counter + + "2:" // Loop over rows of tiles + "mov tile_j, %x[tile_N]\n" // Reset tile column counter + + "3:" // Loop over columns of tiles + // Perform initial loads of the matrix `F` + "ldr qF11, [%x[inptr0]]\n" + "ldr qF12, [%x[inptr0], mstride1]\n" + "ldr qF13, [%x[inptr0], mstride2]\n" + "ldr qF14, [%x[inptr0], mstride3]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + "ldr qF21, [%x[inptr4]]\n" + "ldr qF22, [%x[inptr4], mstride1]\n" + "subs channel, %x[n_channels], #4\n" // Reset channel counter + + "ldr qF23, [%x[inptr4], mstride2]\n" + "ldr qF24, [%x[inptr4], mstride3]\n" + "add %x[inptr4], %x[inptr4], #0x10\n" + "beq 5f\n" // Jump straight to tail if necessary + + "4:" // Loop over channels + "ldr qF31, [%x[inptr8]]\n" + "fadd FZ11.4s, F11.4s, F12.4s\n" + + "ldr qF32, [%x[inptr8], mstride1]\n" + "fsub FZ12.4s, F12.4s, F13.4s\n" + + "ldr qF33, [%x[inptr8], mstride2]\n" + "fadd FZ11.4s, FZ11.4s, F13.4s\n" + + "ldr qF34, [%x[inptr8], mstride3]\n" + "fsub FZ12.4s, FZ12.4s, F14.4s\n" + + "ldr qF41, [%x[inptr12]]\n" + "fadd FZ21.4s, F21.4s, F22.4s\n" + + "ldr qF42, [%x[inptr12], mstride1]\n" + "fsub FZ22.4s, F22.4s, F23.4s\n" + + "ldr qF43, [%x[inptr12], mstride2]\n" + "fadd FZ21.4s, FZ21.4s, F23.4s\n" + + "ldr qF44, [%x[inptr12], mstride3]\n" + "fsub FZ22.4s, FZ22.4s, F24.4s\n" + + "fadd FZ31.4s, F31.4s, F32.4s\n" + "add %x[inptr8], %x[inptr8], #0x10\n" + + "fsub FZ32.4s, F32.4s, F33.4s\n" + "add %x[inptr12], %x[inptr12], #0x10\n" + + "fadd FZ31.4s, FZ31.4s, F33.4s\n" + + "fsub FZ32.4s, FZ32.4s, F34.4s\n" + + "fadd g11.4s, FZ11.4s, FZ21.4s\n" + + "fadd g12.4s, FZ12.4s, FZ22.4s\n" + + "fadd g11.4s, g11.4s, FZ31.4s\n" + + "fadd g12.4s, g12.4s, FZ32.4s\n" + + "ldr qF11, [%x[inptr0]]\n" + "fadd FZ41.4s, F41.4s, F42.4s\n" + + "ldr qF12, [%x[inptr0], mstride1]\n" + "fsub g21.4s, FZ21.4s, FZ31.4s\n" + + "ldr qF13, [%x[inptr0], mstride2]\n" + "fsub FZ42.4s, F42.4s, F43.4s\n" + + "ldr qF14, [%x[inptr0], mstride3]\n" + "str qg11, [%x[outptr]]\n" + + "ldr qF21, [%x[inptr4]]\n" + "fadd FZ41.4s, FZ41.4s, F43.4s\n" + + "ldr qF22, [%x[inptr4], mstride1]\n" + "str qg12, [%x[outptr], col_stride]\n" + + "ldr qF23, [%x[inptr4], mstride2]\n" + "fsub FZ42.4s, FZ42.4s, F44.4s\n" + + "ldr qF24, [%x[inptr4], mstride3]\n" + "fsub g22.4s, FZ22.4s, FZ32.4s\n" + + "fsub g21.4s, g21.4s, FZ41.4s\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "fsub g22.4s, g22.4s, FZ42.4s\n" + "add %x[inptr4], %x[inptr4], #0x10\n" + + "subs channel, channel, #4\n" + + "str qg21, [%x[outptr], row_stride]\n" + + "str qg22, [%x[outptr], row_plus_col_stride]\n" + + "add %x[outptr], %x[outptr], #0x10\n" + + "bne 4b\n" + + "5:" // Channel tail + "ldr qF31, [%x[inptr8]]\n" + "fadd FZ11.4s, F11.4s, F12.4s\n" + + "ldr qF32, [%x[inptr8], mstride1]\n" + "fsub FZ12.4s, F12.4s, F13.4s\n" + + "ldr qF33, [%x[inptr8], mstride2]\n" + "fadd FZ11.4s, FZ11.4s, F13.4s\n" + + "ldr qF34, [%x[inptr8], mstride3]\n" + "fsub FZ12.4s, FZ12.4s, F14.4s\n" + + "ldr qF41, [%x[inptr12]]\n" + "fadd FZ21.4s, F21.4s, F22.4s\n" + + "ldr qF42, [%x[inptr12], mstride1]\n" + "fsub FZ22.4s, F22.4s, F23.4s\n" + + "ldr qF43, [%x[inptr12], mstride2]\n" + "fadd FZ21.4s, FZ21.4s, F23.4s\n" + + "ldr qF44, [%x[inptr12], mstride3]\n" + "fsub FZ22.4s, FZ22.4s, F24.4s\n" + + "fadd FZ31.4s, F31.4s, F32.4s\n" + "add %x[inptr8], %x[inptr8], #0x10\n" + + "fsub FZ32.4s, F32.4s, F33.4s\n" + "add %x[inptr12], %x[inptr12], #0x10\n" + + "fadd FZ31.4s, FZ31.4s, F33.4s\n" + + "fsub FZ32.4s, FZ32.4s, F34.4s\n" + + "fadd g11.4s, FZ11.4s, FZ21.4s\n" + + "fadd g12.4s, FZ12.4s, FZ22.4s\n" + + "fadd g11.4s, g11.4s, FZ31.4s\n" + + "fadd g12.4s, g12.4s, FZ32.4s\n" + + "fadd FZ41.4s, F41.4s, F42.4s\n" + + "fsub g21.4s, FZ21.4s, FZ31.4s\n" + + "fsub FZ42.4s, F42.4s, F43.4s\n" + + "str qg11, [%x[outptr]]\n" + + "fadd FZ41.4s, FZ41.4s, F43.4s\n" + + "str qg12, [%x[outptr], col_stride]\n" + + "fsub FZ42.4s, FZ42.4s, F44.4s\n" + + "fsub g22.4s, FZ22.4s, FZ32.4s\n" + + "fsub g21.4s, g21.4s, FZ41.4s\n" + + "fsub g22.4s, g22.4s, FZ42.4s\n" + + "subs channel, channel, #4\n" + + "str qg21, [%x[outptr], row_stride]\n" + + // Progress input pointers to the next row of the matrix + "add %x[inptr0], %x[inptr0], %x[mrowpad]\n" + "add %x[inptr4], %x[inptr4], %x[mrowpad]\n" + "add %x[inptr8], %x[inptr8], %x[mrowpad]\n" + "add %x[inptr12], %x[inptr12], %x[mrowpad]\n" + + "str qg22, [%x[outptr], row_plus_col_stride]\n" + + "add %x[outptr], %x[outptr], #0x10\n" + + + "add %x[outptr], %x[outptr], col_stride\n" + "subs tile_j, tile_j, #1\n" + "bne 3b\n" + + "add %x[outptr], %x[outptr], row_stride\n" + "subs tile_i, tile_i, #1\n" + "bne 2b\n" + + "subs %[batch], %[batch], #1\n" + "bne 1b\n" + + ".unreq F11\n" ".unreq qF11\n" + ".unreq F12\n" ".unreq qF12\n" + ".unreq F13\n" ".unreq qF13\n" + ".unreq F14\n" ".unreq qF14\n" + ".unreq F21\n" ".unreq qF21\n" + ".unreq F22\n" ".unreq qF22\n" + ".unreq F23\n" ".unreq qF23\n" + ".unreq F24\n" ".unreq qF24\n" + ".unreq F31\n" ".unreq qF31\n" + ".unreq F32\n" ".unreq qF32\n" + ".unreq F33\n" ".unreq qF33\n" + ".unreq F34\n" ".unreq qF34\n" + ".unreq F41\n" ".unreq qF41\n" + ".unreq F42\n" ".unreq qF42\n" + ".unreq F43\n" ".unreq qF43\n" + ".unreq F44\n" ".unreq qF44\n" + + ".unreq FZ11\n" ".unreq FZ12\n" + ".unreq FZ21\n" ".unreq FZ22\n" + ".unreq FZ31\n" ".unreq FZ32\n" + ".unreq FZ41\n" ".unreq FZ42\n" + + ".unreq g11\n" ".unreq qg11\n" + ".unreq g12\n" ".unreq qg12\n" + ".unreq g21\n" ".unreq qg21\n" + ".unreq g22\n" ".unreq qg22\n" + + ".unreq col_stride\n" + ".unreq row_stride\n" + ".unreq row_plus_col_stride\n" + + ".unreq mstride1\n" + ".unreq mstride2\n" + ".unreq mstride3\n" + + ".unreq tile_i \n" + ".unreq tile_j \n" + ".unreq channel\n" + + : [batch] "+r" (batch), + [outptr] "+r" (outptr), + [inptr0] "+r" (inptr0), + [inptr4] "+r" (inptr4), + [inptr8] "+r" (inptr8), + [inptr12] "+r" (inptr12) + : [tile_M] "r" (tile_M), + [tile_N] "r" (tile_N), + [n_channels] "r" (output_shape.n_channels), + [col_stride] "r" (col_stride), + [row_stride] "r" (row_stride), + [row_plus_col_stride] "r" (row_stride + col_stride), + [mstride1] "r" (mstride * sizeof(float)), + [mstride2] "r" (2 * mstride * sizeof(float)), + [mstride3] "r" (3 * mstride * sizeof(float)), + [mrowpad] "r" ((matrix_row_stride - output_shape.n_channels) * sizeof(float)) + : "x19", "x20", "x21", + "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "q16", "q17", "q18", "q19", "q20", "q21", + "q22", "q23", "q24", "q25", "q26", "q27", + "cc", "memory" + ); +} + +template <> +template +inline void Winograd2x2_3x3GemmOutput::_execute( + const Tensor4DShape &output_shape, + float *output, + const float *input, + const int mstride, + const int matrix_row_stride +) { + // Compute basic information about the shape of the matrices + const int tile_M = output_shape.n_rows / 2; + const int tile_N = output_shape.n_cols / 2; + const int n_channels = output_shape.n_channels; + + // Extract 16 input pointers + const float* inptr[16]; + for (int i = 0; i < 16; i++) { + inptr[i] = input + i*mstride; + } + + // Extract 4 output pointers + float *outptr00 = output; + float *outptr01 = outptr00 + n_channels; + float *outptr10 = outptr00 + output_shape.n_cols * n_channels; + float *outptr11 = outptr10 + n_channels; + + // Progress over the output tiles, generating output values. + for (int batch = 0; batch < output_shape.n_batches; batch++) { + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + float F[4][4]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + F[i][j] = *(inptr[i*4 + j]++); + } + } + + // Compute the matrix F.Z + float ZF[4][2]; + ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; + ZF[0][1] = F[0][1] - F[0][2] - F[0][3]; + ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; + ZF[1][1] = F[1][1] - F[1][2] - F[1][3]; + ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; + ZF[2][1] = F[2][1] - F[2][2] - F[2][3]; + ZF[3][0] = F[3][0] + F[3][1] + F[3][2]; + ZF[3][1] = F[3][1] - F[3][2] - F[3][3]; + + // Hence compute the output matrix Z^T . (F.Z) + *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; + *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1]; + *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0]; + *(outptr11++) = ZF[1][1] - ZF[2][1] - ZF[3][1]; + } + + // Progress the input pointers to the next row + for (int i = 0; i < 16; i++) { + inptr[i] += matrix_row_stride - n_channels; + } + + // Progress the output pointers to the next column + outptr00 += n_channels; + outptr01 += n_channels; + outptr10 += n_channels; + outptr11 += n_channels; + } + + if (tail_N) { + // Only evaluate the left-most columns of the output + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + float F[4][3]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 3; j++) { + F[i][j] = *(inptr[i*4 + j]++); + } + } + for (int i = 0; i < 4; i++) { + inptr[i*4 + 3]++; + } + + // Compute the matrix F.Z + float ZF[4][1]; + ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; + ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; + ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; + ZF[3][0] = F[3][0] + F[3][1] + F[3][2]; + + // Hence compute the output matrix Z^T . (F.Z) + *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; + *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0]; + } + + // Progress the input pointers to the next row + for (int i = 0; i < 16; i++) { + inptr[i] += matrix_row_stride - n_channels; + } + + // Progress the output pointers to the next column + outptr01 += n_channels; // Account for being skipped above + outptr11 += n_channels; // Account for being skipped above + } + + // Progress the output pointers to the next row + outptr00 += output_shape.n_cols * n_channels; + outptr01 += output_shape.n_cols * n_channels; + outptr10 += output_shape.n_cols * n_channels; + outptr11 += output_shape.n_cols * n_channels; + } + + if (tail_M) { + // Only work on the upper row of the output + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + float F[3][4]; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 4; j++) { + F[i][j] = *(inptr[i*4 + j]++); + } + } + for (int j = 0; j < 4; j++) { + inptr[12 + j]++; + } + + // Compute the matrix F.Z + float ZF[3][2]; + ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; + ZF[0][1] = F[0][1] - F[0][2] - F[0][3]; + ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; + ZF[1][1] = F[1][1] - F[1][2] - F[1][3]; + ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; + ZF[2][1] = F[2][1] - F[2][2] - F[2][3]; + + // Hence compute the output matrix Z^T . (F.Z) + *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; + *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1]; + } + + // Progress the input pointers to the next row + for (int i = 0; i < 16; i++) { + inptr[i] += matrix_row_stride - n_channels; + } + + // Progress the output pointers to the next column + outptr00 += n_channels; + outptr01 += n_channels; + outptr10 += 2 * n_channels; // Account for being skipped above + outptr11 += 2 * n_channels; // Account for being skipped above + } + + if (tail_N) { + // Only evaluate the upper-left cell of the output + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + float F[3][3]; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + F[i][j] = *(inptr[i*4 + j]); + } + } + for (int i = 0; i < 16; i++) { + inptr[i]++; + } + + // Compute the matrix F.Z + float ZF[3][1]; + ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; + ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; + ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; + + // Hence compute the output matrix Z^T . (F.Z) + *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; + } + + // Progress the input pointers to the next row + for (int i = 0; i < 16; i++) { + inptr[i] += matrix_row_stride - n_channels; + } + + // Progress the output pointers to the next column + outptr01 += n_channels; // Account for being skipped above + outptr10 += n_channels; // Account for being skipped above + outptr11 += n_channels; // Account for being skipped above + } + } + } +} + +/*****************************************************************************/ +template <> +inline void Winograd2x2_3x3GemmOutput::execute( + const Tensor4DShape &output_shape, + float* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + float* const output +) { + // Dispatch to an appropriate implementation based on the shape of the output + // tensor. + if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { + constexpr bool tail_M = true, tail_N = true; + switch (output_shape.n_channels % 4) { + case 0: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 1: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 2: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 3: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + default: + assert(0); + break; + } + } else if (output_shape.n_rows % 2) { + constexpr bool tail_M = true, tail_N = false; + switch (output_shape.n_channels % 4) { + case 0: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 1: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 2: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 3: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + default: + assert(0); + break; + } + } else if (output_shape.n_cols % 2) { + constexpr bool tail_M = false, tail_N = true; + switch (output_shape.n_channels % 4) { + case 0: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 1: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 2: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 3: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + default: + assert(0); + break; + + } + } else { + constexpr bool tail_M = false, tail_N = false; + switch (output_shape.n_channels % 4) { + case 0: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 1: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 2: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 3: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + default: + assert(0); + break; + + } + } +} +/*****************************************************************************/ + +} // namespace winograd +#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp new file mode 100644 index 0000000000..f551b12b52 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp @@ -0,0 +1,655 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +/*****************************************************************************/ +// Compute ZF specializations + +template <> +template <> +inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zf<0>( + const int n_rows, const int n_channels, + float* output, const float* const input[16] +) { + // Make copies of some variables + int row = n_rows; + float* outptr = output; + const float* inptr = input[0]; + + // Perform the transformation + asm volatile ( + // "inptr0 .req %x[inptr]\n" + "inptr1 .req x0\n" + "inptr2 .req x1\n" + "inptr3 .req x2\n" + "inptr4 .req x3\n" + "inptr5 .req x4\n" + "inptr6 .req x5\n" + "inptr7 .req x6\n" + "inptr8 .req x7\n" + "inptr9 .req x8\n" + "inptr10 .req x9\n" + "inptr11 .req x10\n" + "inptr12 .req x11\n" + "inptr13 .req x12\n" + "inptr14 .req x13\n" + "inptr15 .req x14\n" + + // "outptr0 .req %x[outptr]\n" + "outptr1 .req x15\n" + "outptr2 .req x16\n" + "outptr3 .req x17\n" + "outptr4 .req x18\n" + "outptr5 .req x19\n" + "outptr6 .req x20\n" + "outptr7 .req x21\n" + + // Compute additional pointers into the input and output matrices. + "mstride .req x22\n" // Matrix stride + "mul mstride, %x[row], %x[n_channels]\n" + "lsl mstride, mstride, #2\n" // * sizeof(float) + + "add inptr1, %x[inptr], mstride\n" + "add inptr2, %x[inptr], mstride, LSL #1\n" + "add inptr3, inptr2, mstride\n" + "add inptr4, inptr3, mstride\n" + "add inptr5, inptr4, mstride\n" + "add inptr6, inptr5, mstride\n" + "add inptr7, inptr6, mstride\n" + "add inptr8, inptr7, mstride\n" + "add inptr9, inptr8, mstride\n" + "add inptr10, inptr9, mstride\n" + "add inptr11, inptr10, mstride\n" + "add inptr12, inptr11, mstride\n" + "add inptr13, inptr12, mstride\n" + "add inptr14, inptr13, mstride\n" + "add inptr15, inptr14, mstride\n" + + "add outptr1, %[outptr], mstride\n" + "add outptr2, outptr1, mstride\n" + "add outptr3, outptr2, mstride\n" + "add outptr4, outptr3, mstride\n" + "add outptr5, outptr4, mstride\n" + "add outptr6, outptr5, mstride\n" + "add outptr7, outptr6, mstride\n" + + ".unreq mstride\n" + + "column .req x22\n" // Column loop counter + + "1:" // Loop over rows + "ldr q0, [%x[inptr]], #0x10\n" + "ldr q1, [inptr1], #0x10\n" + "ldr q2, [inptr2], #0x10\n" + "ldr q3, [inptr3], #0x10\n" + "ldr q4, [inptr4], #0x10\n" + "ldr q5, [inptr5], #0x10\n" + "ldr q6, [inptr6], #0x10\n" + "ldr q7, [inptr7], #0x10\n" + "subs column, %x[n_channels], #0x4\n" + "beq 3f\n" + + "2:" // Loop over columns + "ldr q8, [inptr8], #0x10\n" + "prfm pldl1keep, [%x[inptr], #196]\n" + "fadd v16.4s, v0.4s, v1.4s\n" + + "ldr q9, [inptr9], #0x10\n" + "prfm pldl1keep, [inptr1, #196]\n" + "fsub v17.4s, v1.4s, v2.4s\n" + + "ldr q10, [inptr10], #0x10\n" + "prfm pldl1keep, [inptr2, #196]\n" + "fadd v16.4s, v16.4s, v2.4s\n" + + "ldr q11, [inptr11], #0x10\n" + "prfm pldl1keep, [inptr3, #196]\n" + "fsub v17.4s, v17.4s, v3.4s\n" + + "ldr q12, [inptr12], #0x10\n" + "prfm pldl1keep, [inptr4, #196]\n" + "str q16, [%x[outptr]], #0x10\n" + + "ldr q13, [inptr13], #0x10\n" + "prfm pldl1keep, [inptr5, #196]\n" + "str q17, [outptr1], #0x10\n" + + "ldr q14, [inptr14], #0x10\n" + "prfm pldl1keep, [inptr6, #196]\n" + "fadd v16.4s, v4.4s, v5.4s\n" + + "ldr q15, [inptr15], #0x10\n" + "prfm pldl1keep, [inptr7, #196]\n" + "fsub v17.4s, v5.4s, v6.4s\n" + + "ldr q0, [%x[inptr]], #0x10\n" + "prfm pldl1keep, [inptr8, #196]\n" + "fadd v16.4s, v16.4s, v6.4s\n" + + "ldr q1, [inptr1], #0x10\n" + "prfm pldl1keep, [inptr9, #196]\n" + "fsub v17.4s, v17.4s, v7.4s\n" + + "ldr q2, [inptr2], #0x10\n" + "prfm pldl1keep, [inptr10, #196]\n" + "str q16, [outptr2], #0x10\n" + + "ldr q3, [inptr3], #0x10\n" + "prfm pldl1keep, [inptr11, #196]\n" + "str q17, [outptr3], #0x10\n" + + "ldr q4, [inptr4], #0x10\n" + "prfm pldl1keep, [inptr12, #196]\n" + "fadd v16.4s, v8.4s, v9.4s\n" + + "ldr q5, [inptr5], #0x10\n" + "prfm pldl1keep, [inptr13, #196]\n" + "fsub v17.4s, v9.4s, v10.4s\n" + + "ldr q6, [inptr6], #0x10\n" + "prfm pldl1keep, [inptr14, #196]\n" + "fadd v16.4s, v16.4s, v10.4s\n" + + "ldr q7, [inptr7], #0x10\n" + "prfm pldl1keep, [inptr15, #196]\n" + "fsub v17.4s, v17.4s, v11.4s\n" + + "str q16, [outptr4], #0x10\n" + "fadd v16.4s, v12.4s, v13.4s\n" + "fsub v18.4s, v13.4s, v14.4s\n" + + "str q17, [outptr5], #0x10\n" + "fadd v16.4s, v16.4s, v14.4s\n" + "fsub v18.4s, v18.4s, v15.4s\n" + + "str q16, [outptr6], #0x10\n" + "subs column, column, #0x4\n" + + "str q18, [outptr7], #0x10\n" + "bne 2b\n" + + "3:" // Tail + "ldr q8, [inptr8], #0x10\n" + "prfm pldl1keep, [%x[inptr], #196]\n" + "fadd v16.4s, v0.4s, v1.4s\n" + + "ldr q9, [inptr9], #0x10\n" + "prfm pldl1keep, [inptr1, #196]\n" + "fsub v17.4s, v1.4s, v2.4s\n" + + "ldr q10, [inptr10], #0x10\n" + "prfm pldl1keep, [inptr2, #196]\n" + "fadd v16.4s, v16.4s, v2.4s\n" + + "ldr q11, [inptr11], #0x10\n" + "prfm pldl1keep, [inptr3, #196]\n" + "fsub v17.4s, v17.4s, v3.4s\n" + + "ldr q12, [inptr12], #0x10\n" + "prfm pldl1keep, [inptr4, #196]\n" + "str q16, [%x[outptr]], #0x10\n" + + "ldr q13, [inptr13], #0x10\n" + "prfm pldl1keep, [inptr5, #196]\n" + "str q17, [outptr1], #0x10\n" + + "ldr q14, [inptr14], #0x10\n" + "prfm pldl1keep, [inptr6, #196]\n" + "fadd v16.4s, v4.4s, v5.4s\n" + + "ldr q15, [inptr15], #0x10\n" + "prfm pldl1keep, [inptr7, #196]\n" + "fsub v17.4s, v5.4s, v6.4s\n" + + "prfm pldl1keep, [inptr8, #196]\n" + "prfm pldl1keep, [inptr9, #196]\n" + "fadd v16.4s, v16.4s, v6.4s\n" + + "prfm pldl1keep, [inptr10, #196]\n" + "prfm pldl1keep, [inptr11, #196]\n" + "fsub v17.4s, v17.4s, v7.4s\n" + + "prfm pldl1keep, [inptr12, #196]\n" + "prfm pldl1keep, [inptr13, #196]\n" + "str q16, [outptr2], #0x10\n" + + "prfm pldl1keep, [inptr14, #196]\n" + "prfm pldl1keep, [inptr15, #196]\n" + "str q17, [outptr3], #0x10\n" + + "fadd v16.4s, v8.4s, v9.4s\n" + "fsub v17.4s, v9.4s, v10.4s\n" + + "fadd v16.4s, v16.4s, v10.4s\n" + "fsub v17.4s, v17.4s, v11.4s\n" + + "str q16, [outptr4], #0x10\n" + "fadd v16.4s, v12.4s, v13.4s\n" + "fsub v18.4s, v13.4s, v14.4s\n" + + "str q17, [outptr5], #0x10\n" + "fadd v16.4s, v16.4s, v14.4s\n" + "fsub v18.4s, v18.4s, v15.4s\n" + + "str q16, [outptr6], #0x10\n" + "str q18, [outptr7], #0x10\n" + + "subs %x[row], %x[row], #0x1\n" + "bne 1b\n" + + ".unreq inptr1\n" + ".unreq inptr2\n" + ".unreq inptr3\n" + ".unreq inptr4\n" + ".unreq inptr5\n" + ".unreq inptr6\n" + ".unreq inptr7\n" + ".unreq inptr8\n" + ".unreq inptr9\n" + ".unreq inptr10\n" + ".unreq inptr11\n" + ".unreq inptr12\n" + ".unreq inptr13\n" + ".unreq inptr14\n" + ".unreq inptr15\n" + ".unreq outptr1\n" + ".unreq outptr2\n" + ".unreq outptr3\n" + ".unreq outptr4\n" + ".unreq outptr5\n" + ".unreq outptr6\n" + ".unreq outptr7\n" + + : [row] "+r" (row), + [inptr] "+r" (inptr), + [outptr] "+r" (outptr) + : [n_channels] "r" (n_channels), + [sizeof_float] "i" (sizeof(float)) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "q16", "q17", "x0", "x1", "x2", "x3", "x4", + "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", + "x16", "x17", "x18", "x19", "x20", "x21", "x22", "cc", "memory" + ); +} + +/*****************************************************************************/ +// Compute ZFZ^T specializations + +template <> +template <> +inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zfzT( + const Tensor4DShape &output_shape, + float* const output, const float* const input +) { + const int tile_M = output_shape.n_rows / 2; + const int tile_N = output_shape.n_cols / 2; + int batch = output_shape.n_batches; + float *outptr = output; + const float *inptr = input; + + asm volatile ( + // Compute input pointers + "inptr1 .req x0\n" + "inptr2 .req x1\n" + "inptr3 .req x2\n" + "inptr4 .req x3\n" + "inptr5 .req x4\n" + "inptr6 .req x5\n" + "inptr7 .req x6\n" + "inptr8 .req x7\n" + + "mstride .req x8\n" + "mul mstride, %x[tile_M], %x[tile_N]\n" + "mul mstride, mstride, %x[n_channels]\n" + "lsl mstride, mstride, #2\n" // * sizeof(float) + + "add inptr1, %[inptr], mstride\n" + "add inptr2, inptr1, mstride\n" + "add inptr3, inptr2, mstride\n" + "add inptr4, inptr3, mstride\n" + "add inptr5, inptr4, mstride\n" + "add inptr6, inptr5, mstride\n" + "add inptr7, inptr6, mstride\n" + "add inptr8, inptr7, mstride\n" + + ".unreq mstride\n" + + // Compute initial output pointers + "outptr01 .req x8\n" + "outptr10 .req x9\n" + "outptr11 .req x10\n" + + "add outptr01, %x[outptr], %x[n_channels], LSL #2\n" + "add outptr10, %x[outptr], %x[row_stride], LSL #2\n" + "add outptr11, outptr10, %x[n_channels], LSL #2\n" + + "tile_i .req x11\n" + "tile_j .req x12\n" + "channel .req x13\n" + + "1:" // Loop over batches + "mov tile_i, %x[tile_M]\n" + + "2:" // Loop over rows of output tiles + "mov tile_j, %x[tile_N]\n" + + "3:" // Loop over columns of output tiles + "ldr q0, [%x[inptr]], #0x10\n" + "ldr q2, [inptr2], #0x10\n" + "subs channel, %x[n_channels], #0x4\n" + + "ldr q1, [inptr1], #0x10\n" + "ldr q3, [inptr3], #0x10\n" + "beq 6f\n" + + "4:" + "ldr q4, [inptr4], #0x10\n" + "ldr q5, [inptr5], #0x10\n" + "fadd v16.4s, v0.4s, v2.4s\n" + + "ldr q6, [inptr6], #0x10\n" + "ldr q7, [inptr7], #0x10\n" + "fadd v17.4s, v1.4s, v3.4s\n" + + "ldr q8, [%x[inptr]], #0x10\n" + "ldr q10, [inptr2], #0x10\n" + "fadd v16.4s, v16.4s, v4.4s\n" + + "ldr q9, [inptr1], #0x10\n" + "ldr q11, [inptr3], #0x10\n" + "fadd v17.4s, v17.4s, v5.4s\n" + + "str q16, [%x[outptr]], #0x10\n" + "prfm pldl1strm, [%x[inptr], #196]\n" + "fsub v18.4s, v2.4s, v4.4s\n" + + "str q17, [outptr01], #0x10\n" + "prfm pldl1strm, [inptr2, #196]\n" + "fsub v19.4s, v3.4s, v5.4s\n" + + "prfm pldl1strm, [inptr1, #196]\n" + "prfm pldl1strm, [inptr3, #196]\n" + "fsub v18.4s, v18.4s, v6.4s\n" + + "prfm pldl1strm, [inptr4, #196]\n" + "prfm pldl1strm, [inptr5, #196]\n" + "fsub v19.4s, v19.4s, v7.4s\n" + + "str q18, [outptr10], #0x10\n" + "prfm pldl1strm, [inptr6, #196]\n" + "prfm pldl1strm, [inptr7, #196]\n" + + "subs channel, channel, #0x4\n" + + "str q19, [outptr11], #0x10\n" + "beq 6f\n" // Branch to tail + + "ldr q12, [inptr4], #0x10\n" + "ldr q13, [inptr5], #0x10\n" + "fadd v16.4s, v8.4s, v10.4s\n" + + "ldr q14, [inptr6], #0x10\n" + "ldr q15, [inptr7], #0x10\n" + "fadd v17.4s, v9.4s, v11.4s\n" + + "ldr q0, [%x[inptr]], #0x10\n" + "ldr q2, [inptr2], #0x10\n" + "fadd v16.4s, v16.4s, v12.4s\n" + + "ldr q1, [inptr1], #0x10\n" + "ldr q3, [inptr3], #0x10\n" + "fadd v17.4s, v17.4s, v13.4s\n" + + "str q16, [%x[outptr]], #0x10\n" + "prfm pldl1strm, [%x[inptr], #196]\n" + "fsub v18.4s, v10.4s, v12.4s\n" + + "str q17, [outptr01], #0x10\n" + "prfm pldl1strm, [inptr2, #196]\n" + "fsub v19.4s, v11.4s, v13.4s\n" + + "prfm pldl1strm, [inptr1, #196]\n" + "prfm pldl1strm, [inptr3, #196]\n" + "fsub v18.4s, v18.4s, v14.4s\n" + + "prfm pldl1strm, [inptr4, #196]\n" + "prfm pldl1strm, [inptr5, #196]\n" + "fsub v19.4s, v19.4s, v15.4s\n" + + "str q18, [outptr10], #0x10\n" + "prfm pldl1strm, [inptr6, #196]\n" + "prfm pldl1strm, [inptr7, #196]\n" + + "subs channel, channel, #0x4\n" + + "str q19, [outptr11], #0x10\n" + "bne 4b\n" // Continue loop + + "5:" // Tail + "ldr q12, [inptr4], #0x10\n" + "ldr q13, [inptr5], #0x10\n" + "fadd v16.4s, v8.4s, v10.4s\n" + + "ldr q14, [inptr6], #0x10\n" + "ldr q15, [inptr7], #0x10\n" + "fadd v17.4s, v9.4s, v11.4s\n" + + "fadd v16.4s, v16.4s, v12.4s\n" + + "fadd v17.4s, v17.4s, v13.4s\n" + + "str q16, [%x[outptr]], #0x10\n" + "fsub v18.4s, v10.4s, v12.4s\n" + "fsub v19.4s, v11.4s, v13.4s\n" + + "str q17, [outptr01], #0x10\n" + "fsub v18.4s, v18.4s, v14.4s\n" + "fsub v19.4s, v19.4s, v15.4s\n" + + "str q18, [outptr10], #0x10\n" + "str q19, [outptr11], #0x10\n" + "b 7f\n" + + "6:" // Tail + "ldr q4, [inptr4], #0x10\n" + "ldr q5, [inptr5], #0x10\n" + "fadd v16.4s, v0.4s, v2.4s\n" + + "ldr q6, [inptr6], #0x10\n" + "ldr q7, [inptr7], #0x10\n" + "fadd v17.4s, v1.4s, v3.4s\n" + + "fadd v16.4s, v16.4s, v4.4s\n" + + "fadd v17.4s, v17.4s, v5.4s\n" + + "str q16, [%x[outptr]], #0x10\n" + "fsub v18.4s, v2.4s, v4.4s\n" + "fsub v19.4s, v3.4s, v5.4s\n" + + "str q17, [outptr01], #0x10\n" + "fsub v18.4s, v18.4s, v6.4s\n" + "fsub v19.4s, v19.4s, v7.4s\n" + + "str q18, [outptr10], #0x10\n" + "str q19, [outptr11], #0x10\n" + + "7:" + "add %x[outptr], %x[outptr], %x[n_channels], LSL #2\n" + "add outptr01, outptr01, %x[n_channels], LSL #2\n" + "add outptr10, outptr10, %x[n_channels], LSL #2\n" + "add outptr11, outptr11, %x[n_channels], LSL #2\n" + + "subs tile_j, tile_j, #1\n" + "bne 3b\n" + + // Progress the output pointers to the new row + "add %x[outptr], %x[outptr], %x[row_stride], LSL #2\n" + "add outptr01, outptr01, %x[row_stride], LSL #2\n" + "add outptr10, outptr10, %x[row_stride], LSL #2\n" + "add outptr11, outptr11, %x[row_stride], LSL #2\n" + + "subs tile_i, tile_i, #1\n" + "bne 2b\n" + + "subs %[batch], %[batch], #1\n" + "bne 1b\n" + "5:" + + ".unreq inptr1\n" + ".unreq inptr2\n" + ".unreq inptr3\n" + ".unreq inptr4\n" + ".unreq inptr5\n" + ".unreq inptr6\n" + ".unreq inptr7\n" + ".unreq inptr8\n" + ".unreq outptr01\n" + ".unreq outptr10\n" + ".unreq outptr11\n" + : [batch] "+r" (batch), + [outptr] "+r" (outptr), + [inptr] "+r" (inptr) + : [tile_M] "r" (tile_M), + [tile_N] "r" (tile_N), + [n_channels] "r" (output_shape.n_channels), + [row_stride] "r" (output_shape.n_cols * output_shape.n_channels) + : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", + "x12", "x13", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "cc", "memory" + ); +} +/*****************************************************************************/ + +/*****************************************************************************/ +template <> +inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::execute( + const Tensor4DShape &output_shape, + float* const matrices[16], float* const output +) { + // profiler prof; + + // Allocate memory for the intermediate matrices + const int tile_M = iceildiv(output_shape.n_rows, 2); + const int tile_N = iceildiv(output_shape.n_cols, 2); + const int n_rows = output_shape.n_batches * tile_M * tile_N; + const int n_channels = output_shape.n_channels; + float* matrices_zf = reinterpret_cast( + calloc(8 * n_rows * n_channels, sizeof(float)) + ); + + // Perform the first stage transform, computing ZF. + const auto f_compute_zf = [&] () { + switch (n_channels % 4) { + case 0: + compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); + break; + case 1: + compute_zf<1>(n_rows, n_channels, matrices_zf, matrices); + break; + case 2: + compute_zf<2>(n_rows, n_channels, matrices_zf, matrices); + break; + case 3: + compute_zf<3>(n_rows, n_channels, matrices_zf, matrices); + }; + }; + // prof("Compute ZF", f_compute_zf, 16 * n_rows * n_channels * sizeof(float), 0, 8 * n_rows * n_channels * sizeof(float)); + f_compute_zf(); + + // Perform the second stage transform, finishing Z F Z^T - variable dispatch + // based on size of the output and the channel tail. + const auto f_compute_zfzT = [&] () { + if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { + constexpr bool tail_M = true, tail_N = true; + switch (n_channels % 4) { + case 0: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 1: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 2: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 3: + compute_zfzT(output_shape, output, matrices_zf); + } + } else if (output_shape.n_rows % 2) { + constexpr bool tail_M = true, tail_N = false; + switch (n_channels % 4) { + case 0: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 1: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 2: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 3: + compute_zfzT(output_shape, output, matrices_zf); + } + } else if (output_shape.n_cols % 2) { + constexpr bool tail_M = false, tail_N = true; + switch (n_channels % 4) { + case 0: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 1: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 2: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 3: + compute_zfzT(output_shape, output, matrices_zf); + } + } else { + constexpr bool tail_M = false, tail_N = false; + switch (n_channels % 4) { + case 0: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 1: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 2: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 3: + compute_zfzT(output_shape, output, matrices_zf); + } + } + }; + // prof("Compute ZFZT", f_compute_zfzT, 8 * n_rows * n_channels * sizeof(float), 0, 4 * n_rows * n_channels * sizeof(float)); + f_compute_zfzT(); + + free(reinterpret_cast(matrices_zf)); +} +/*****************************************************************************/ + +#endif // __aarch64__ -- cgit v1.2.1