diff options
author | Pablo Tello <pablo.tello@arm.com> | 2018-01-10 16:44:13 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:43:42 +0000 |
commit | 9ceebbeb8dfe61746fdc7022a147f8e2d24c5493 (patch) | |
tree | 38647440e57fe7f8f7e7996cd8f5ede7d1bca530 /src/core/NEON/kernels/winograd/transforms | |
parent | 00afd11eaa7d408ff873732639c9a724fece9058 (diff) | |
download | ComputeLibrary-9ceebbeb8dfe61746fdc7022a147f8e2d24c5493.tar.gz |
COMPMID-815: Updated NEWinogradLayer with the lastest code from Research.
Change-Id: I86d7f53b5f5d1dbc22078aea5c32b08a25d1f49e
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/116634
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/winograd/transforms')
14 files changed, 1926 insertions, 5776 deletions
diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp deleted file mode 100644 index ca8d012e5e..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp +++ /dev/null @@ -1,639 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include "arm_compute/core/NEON/kernels/winograd/tensor.hpp" - - -namespace winograd { - /* Transform an input tensor into the Winograd domain. - */ - template <typename T> - struct Winograd2x2_3x3GemmInput { - static void execute( - const T *inptr, - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const int tile_M, - const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride - ); - - static size_t bytes_read(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - return input_shape.n_batches * tile_rows * (16 + 8*(tile_cols - 1)) * input_shape.n_channels * sizeof(T); - } - - static int flops_performed(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - return input_shape.n_batches * tile_rows * (32 + 24*(tile_cols - 1)) * input_shape.n_channels; - } - - static size_t bytes_written(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - const int M = input_shape.n_batches * tile_rows * tile_cols; - return 16 * M * input_shape.n_channels * sizeof(T); - } - - protected: - template <const PaddingType padding, const int pad_bottom, const int pad_right> - static void process_tile_tensor( - const int tile_M, // Number of rows of tiles - const int tile_N, // Number of columns of tiles - int n_channels, // Number of input channels - const T* const input, // Base input pointer (appropriate to batch and channel) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - T* const matrix, // 1st output matrix (appropriate to batch and channel) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix - ); - - template <const int pad_top, const int pad_left, - const int pad_bottom, const int pad_right, - const int proc_channels> - static void process_tile_row( - const int tile_N, // Number of tiles in the row - const T* const input, // Base input pointer (appropriate to batch, channel and row) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - T* const matrix, // 1st output matrix (appropriate to batch, channel and row) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix - ); - }; - - template <typename T> - struct Winograd2x2_3x3GemmInputChannelwise { - static void execute( - const T *inptr, - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const int tile_M, - const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride - ); - - static size_t bytes_read(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - // We read as many bytes as we write - return bytes_written(input_shape, output_shape); - } - - static int flops_performed(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - return input_shape.n_batches * tile_rows * 32 * tile_cols * input_shape.n_channels; - } - - static size_t bytes_written(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - return winograd::Winograd2x2_3x3GemmInput<T>::bytes_written(input_shape, output_shape); - } - - protected: - typedef void (*tilefunc)(int, const T*, int, int, T*, int); - template <const int pad_top, - const int pad_left, - const int pad_bottom, - const int pad_right> - static void process_tile( - int n_channels, // Number of channels in the tile - const T* const input_base, - const int input_row_stride, - const int input_col_stride, - T* const matrix_base, - const int matrix_stride - ); - - private: - template <const int pad_top, - const int pad_left, - const int pad_bottom, - const int pad_right, - const int proc_channels> - static void _process_tile( - int &n_channels, const T* &inptr, - const int input_row_stride, const int input_col_stride, - T* &outptr, const int matrix_stride - ); - }; -} - -/*****************************************************************************/ -// Include specialised implementations here -#include "input_2x2_3x3/a64_float.hpp" -#include "input_2x2_3x3/a64_float_channelwise.hpp" -/*****************************************************************************/ - -/*****************************************************************************/ -template <typename T> -void winograd::Winograd2x2_3x3GemmInput<T>::execute( - const T *inptr_base, - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const int tile_M, - const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride -) { - // Select an appropriate matrix processing method for the shape and padding - // of the input tensor. - typedef void (*tensorfunc)(int, int, int, const T*, int, int, T*, int, int); - const auto process_tensor = [&padding_type, &input_shape] () -> tensorfunc { - if (padding_type == PADDING_VALID) { - const int pad_bottom = input_shape.n_rows % 2; - const int pad_right = input_shape.n_cols % 2; - - if (pad_bottom == 0 && pad_right == 0) { - return process_tile_tensor<PADDING_VALID, 0, 0>; - } else if (pad_bottom == 0 && pad_right == 1) { - return process_tile_tensor<PADDING_VALID, 0, 1>; - } else if (pad_bottom == 1 && pad_right == 0) { - return process_tile_tensor<PADDING_VALID, 1, 0>; - } else if (pad_bottom == 1 && pad_right == 1) { - return process_tile_tensor<PADDING_VALID, 1, 1>; - } - } else { // PADDING_SAME - const int pad_bottom = 1 + input_shape.n_rows % 2; - const int pad_right = 1 + input_shape.n_cols % 2; - - if (pad_bottom == 1 && pad_right == 1) { - return process_tile_tensor<PADDING_SAME, 1, 1>; - } else if (pad_bottom == 1 && pad_right == 2) { - return process_tile_tensor<PADDING_SAME, 1, 2>; - } else if (pad_bottom == 2 && pad_right == 1) { - return process_tile_tensor<PADDING_SAME, 2, 1>; - } else if (pad_bottom == 2 && pad_right == 2) { - return process_tile_tensor<PADDING_SAME, 2, 2>; - } - } - - printf("%s::%u Uncovered case.\n", __FILE__, __LINE__); - exit(-1); - return NULL; // No function found - } (); - - // Compute strides - const int input_row_stride = input_shape.n_cols * input_shape.n_channels; - const int input_col_stride = input_shape.n_channels; - - // Process each batch of the tensor in turn. - for (int batch = 0; batch < input_shape.n_batches; batch++) { - // Work out pointers - const T *inptr = inptr_base + (batch * input_shape.n_rows * - input_shape.n_cols * input_shape.n_channels); - T *outptr = outptr_base + batch * matrix_batch_stride; - - // Delegate doing the actual work - process_tensor( - tile_M, tile_N, input_shape.n_channels, - inptr, input_row_stride, input_col_stride, - outptr, matrix_stride, matrix_row_stride - ); - } -} - -/*****************************************************************************/ -template <typename T> -template <const PaddingType padding, const int pad_bottom, const int pad_right> -void winograd::Winograd2x2_3x3GemmInput<T>::process_tile_tensor( - const int tile_M, // Number of rows of tiles - const int tile_N, // Number of columns of tiles - int n_channels, // Number of input channels - const T* const input, // Base input pointer (appropriate to batch and channel) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - T* const matrix, // 1st output matrix (appropriate to batch and channel) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix -) { - // Base row processing functions - typedef void (*rowfunc)(int, const T*, int, int, T*, int, int); - const rowfunc process_top_row[3] = { - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 1> - : process_tile_row<1, 1, 0, pad_right, 1>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 2> - : process_tile_row<1, 1, 0, pad_right, 2>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 4> - : process_tile_row<1, 1, 0, pad_right, 4>, - }; - const rowfunc process_middle_row[3] = { - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 1> - : process_tile_row<0, 1, 0, pad_right, 1>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 2> - : process_tile_row<0, 1, 0, pad_right, 2>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 4> - : process_tile_row<0, 1, 0, pad_right, 4>, - }; - const rowfunc process_bottom_row[3] = { - (padding == PADDING_VALID) - ? process_tile_row<0, 0, pad_bottom, pad_right, 1> - : process_tile_row<0, 1, pad_bottom, pad_right, 1>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, pad_bottom, pad_right, 2> - : process_tile_row<0, 1, pad_bottom, pad_right, 2>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, pad_bottom, pad_right, 4> - : process_tile_row<0, 1, pad_bottom, pad_right, 4>, - }; - - // Method to get an input pointer for the given tile row - const auto get_inptr = [&input, &input_row_stride] (const int tile_i) { - if (padding == PADDING_VALID) { - return input + 2 * tile_i * input_row_stride; - } else { - return input + (2 * tile_i - (tile_i ? 1 : 0)) * input_row_stride; - } - }; - - // Wrapper to process a row of tiles, covering all channels. - const auto process_row = - [tile_N, input_row_stride, input_col_stride, matrix_stride, matrix_row_stride, n_channels] - (const rowfunc f[3], const T *inptr, T *outptr) { - int rem_channels = n_channels; - - // While there remain channels to process continue to process the - // row. - for (; rem_channels >= 4; rem_channels -= 4, inptr += 4, outptr += 4) { - f[2](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); - } - for (; rem_channels >= 2; rem_channels -= 2, inptr += 2, outptr += 2) { - f[1](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); - } - if (rem_channels) { - f[0](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); - } - }; - - // Process all rows of tiles in the tensor - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - T* const m_row = matrix + tile_i * tile_N * matrix_row_stride; - const T *row_inptr = get_inptr(tile_i); - - if (tile_i == 0) { - // Top row of the input - process_row(process_top_row, row_inptr, m_row); - } else if (tile_i == tile_M - 1) { - // Bottom row of the input - process_row(process_bottom_row, row_inptr, m_row); - } else { - // Any other row of the input - process_row(process_middle_row, row_inptr, m_row); - } - } -} - -/*****************************************************************************/ -template <typename T> -template <const int pad_top, const int pad_left, - const int pad_bottom, const int pad_right, - const int proc_channels> -void winograd::Winograd2x2_3x3GemmInput<T>::process_tile_row( - const int tile_N, // Number of tiles in the row - const T* const input, // Base input pointer (appropriate to batch, channel and row) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - T* const matrix, // 1st output matrix (appropriate to batch, channel and row) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix -) { - // Construct copies of the pointers - const T *inptr = input; - T *outptr = matrix; - - // Storage for the tensors x, X.T x, and X.T x X. - T x[4][4][proc_channels], XTx[4][4][proc_channels], XTxX[4][4][proc_channels]; - - // For every tile in the row - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - // Determine the padding for the tile - const int tile_pad_left = (tile_j == 0) ? pad_left : 0; - const int tile_pad_right = (tile_j == tile_N - 1) ? pad_right : 0; - - // Load tile values. If this is the first tile in the row then we must load - // all values, otherwise we can just load the final two columns of the input. - for (int i = 0; i < 4; i++) { - for (int j = ((tile_j == 0) ? 0 : 2); j < 4; j++) { - // Fill with padding if required - if (i < pad_top || 4 - pad_bottom <= i || - j < tile_pad_left || 4 - tile_pad_right <= j) { - for (int c = 0; c < proc_channels; c++) { - x[i][j][c] = static_cast<T>(0); // Padding - } - } else { - // Load values, note that the initial padding offsets the pointer we - // were provided. - for (int c = 0; c < proc_channels; c++) { - const int row_offset = (i - pad_top) * input_row_stride; - const int col_offset = (j - tile_pad_left) * input_col_stride; - x[i][j][c] = inptr[row_offset + col_offset + c]; - } - } - } - } - - // Compute the matrix X.T x. Note, can elide operations depending on the - // padding. Furthermore, if this isn't the left-most tile we can skip half - // of the operations by copying results from the previous version of X.T x. - // This latter optimisation can be simplified by unrolling the outermost - // loop by two and by renaming the registers containing XTx. - if (tile_j == 0) { - for (int j = 0; j < 4; j++) { - for (int c = 0; c < proc_channels; c++) { - XTx[0][j][c] = x[0][j][c] - x[2][j][c]; - XTx[1][j][c] = x[1][j][c] + x[2][j][c]; - XTx[2][j][c] = -x[1][j][c] + x[2][j][c]; - XTx[3][j][c] = x[1][j][c] - x[3][j][c]; - } - } - } else { - for (int j = 0; j < 2; j++) { - for (int c = 0; c < proc_channels; c++) { - XTx[0][j][c] = XTx[0][j + 2][c]; - XTx[1][j][c] = XTx[1][j + 2][c]; - XTx[2][j][c] = XTx[2][j + 2][c]; - XTx[3][j][c] = XTx[3][j + 2][c]; - } - } - for (int j = 2; j < 4; j++) { - for (int c = 0; c < proc_channels; c++) { - XTx[0][j][c] = x[0][j][c] - x[2][j][c]; - XTx[1][j][c] = x[1][j][c] + x[2][j][c]; - XTx[2][j][c] = -x[1][j][c] + x[2][j][c]; - XTx[3][j][c] = x[1][j][c] - x[3][j][c]; - } - } - } - - // Compute the matrix X.T x X. Note, can elide operations based on the - // padding. - for (int i = 0; i < 4; i++) { - for (int c = 0; c < proc_channels; c++) { - XTxX[i][0][c] = XTx[i][0][c] - XTx[i][2][c]; - XTxX[i][1][c] = XTx[i][1][c] + XTx[i][2][c]; - XTxX[i][2][c] = -XTx[i][1][c] + XTx[i][2][c]; - XTxX[i][3][c] = XTx[i][1][c] - XTx[i][3][c]; - } - } - - // Store the output matrix (X.T x X) - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - // Get a pointer to the relevant output matrix - T *mptr = outptr + (i*4 + j)*matrix_stride; - - // Write out the channels - for (int c = 0; c < proc_channels; c++) { - mptr[c] = XTxX[i][j][c]; - } - } - } - - // Update the pointers - inptr += input_col_stride * ((tile_j == 0 && pad_left) ? 1 : 2); - outptr += matrix_row_stride; - } -} - -/*****************************************************************************/ -template <typename T> -void winograd::Winograd2x2_3x3GemmInputChannelwise<T>::execute( - const T *inptr, - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const int tile_M, - const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride -) { - const int n_channels = input_shape.n_channels; - const int input_col_stride = n_channels; - const int input_row_stride = input_shape.n_cols * input_col_stride; - - // Determine the padding and hence select appropriate methods for each tile. - tilefunc fs[3][3]; - - if (padding_type == PADDING_VALID) { - constexpr int pad_top = 0; - constexpr int pad_left = 0; - const int pad_right = input_shape.n_cols % 2 == 0; - - fs[0][0] = process_tile<pad_top, pad_left, 0, 0>; - fs[0][1] = process_tile<pad_top, 0, 0, 0>; - fs[0][2] = (pad_right) ? process_tile<pad_top, 0, 0, 0> : process_tile<pad_top, 0, 0, 1>; - - fs[1][0] = process_tile<0, pad_left, 0, 0>; - fs[1][1] = process_tile<0, 0, 0, 0>; - fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 0> : process_tile<0, 0, 0, 1>; - - if (input_shape.n_rows % 2 == 0) { - constexpr int pad_bottom = 0; - fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; - fs[2][1] = process_tile<0, 0, pad_bottom, 0>; - fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>; - } else { - constexpr int pad_bottom = 1; - fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; - fs[2][1] = process_tile<0, 0, pad_bottom, 0>; - fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>; - } - } else { - constexpr int pad_top = 1; - constexpr int pad_left = 1; - const int pad_right = input_shape.n_cols % 2 == 0; - - fs[0][0] = process_tile<pad_top, pad_left, 0, 0>; - fs[0][1] = process_tile<pad_top, 0, 0, 0>; - fs[0][2] = (pad_right) ? process_tile<pad_top, 0, 0, 1> : process_tile<pad_top, 0, 0, 2>; - - fs[1][0] = process_tile<0, pad_left, 0, 0>; - fs[1][1] = process_tile<0, 0, 0, 0>; - fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 1> : process_tile<0, 0, 0, 2>; - - if (input_shape.n_rows % 2 == 0) { - constexpr int pad_bottom = 1; - fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; - fs[2][1] = process_tile<0, 0, pad_bottom, 0>; - fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>; - } else { - constexpr int pad_bottom = 2; - fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; - fs[2][1] = process_tile<0, 0, pad_bottom, 0>; - fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>; - } - } - - // Process each tile in turn - for (int batch = 0; batch < input_shape.n_batches; batch++) { - const T* const input_base_batch = inptr + batch*input_shape.n_rows*input_shape.n_cols*n_channels; - - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - const int row_offset = (tile_i == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1); - const T* const input_base_row = input_base_batch + (2*tile_i - row_offset)*input_shape.n_cols*n_channels; - - // Select the set of functions for the row - const int fs_i = (tile_i == 0) ? 0 : ((tile_i < tile_M - 1) ? 1 : 2); - - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - // Select the function for the column - const int fs_j = (tile_j == 0) ? 0 : ((tile_j < tile_N - 1) ? 1 : 2); - const auto f = fs[fs_i][fs_j]; - - // Get pointers into the input and outputs - const int col_offset = (tile_j == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1); - const T* const input_base_col = input_base_row + (2*tile_j - col_offset)*n_channels; - T* const matrix_base = outptr_base + batch*matrix_batch_stride + (tile_i*tile_N + tile_j)*matrix_row_stride; - f(n_channels, input_base_col, input_row_stride, input_col_stride, - matrix_base, matrix_stride); - } - } - } -} - -template <typename T> -template <const int pad_top, - const int pad_left, - const int pad_bottom, - const int pad_right> -void winograd::Winograd2x2_3x3GemmInputChannelwise<T>::process_tile( - int n_channels, // Number of channels in the tile - const T* const input_base, - const int input_row_stride, - const int input_col_stride, - T* const matrix_base, - const int matrix_stride -) { - // Copy pointers - const T *inptr = input_base; - T *outptr = matrix_base; - - // Process channels (modifies inptr, outptr and n_channels) - _process_tile<pad_top, pad_left, pad_bottom, pad_right, 4>( - n_channels, inptr, input_row_stride, input_col_stride, - outptr, matrix_stride - ); - _process_tile<pad_top, pad_left, pad_bottom, pad_right, 2>( - n_channels, inptr, input_row_stride, input_col_stride, - outptr, matrix_stride - ); - _process_tile<pad_top, pad_left, pad_bottom, pad_right, 1>( - n_channels, inptr, input_row_stride, input_col_stride, - outptr, matrix_stride - ); -} - -template <typename T> -template <const int pad_top, - const int pad_left, - const int pad_bottom, - const int pad_right, - const int proc_channels> -void winograd::Winograd2x2_3x3GemmInputChannelwise<T>::_process_tile( - int &n_channels, - const T* &inptr, const int input_row_stride, const int input_col_stride, - T* &outptr, const int matrix_stride -) { - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - T* outptrs[4] = { - outptr, - outptr + matrix_stride * 4, - outptr + matrix_stride * 8, - outptr + matrix_stride * 12 - }; - - // The matrix X; zeroed to account for padding. - T x[4][4]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - x[i][j] = 0; - } - } - - // The matrices X.T x and U - T XTx[4][4], U[4][4]; - - // Now progress through each channel - for (; n_channels >= proc_channels; n_channels -= proc_channels) { - for (int n = 0; n < proc_channels; n++) { - // Load the matrix X - for (int cell_i = pad_top, i = 0; cell_i < 4 - pad_bottom; cell_i++, i++) { - for (int cell_j = pad_left, j = 0; cell_j < 4 - pad_right; cell_j++, j++) { - x[cell_i][cell_j] = inptr[i*input_row_stride + j*input_col_stride]; - } - } - inptr++; - - // Compute the matrix X.T - for (int j = 0; j < 4; j++) { - XTx[0][j] = x[0][j] - x[2][j]; - XTx[1][j] = x[1][j] + x[2][j]; - XTx[2][j] = x[2][j] - x[1][j]; - XTx[3][j] = x[1][j] - x[3][j]; - } - - // Hence compute the matrix U - for (int i = 0; i < 4; i++) { - U[i][0] = XTx[i][0] - XTx[i][2]; - U[i][1] = XTx[i][1] + XTx[i][2]; - U[i][2] = XTx[i][2] - XTx[i][1]; - U[i][3] = XTx[i][1] - XTx[i][3]; - } - - // Store the matrix U - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - outptrs[i][j * matrix_stride] = U[i][j]; - } - outptrs[i]++; - } - } - } - - // Update the output pointer for future calls - outptr = outptrs[0]; -} diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp deleted file mode 100644 index a99cbe325b..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp +++ /dev/null @@ -1,1498 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include "../input_2x2_3x3.hpp" - -#ifdef __aarch64__ -namespace winograd { - -// Pad left by one column, pad right by one column, no upper or lower padding, 4 channels -template <> -template <> -inline void Winograd2x2_3x3GemmInput<float>::process_tile_row<0, 1, 0, 1, 4>( - const int tile_N, // Number of tiles in the row - const float* const input, // Base input pointer (appropriate to batch, channel and row) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - float* const matrix, // 1st output matrix (appropriate to batch, channel and row) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix -) { - /* SIMD register allocation - * ======================== - * - * In the following code we read 4x4 tiles of a matrix `x`, with which we - * compute another matrix `X.T x` where: - * - * / 1 0 0 0 \ - * X = | 0 1 -1 1 | - * | -1 1 1 0 | - * \ 0 0 0 -1 / - * - * Hence, `X.T` is a program which operates upon rows of the matrix `X`. - * We subsequently compute and store the matrix `U = (X.T x) X`. - * - * Importantly, each iteration of the loop below loads a new matrix `x'` - * where the final two columns of `x'` are the first two columns of the - * previous `x`. That is: - * - * x11 x12 x13 x14 - * x21 x22 x23 x24 - * x31 x32 x33 x34 - * x41 x42 x43 x44 - * - * x'11 x'12 x'13 x'14 - * x'21 x'22 x'23 x'24 - * x'31 x'32 x'33 x'34 - * x'41 x'42 x'43 x'44 - * - * Consequently, while the first iteration of the below loop must load 16 - * values for `x`, the second need load only 8. *Furthermore*, since we noted - * above that the operation `X.T x` was a program which operated upon *rows* - * of the matrix `x` it follows that that the relation that `x'[i][1] = - * x[i][3]` and `x'[i][2] = x[i][4]` applies also the matrices `X.T x'` and - * `X.T x`. That is: - * - * (X.T x)11 (X.T x)12 (X.T x)13 (X.T x)14 - * (X.T x)21 (X.T x)22 (X.T x)23 (X.T x)24 - * (X.T x)31 (X.T x)32 (X.T x)33 (X.T x)34 - * (X.T x)41 (X.T x)42 (X.T x)43 (X.T x)44 - * - * (X.T x')11 (X.T x')12 (X.T x')13 (X.T x')14 - * (X.T x')12 (X.T x')12 (X.T x')12 (X.T x')12 - * (X.T x')13 (X.T x')13 (X.T x')13 (X.T x')13 - * (X.T x')14 (X.T x')14 (X.T x')14 (X.T x')14 - * - * Hence, as well as not needing to load new values for x'[i][1..2] it is - * also unnecessary to recompute values for (X.T x')[i][1..2]. - * - * Following this we break the registers into blocks `A` and `B` used by the - * two stages of the unrolled loop. These registers named such that the - * latter columns of `A` become the earlier columns of `B` and vice-versa: - * - * AXTx11 AXTx12 > AXTx13 AXTx14 | - * AXTx21 AXTx22 > AXTx23 AXTx24 | - * AXTx31 AXTx32 > AXTx33 AXTx34 | - * AXTx41 AXTx42 > AXTx43 AXTx44 | - * - * BXTx13 BXTx14 | BXTx11 BXTx12 > - * BXTx23 BXTx24 | BXTx21 BXTx22 > - * BXTx33 BXTx34 | BXTx31 BXTx32 > - * BXTx43 BXTx44 | BXTx41 BXTx42 > - * - * These 32 named registers require only 16 architectural registers. 1 - * additional architectural register is used as scratch space and 8 - * architectural registers are used to load in the values x[1..4][3,4]. - * - * Input and output addressing - * =========================== - * TODO Description - */ - const float *inptr0 = input; - const float *inptr1 = input + input_row_stride; - const float *inptr2 = input + input_row_stride * 2; - const float *inptr3 = input + input_row_stride * 3; - - float *outptr0 = matrix; - float *outptr4 = matrix + matrix_stride * 4; - float *outptr8 = matrix + matrix_stride * 8; - float *outptr12 = matrix + matrix_stride * 12; - - int tile_j = tile_N; // Tiles to process - - asm volatile ( - // Named SIMD registers according to the policy given above - // Registers into which to load the latter two columns of `x` - "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n" - "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" - "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" - "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n" - - // Registers for storing X.T x (both A and B halves) - "AXTx11 .req v8\n" "BXTx13 .req v8\n" - "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" - "AXTx21 .req v10\n" "BXTx23 .req v10\n" - "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" - "AXTx31 .req v12\n" "BXTx33 .req v12\n" - "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" - "AXTx41 .req v14\n" "BXTx43 .req v14\n" - "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" - "AXTx13 .req v16\n" "BXTx11 .req v16\n" - "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" - "AXTx23 .req v18\n" "BXTx21 .req v18\n" - "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" - "AXTx33 .req v20\n" "BXTx31 .req v20\n" - "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" - "AXTx43 .req v22\n" "BXTx41 .req v22\n" - "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" - - // Result register (TODO Does using more registers yield better - // performance) - "U .req v24\n qU .req q24\n" - - // ---------------------------------------------------------------------- - // Head of loop - // Loads a complete 4x4 tile of x, computes X.T x, computes and stores - // `U = X.T x X`. Prepares for the 'A' half of the loop. - // NOTE: Since the first tile has the leftmost column padded we can - // skip 4 loads and 4 calculations for the matrix X.T x X. - - // Temporarily alias registers for computing the first (non-padded) - // column of x. - "x_12 .req v0\n qx_12 .req q0\n" - "x_22 .req v1\n qx_22 .req q1\n" - "x_32 .req v2\n qx_32 .req q2\n" - "x_42 .req v3\n qx_42 .req q3\n" - - "ldr qx_12, [%x[inptr0]]\n" - "ldr qx_22, [%x[inptr1]]\n" - "ldr qx_32, [%x[inptr2]]\n" - "ldr qx_42, [%x[inptr3]]\n" - - "fsub BXTx12.4s, x_12.4s, x_32.4s\n" - "fadd BXTx22.4s, x_22.4s, x_32.4s\n" - "fsub BXTx32.4s, x_32.4s, x_22.4s\n" - "fsub BXTx42.4s, x_22.4s, x_42.4s\n" - - ".unreq x_12\n .unreq qx_12\n" - ".unreq x_22\n .unreq qx_22\n" - ".unreq x_32\n .unreq qx_32\n" - ".unreq x_42\n .unreq qx_42\n" - - // Load and compute latter two columns of the first tile. Progress the - // input pointers (by three columns so that the each points are the - // second column of the next tile, that is, each points at the first - // column which must be read for the next tile. - "ldr qx_13, [%x[inptr0], %x[colstride1]]\n" - "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" - "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" - "ldr qx_43, [%x[inptr3], %x[colstride1]]\n" - - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride2]]\n" - - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" - - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" - - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride2]]\n" - - "fsub BXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride3]\n" - - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride3]\n" - - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride3]\n" - - "fsub BXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride3]\n" - - // Compute and store U for the first tile - // First row - "fneg U.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fneg U.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fneg U.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row, simultaneously load the first column of inputs for the - // next tile. - "fneg U.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - // Update the loop counter, subtract two to account for both the head and - // the tail. - "subs %x[tile_j], %x[tile_j], #2\n" - "beq 2f\n" // Jump to "A" tail if out of tiles - - // ---------------------------------------------------------------------- - "1:" - // Start part A - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fsub AXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "fsub AXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" - "fsub AXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride2]\n" - "fadd AXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub AXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "fsub AXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride2]\n" - - // Compute and store U. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, AXTx12.4s, AXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, AXTx22.4s, AXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, AXTx32.4s, AXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, AXTx42.4s, AXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - "subs %x[tile_j], %x[tile_j], #1\n" - "beq 3f\n" // Jump to 'B' tail - - // Start part B - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" - "fsub BXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride2]\n" - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "fsub BXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride2]\n" - - // Compute and store U. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - "subs %x[tile_j], %x[tile_j], #1\n" - "bne 1b\n" // Continue loop, otherwise flow into 'A' tail - - // ---------------------------------------------------------------------- - "2:" - // 'A' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fsub AXTx13.4s, x_13.4s, x_33.4s\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "fsub AXTx43.4s, x_23.4s, x_43.4s\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" - - "b 4f\n" // Jump to end of function - - // ---------------------------------------------------------------------- - "3:" - // 'B' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" - - // ---------------------------------------------------------------------- - "4:" - // End of function - - // Clear names - ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n" - ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" - ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" - ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n" - ".unreq AXTx11\n" ".unreq BXTx13\n" - ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" - ".unreq AXTx21\n" ".unreq BXTx23\n" - ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" - ".unreq AXTx31\n" ".unreq BXTx33\n" - ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" - ".unreq AXTx41\n" ".unreq BXTx43\n" - ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" - ".unreq AXTx13\n" ".unreq BXTx11\n" - ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" - ".unreq AXTx23\n" ".unreq BXTx21\n" - ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" - ".unreq AXTx33\n" ".unreq BXTx31\n" - ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" - ".unreq AXTx43\n" ".unreq BXTx41\n" - ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" - ".unreq U\n" ".unreq qU\n" - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [inptr3] "+r" (inptr3), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [tile_j] "+r" (tile_j) // Tile counter - : [colstride1] "r" (1 * input_col_stride * sizeof(float)), - [colstride2] "r" (2 * input_col_stride * sizeof(float)), - [colstride3] "r" (3 * input_col_stride * sizeof(float)), - [mstride1] "r" (1 * matrix_stride * sizeof(float)), - [mstride2] "r" (2 * matrix_stride * sizeof(float)), - [mstride3] "r" (3 * matrix_stride * sizeof(float)), - [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24" - ); -} - -// Pad top, left and right by 1. -template <> -template <> -inline void Winograd2x2_3x3GemmInput<float>::process_tile_row<1, 1, 0, 1, 4>( - const int tile_N, - const float* const input, - const int input_row_stride, - const int input_col_stride, - float* const matrix, - const int matrix_stride, - const int matrix_row_stride -) { - const float *inptr0 = input; - const float *inptr1 = input + input_row_stride; - const float *inptr2 = input + input_row_stride * 2; - - float *outptr0 = matrix; - float *outptr4 = matrix + matrix_stride * 4; - float *outptr8 = matrix + matrix_stride * 8; - float *outptr12 = matrix + matrix_stride * 12; - - int tile_j = tile_N; // Tiles to process - - asm volatile ( - // Named SIMD registers according to the policy given above - // Registers into which to load the latter two columns of `x` - // NOTE: We need only load the latter three rows since we know that the - // first row is padded. - "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" - "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" - "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n" - - // Registers for storing X.T x (both A and B halves) - "AXTx11 .req v8\n" "BXTx13 .req v8\n" - "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" - "AXTx21 .req v10\n" "BXTx23 .req v10\n" - "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" - "AXTx31 .req v12\n" "BXTx33 .req v12\n" - "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" - "AXTx41 .req v14\n" "BXTx43 .req v14\n" - "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" - "AXTx13 .req v16\n" "BXTx11 .req v16\n" - "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" - "AXTx23 .req v18\n" "BXTx21 .req v18\n" - "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" - "AXTx33 .req v20\n" "BXTx31 .req v20\n" - "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" - "AXTx43 .req v22\n" "BXTx41 .req v22\n" - "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" - - // Result register (TODO Does using more registers yield better - // performance) - "U .req v24\n qU .req q24\n" - - // ---------------------------------------------------------------------- - // Head of loop - // Loads a complete 4x4 tile of x, computes X.T x, computes and stores - // `U = X.T x X`. Prepares for the 'A' half of the loop. - // NOTE: Since the first tile has the leftmost column padded we can - // skip 4 loads and 4 calculations for the matrix X.T x X. - - // Temporarily alias registers for computing the first (non-padded) - // column of x. - "x_22 .req v1\n qx_22 .req q1\n" - "x_32 .req v2\n qx_32 .req q2\n" - "x_42 .req v3\n qx_42 .req q3\n" - - "ldr qx_22, [%x[inptr1]]\n" - "ldr qx_32, [%x[inptr2]]\n" - "ldr qx_42, [%x[inptr3]]\n" - - "fneg BXTx12.4s, x_32.4s\n" - "fadd BXTx22.4s, x_22.4s, x_32.4s\n" - "fsub BXTx32.4s, x_32.4s, x_22.4s\n" - "fsub BXTx42.4s, x_22.4s, x_42.4s\n" - - ".unreq x_22\n .unreq qx_22\n" - ".unreq x_32\n .unreq qx_32\n" - ".unreq x_42\n .unreq qx_42\n" - - // Load and compute latter two columns of the first tile. Progress the - // input pointers (by three columns so that the each points are the - // second column of the next tile, that is, each points at the first - // column which must be read for the next tile. - "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" - "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" - "ldr qx_43, [%x[inptr3], %x[colstride1]]\n" - - "fneg BXTx13.4s, x_33.4s\n" - - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" - - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" - - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride2]]\n" - - "fneg BXTx14.4s, x_34.4s\n" - - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride3]\n" - - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride3]\n" - - "fsub BXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride3]\n" - - // Compute and store U for the first tile - // First row - "fneg U.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fneg U.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fneg U.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row, simultaneously load the first column of inputs for the - // next tile. - "fneg U.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - // Update the loop counter, subtract two to account for both the head and - // the tail. - "subs %x[tile_j], %x[tile_j], #2\n" - "beq 2f\n" // Jump to "A" tail if out of tiles - - // ---------------------------------------------------------------------- - "1:" - // Start part A - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fneg AXTx13.4s, x_33.4s\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "fsub AXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" - "fneg AXTx14.4s, x_34.4s\n" - "fadd AXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub AXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "fsub AXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride2]\n" - - // Compute and store U. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, AXTx12.4s, AXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, AXTx22.4s, AXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, AXTx32.4s, AXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, AXTx42.4s, AXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - "subs %x[tile_j], %x[tile_j], #1\n" - "beq 3f\n" // Jump to 'B' tail - - // Start part B - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fneg BXTx13.4s, x_33.4s\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" - "fneg BXTx14.4s, x_34.4s\n" - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "fsub BXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride2]\n" - - // Compute and store U. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - "subs %x[tile_j], %x[tile_j], #1\n" - "bne 1b\n" // Continue loop, otherwise flow into 'A' tail - - // ---------------------------------------------------------------------- - "2:" - // 'A' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fneg AXTx13.4s, x_33.4s\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "fsub AXTx43.4s, x_23.4s, x_43.4s\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" - - "b 4f\n" // Jump to end of function - - // ---------------------------------------------------------------------- - "3:" - // 'B' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fneg BXTx13.4s, x_33.4s\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" - - // ---------------------------------------------------------------------- - "4:" - // End of function - - // Clear names - ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" - ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" - ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n" - ".unreq AXTx11\n" ".unreq BXTx13\n" - ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" - ".unreq AXTx21\n" ".unreq BXTx23\n" - ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" - ".unreq AXTx31\n" ".unreq BXTx33\n" - ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" - ".unreq AXTx41\n" ".unreq BXTx43\n" - ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" - ".unreq AXTx13\n" ".unreq BXTx11\n" - ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" - ".unreq AXTx23\n" ".unreq BXTx21\n" - ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" - ".unreq AXTx33\n" ".unreq BXTx31\n" - ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" - ".unreq AXTx43\n" ".unreq BXTx41\n" - ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" - ".unreq U\n" ".unreq qU\n" - : [inptr1] "+r" (inptr0), // Offset to account for padded row - [inptr2] "+r" (inptr1), // Offset to account for padded row - [inptr3] "+r" (inptr2), // Offset to account for padded row - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [tile_j] "+r" (tile_j) // Tile counter - : [colstride1] "r" (1 * input_col_stride * sizeof(float)), - [colstride2] "r" (2 * input_col_stride * sizeof(float)), - [colstride3] "r" (3 * input_col_stride * sizeof(float)), - [mstride1] "r" (1 * matrix_stride * sizeof(float)), - [mstride2] "r" (2 * matrix_stride * sizeof(float)), - [mstride3] "r" (3 * matrix_stride * sizeof(float)), - [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24" - ); -} - -// Pad left, right and bottom by 1. -template <> -template <> -inline void Winograd2x2_3x3GemmInput<float>::process_tile_row<0, 1, 1, 1, 4>( - const int tile_N, - const float* const input, - const int input_row_stride, - const int input_col_stride, - float* const matrix, - const int matrix_stride, - const int matrix_row_stride -) { - const float *inptr0 = input; - const float *inptr1 = input + input_row_stride; - const float *inptr2 = input + input_row_stride * 2; - - float *outptr0 = matrix; - float *outptr4 = matrix + matrix_stride * 4; - float *outptr8 = matrix + matrix_stride * 8; - float *outptr12 = matrix + matrix_stride * 12; - - int tile_j = tile_N; // Tiles to process - - asm volatile ( - // Named SIMD registers according to the policy given above - // Registers into which to load the latter two columns of `x` - // NOTE: Bottom row is not required since since it is padded. - "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n" - "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" - "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" - - // Registers for storing X.T x (both A and B halves) - "AXTx11 .req v8\n" "BXTx13 .req v8\n" - "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" - "AXTx21 .req v10\n" "BXTx23 .req v10\n" - "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" - "AXTx31 .req v12\n" "BXTx33 .req v12\n" - "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" - "AXTx41 .req v14\n" "BXTx43 .req v14\n" - "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" - "AXTx13 .req v16\n" "BXTx11 .req v16\n" - "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" - "AXTx23 .req v18\n" "BXTx21 .req v18\n" - "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" - "AXTx33 .req v20\n" "BXTx31 .req v20\n" - "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" - "AXTx43 .req v22\n" "BXTx41 .req v22\n" - "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" - - // Result register (TODO Does using more registers yield better - // performance) - "U .req v24\n qU .req q24\n" - - // ---------------------------------------------------------------------- - // Head of loop - // Loads a complete 4x4 tile of x, computes X.T x, computes and stores - // `U = X.T x X`. Prepares for the 'A' half of the loop. - // NOTE: Since the first tile has the leftmost column padded we can - // skip 4 loads and 4 calculations for the matrix X.T x X. - - // Temporarily alias registers for computing the first (non-padded) - // column of x. - "x_12 .req v0\n qx_12 .req q0\n" - "x_22 .req v1\n qx_22 .req q1\n" - "x_32 .req v2\n qx_32 .req q2\n" - - "ldr qx_12, [%x[inptr0]]\n" - "ldr qx_22, [%x[inptr1]]\n" - "ldr qx_32, [%x[inptr2]]\n" - - "fsub BXTx12.4s, x_12.4s, x_32.4s\n" - "fadd BXTx22.4s, x_22.4s, x_32.4s\n" - "fsub BXTx32.4s, x_32.4s, x_22.4s\n" - "mov BXTx42.16b, x_22.16b\n" // Probably should do better - - ".unreq x_12\n .unreq qx_12\n" - ".unreq x_22\n .unreq qx_22\n" - ".unreq x_32\n .unreq qx_32\n" - - // Load and compute latter two columns of the first tile. Progress the - // input pointers (by three columns so that the each points are the - // second column of the next tile, that is, each points at the first - // column which must be read for the next tile. - "ldr qx_13, [%x[inptr0], %x[colstride1]]\n" - "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" - "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" - - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride2]]\n" - - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" - - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" - - "mov BXTx43.16b, x_23.16b\n" - "fsub BXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride3]\n" - - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride3]\n" - - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride3]\n" - - "mov BXTx44.16b, x_24.16b\n" - - // Compute and store U for the first tile - // First row - "fneg U.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fneg U.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fneg U.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row, simultaneously load the first column of inputs for the - // next tile. - "fneg U.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - // Update the loop counter, subtract two to account for both the head and - // the tail. - "subs %x[tile_j], %x[tile_j], #2\n" - "beq 2f\n" // Jump to "A" tail if out of tiles - - // ---------------------------------------------------------------------- - "1:" - // Start part A - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fsub AXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "mov AXTx43.16b, x_23.16b\n" - - "fsub AXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride2]\n" - "fadd AXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub AXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "mov AXTx44.16b, x_24.16b\n" - - // Compute and store U. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, AXTx12.4s, AXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, AXTx22.4s, AXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, AXTx32.4s, AXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, AXTx42.4s, AXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - "subs %x[tile_j], %x[tile_j], #1\n" - "beq 3f\n" // Jump to 'B' tail - - // Start part B - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "mov BXTx43.16b, x_23.16b\n" - - "fsub BXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride2]\n" - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "mov BXTx44.16b, x_24.16b\n" - - // Compute and store U. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - "subs %x[tile_j], %x[tile_j], #1\n" - "bne 1b\n" // Continue loop, otherwise flow into 'A' tail - - // ---------------------------------------------------------------------- - "2:" - // 'A' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fsub AXTx13.4s, x_13.4s, x_33.4s\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "mov AXTx43.16b, x_23.16b\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" - - "b 4f\n" // Jump to end of function - - // ---------------------------------------------------------------------- - "3:" - // 'B' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "mov BXTx43.16b, x_23.16b\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" - - // ---------------------------------------------------------------------- - "4:" - // End of function - - // Clear names - ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n" - ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" - ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" - ".unreq AXTx11\n" ".unreq BXTx13\n" - ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" - ".unreq AXTx21\n" ".unreq BXTx23\n" - ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" - ".unreq AXTx31\n" ".unreq BXTx33\n" - ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" - ".unreq AXTx41\n" ".unreq BXTx43\n" - ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" - ".unreq AXTx13\n" ".unreq BXTx11\n" - ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" - ".unreq AXTx23\n" ".unreq BXTx21\n" - ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" - ".unreq AXTx33\n" ".unreq BXTx31\n" - ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" - ".unreq AXTx43\n" ".unreq BXTx41\n" - ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" - ".unreq U\n" ".unreq qU\n" - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [tile_j] "+r" (tile_j) // Tile counter - : [colstride1] "r" (1 * input_col_stride * sizeof(float)), - [colstride2] "r" (2 * input_col_stride * sizeof(float)), - [colstride3] "r" (3 * input_col_stride * sizeof(float)), - [mstride1] "r" (1 * matrix_stride * sizeof(float)), - [mstride2] "r" (2 * matrix_stride * sizeof(float)), - [mstride3] "r" (3 * matrix_stride * sizeof(float)), - [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24" - ); -} -} -#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp deleted file mode 100644 index ad1ad55291..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp +++ /dev/null @@ -1,961 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include "../input_2x2_3x3.hpp" - -#ifdef __aarch64__ - -namespace winograd { - -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 0, 0, 0, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 1*input_row_stride; - auto inptr2 = inptr0 + 2*input_row_stride; - auto inptr3 = inptr0 + 3*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_11 .req v0\n" "qX_11 .req q0\n" - "X_12 .req v1\n" "qX_12 .req q1\n" - "X_13 .req v2\n" "qX_13 .req q2\n" - "X_14 .req v3\n" "qX_14 .req q3\n" - "X_21 .req v4\n" "qX_21 .req q4\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_24 .req v7\n" "qX_24 .req q7\n" - "X_31 .req v8\n" "qX_31 .req q8\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_34 .req v11\n" "qX_34 .req q11\n" - "X_41 .req v12\n" "qX_41 .req q12\n" - "X_42 .req v13\n" "qX_42 .req q13\n" - "X_43 .req v14\n" "qX_43 .req q14\n" - "X_44 .req v15\n" "qX_44 .req q15\n" - "xX_11 .req v16\n" - "xX_12 .req v17\n" - "xX_13 .req v18\n" - "xX_14 .req v19\n" - "xX_21 .req v20\n" - "xX_22 .req v21\n" - "xX_23 .req v22\n" - "xX_24 .req v23\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req v27\n" - "xX_41 .req v28\n" - "xX_42 .req v29\n" - "xX_43 .req v30\n" - "xX_44 .req v31\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_11, [%x[inptr0]]\n" - "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" - "ldr qX_14, [%x[inptr0], %x[colstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qX_21, [%x[inptr1]]\n" - "fsub xX_11.4s, x_11.4s, x_13.4s\n" - "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" - "fadd xX_12.4s, x_12.4s, x_13.4s\n" - "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" - "fsub xX_13.4s, x_13.4s, x_12.4s\n" - "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" - "fsub xX_14.4s, x_12.4s, x_14.4s\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qX_31, [%x[inptr2]]\n" - "fsub xX_21.4s, x_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" - "fsub xX_24.4s, x_22.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "ldr qX_41, [%x[inptr3]]\n" - "fsub xX_31.4s, x_31.4s, x_33.4s\n" - "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "ldr qX_44, [%x[inptr3], %x[colstride3]]\n" - "fsub xX_34.4s, x_32.4s, x_34.4s\n" - "add %x[inptr3], %x[inptr3], #0x10\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fsub xX_41.4s, x_41.4s, x_43.4s\n" - - "fsub U.4s, xX_11.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fsub U.4s, xX_12.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, xX_13.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, xX_14.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd xX_42.4s, x_42.4s, x_43.4s\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub xX_43.4s, x_43.4s, x_42.4s\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fsub xX_44.4s, x_42.4s, x_44.4s\n" - - "fsub U.4s, xX_21.4s, xX_41.4s\n" - "str qU, [%x[outptr12]]\n" - "fsub U.4s, xX_22.4s, xX_42.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, xX_23.4s, xX_43.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "fsub U.4s, xX_24.4s, xX_44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq qU\n" - ".unreq U\n" - ".unreq X_11\n" ".unreq qX_11\n" - ".unreq X_12\n" ".unreq qX_12\n" - ".unreq X_13\n" ".unreq qX_13\n" - ".unreq X_14\n" ".unreq qX_14\n" - ".unreq X_21\n" ".unreq qX_21\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_24\n" ".unreq qX_24\n" - ".unreq X_31\n" ".unreq qX_31\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_34\n" ".unreq qX_34\n" - ".unreq X_41\n" ".unreq qX_41\n" - ".unreq X_42\n" ".unreq qX_42\n" - ".unreq X_43\n" ".unreq qX_43\n" - ".unreq X_44\n" ".unreq qX_44\n" - ".unreq xX_11\n" - ".unreq xX_12\n" - ".unreq xX_13\n" - ".unreq xX_14\n" - ".unreq xX_21\n" - ".unreq xX_22\n" - ".unreq xX_23\n" - ".unreq xX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - ".unreq xX_41\n" - ".unreq xX_42\n" - ".unreq xX_43\n" - ".unreq xX_44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [inptr3] "+r" (inptr3), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [colstride3] "r" (input_col_stride * sizeof(float) * 3), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} - -// Pad top by 1 -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<1, 0, 0, 0, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 0*input_row_stride; - auto inptr2 = inptr0 + 1*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_21 .req v4\n" "qX_21 .req q4\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_24 .req v7\n" "qX_24 .req q7\n" - "X_31 .req v8\n" "qX_31 .req q8\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_34 .req v11\n" "qX_34 .req q11\n" - "X_41 .req v12\n" "qX_41 .req q12\n" - "X_42 .req v13\n" "qX_42 .req q13\n" - "X_43 .req v14\n" "qX_43 .req q14\n" - "X_44 .req v15\n" "qX_44 .req q15\n" - "xX_21 .req v20\n" - "xX_22 .req v21\n" - "xX_23 .req v22\n" - "xX_24 .req v23\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req v27\n" - "xX_41 .req v28\n" - "xX_42 .req v29\n" - "xX_43 .req v30\n" - "xX_44 .req v31\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_21, [%x[inptr1]]\n" - "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" - "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" - "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qX_31, [%x[inptr2]]\n" - "fsub xX_21.4s, x_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" - "fsub xX_24.4s, x_22.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "ldr qX_41, [%x[inptr3]]\n" - "fsub xX_31.4s, x_31.4s, x_33.4s\n" - "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "ldr qX_44, [%x[inptr3], %x[colstride3]]\n" - "fsub xX_34.4s, x_32.4s, x_34.4s\n" - "add %x[inptr3], %x[inptr3], #0x10\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fsub xX_41.4s, x_41.4s, x_43.4s\n" - - "fneg U.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fneg U.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fneg U.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fneg U.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd xX_42.4s, x_42.4s, x_43.4s\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub xX_43.4s, x_43.4s, x_42.4s\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fsub xX_44.4s, x_42.4s, x_44.4s\n" - - "fsub U.4s, xX_21.4s, xX_41.4s\n" - "str qU, [%x[outptr12]]\n" - "fsub U.4s, xX_22.4s, xX_42.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, xX_23.4s, xX_43.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "fsub U.4s, xX_24.4s, xX_44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq qU\n" - ".unreq U\n" - ".unreq X_21\n" ".unreq qX_21\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_24\n" ".unreq qX_24\n" - ".unreq X_31\n" ".unreq qX_31\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_34\n" ".unreq qX_34\n" - ".unreq X_41\n" ".unreq qX_41\n" - ".unreq X_42\n" ".unreq qX_42\n" - ".unreq X_43\n" ".unreq qX_43\n" - ".unreq X_44\n" ".unreq qX_44\n" - ".unreq xX_21\n" - ".unreq xX_22\n" - ".unreq xX_23\n" - ".unreq xX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - ".unreq xX_41\n" - ".unreq xX_42\n" - ".unreq xX_43\n" - ".unreq xX_44\n" - - : [inptr1] "+r" (inptr0), // Offset for missing row - [inptr2] "+r" (inptr1), // Offset for missing row - [inptr3] "+r" (inptr2), // Offset for missing row - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [colstride3] "r" (input_col_stride * sizeof(float) * 3), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} - -// Pad left by 1 -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 1, 0, 0, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 1*input_row_stride; - auto inptr2 = inptr0 + 2*input_row_stride; - auto inptr3 = inptr0 + 3*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_12 .req v1\n" "qX_12 .req q1\n" - "X_13 .req v2\n" "qX_13 .req q2\n" - "X_14 .req v3\n" "qX_14 .req q3\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_24 .req v7\n" "qX_24 .req q7\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_34 .req v11\n" "qX_34 .req q11\n" - "X_42 .req v13\n" "qX_42 .req q13\n" - "X_43 .req v14\n" "qX_43 .req q14\n" - "X_44 .req v15\n" "qX_44 .req q15\n" - "xX_11 .req v16\n" - "xX_12 .req v17\n" - "xX_13 .req v18\n" - "xX_14 .req v19\n" - "xX_21 .req v20\n" - "xX_22 .req v21\n" - "xX_23 .req v22\n" - "xX_24 .req v23\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req v27\n" - "xX_41 .req v28\n" - "xX_42 .req v29\n" - "xX_43 .req v30\n" - "xX_44 .req v31\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_12, [%x[inptr0]]\n" - "ldr qX_13, [%x[inptr0], %x[colstride1]]\n" - "ldr qX_14, [%x[inptr0], %x[colstride2]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "fneg xX_11.4s, x_13.4s\n" - "ldr qX_22, [%x[inptr1]]\n" - "fadd xX_12.4s, x_12.4s, x_13.4s\n" - "ldr qX_23, [%x[inptr1], %x[colstride1]]\n" - "fsub xX_13.4s, x_13.4s, x_12.4s\n" - "ldr qX_24, [%x[inptr1], %x[colstride2]]\n" - "fsub xX_14.4s, x_12.4s, x_14.4s\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "fneg xX_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride1]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "ldr qX_34, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_24.4s, x_22.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "fneg xX_31.4s, x_33.4s\n" - "ldr qX_42, [%x[inptr3]]\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "ldr qX_43, [%x[inptr3], %x[colstride1]]\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "ldr qX_44, [%x[inptr3], %x[colstride2]]\n" - "fsub xX_34.4s, x_32.4s, x_34.4s\n" - "add %x[inptr3], %x[inptr3], #0x10\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fneg xX_41.4s, x_43.4s\n" - - "fsub U.4s, xX_11.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fsub U.4s, xX_12.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, xX_13.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, xX_14.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd xX_42.4s, x_42.4s, x_43.4s\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub xX_43.4s, x_43.4s, x_42.4s\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fsub xX_44.4s, x_42.4s, x_44.4s\n" - - "fsub U.4s, xX_21.4s, xX_41.4s\n" - "str qU, [%x[outptr12]]\n" - "fsub U.4s, xX_22.4s, xX_42.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, xX_23.4s, xX_43.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "fsub U.4s, xX_24.4s, xX_44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq X_12\n" ".unreq qX_12\n" - ".unreq X_13\n" ".unreq qX_13\n" - ".unreq X_14\n" ".unreq qX_14\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_24\n" ".unreq qX_24\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_34\n" ".unreq qX_34\n" - ".unreq X_42\n" ".unreq qX_42\n" - ".unreq X_43\n" ".unreq qX_43\n" - ".unreq X_44\n" ".unreq qX_44\n" - ".unreq xX_11\n" - ".unreq xX_12\n" - ".unreq xX_13\n" - ".unreq xX_14\n" - ".unreq xX_21\n" - ".unreq xX_22\n" - ".unreq xX_23\n" - ".unreq xX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - ".unreq xX_41\n" - ".unreq xX_42\n" - ".unreq xX_43\n" - ".unreq xX_44\n" - ".unreq U\n" - ".unreq qU\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [inptr3] "+r" (inptr3), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} - -// Pad bottom by 1 -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 0, 1, 0, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 1*input_row_stride; - auto inptr2 = inptr0 + 2*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_11 .req v0\n" "qX_11 .req q0\n" - "X_12 .req v1\n" "qX_12 .req q1\n" - "X_13 .req v2\n" "qX_13 .req q2\n" - "X_14 .req v3\n" "qX_14 .req q3\n" - "X_21 .req v4\n" "qX_21 .req q4\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_24 .req v7\n" "qX_24 .req q7\n" - "X_31 .req v8\n" "qX_31 .req q8\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_34 .req v11\n" "qX_34 .req q11\n" - "xX_11 .req v16\n" - "xX_12 .req v17\n" - "xX_13 .req v18\n" - "xX_14 .req v19\n" - "xX_21 .req v20\n" "qxX_21 .req q20\n" - "xX_22 .req v21\n" "qxX_22 .req q21\n" - "xX_23 .req v22\n" "qxX_23 .req q22\n" - "xX_24 .req v23\n" "qxX_24 .req q23\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req v27\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_11, [%x[inptr0]]\n" - "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" - "ldr qX_14, [%x[inptr0], %x[colstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qX_21, [%x[inptr1]]\n" - "fsub xX_11.4s, x_11.4s, x_13.4s\n" - "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" - "fadd xX_12.4s, x_12.4s, x_13.4s\n" - "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" - "fsub xX_13.4s, x_13.4s, x_12.4s\n" - "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" - "fsub xX_14.4s, x_12.4s, x_14.4s\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qX_31, [%x[inptr2]]\n" - "fsub xX_21.4s, x_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" - "fsub xX_24.4s, x_22.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "fsub xX_31.4s, x_31.4s, x_33.4s\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "fsub xX_34.4s, x_32.4s, x_34.4s\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fsub U.4s, xX_11.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fsub U.4s, xX_12.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, xX_13.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, xX_14.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "str qxX_21, [%x[outptr12]]\n" - "str qxX_22, [%x[outptr12], %x[mstride1]]\n" - "str qxX_23, [%x[outptr12], %x[mstride2]]\n" - "str qxX_24, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq qU\n" - ".unreq U\n" - ".unreq X_11\n" ".unreq qX_11\n" - ".unreq X_12\n" ".unreq qX_12\n" - ".unreq X_13\n" ".unreq qX_13\n" - ".unreq X_14\n" ".unreq qX_14\n" - ".unreq X_21\n" ".unreq qX_21\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_24\n" ".unreq qX_24\n" - ".unreq X_31\n" ".unreq qX_31\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_34\n" ".unreq qX_34\n" - ".unreq xX_11\n" - ".unreq xX_12\n" - ".unreq xX_13\n" - ".unreq xX_14\n" - ".unreq xX_21\n" ".unreq qxX_21\n" - ".unreq xX_22\n" ".unreq qxX_22\n" - ".unreq xX_23\n" ".unreq qxX_23\n" - ".unreq xX_24\n" ".unreq qxX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [colstride3] "r" (input_col_stride * sizeof(float) * 3), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} - -// Pad right by 1 -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 0, 0, 1, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 1*input_row_stride; - auto inptr2 = inptr0 + 2*input_row_stride; - auto inptr3 = inptr0 + 3*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_11 .req v0\n" "qX_11 .req q0\n" - "X_12 .req v1\n" "qX_12 .req q1\n" - "X_13 .req v2\n" "qX_13 .req q2\n" - "X_21 .req v4\n" "qX_21 .req q4\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_31 .req v8\n" "qX_31 .req q8\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_41 .req v12\n" "qX_41 .req q12\n" - "X_42 .req v13\n" "qX_42 .req q13\n" - "X_43 .req v14\n" "qX_43 .req q14\n" - "xX_11 .req v16\n" - "xX_12 .req v17\n" - "xX_13 .req v18\n" - "xX_14 .req x_12\n" - "xX_21 .req v20\n" - "xX_22 .req v21\n" - "xX_23 .req v22\n" - "xX_24 .req x_22\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req x_32\n" - "xX_41 .req v28\n" - "xX_42 .req v29\n" - "xX_43 .req v30\n" - "xX_44 .req x_42\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_11, [%x[inptr0]]\n" - "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qX_21, [%x[inptr1]]\n" - "fsub xX_11.4s, x_11.4s, x_13.4s\n" - "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" - "fadd xX_12.4s, x_12.4s, x_13.4s\n" - "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" - "fsub xX_13.4s, x_13.4s, x_12.4s\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qX_31, [%x[inptr2]]\n" - "fsub xX_21.4s, x_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "ldr qX_41, [%x[inptr3]]\n" - "fsub xX_31.4s, x_31.4s, x_33.4s\n" - "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "add %x[inptr3], %x[inptr3], #0x10\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fsub xX_41.4s, x_41.4s, x_43.4s\n" - - "fsub U.4s, xX_11.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fsub U.4s, xX_12.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, xX_13.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, xX_14.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd xX_42.4s, x_42.4s, x_43.4s\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub xX_43.4s, x_43.4s, x_42.4s\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fsub U.4s, xX_21.4s, xX_41.4s\n" - "str qU, [%x[outptr12]]\n" - "fsub U.4s, xX_22.4s, xX_42.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, xX_23.4s, xX_43.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "fsub U.4s, xX_24.4s, xX_44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq qU\n" - ".unreq U\n" - ".unreq X_11\n" ".unreq qX_11\n" - ".unreq X_12\n" ".unreq qX_12\n" - ".unreq X_13\n" ".unreq qX_13\n" - ".unreq X_21\n" ".unreq qX_21\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_31\n" ".unreq qX_31\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_41\n" ".unreq qX_41\n" - ".unreq X_42\n" ".unreq qX_42\n" - ".unreq X_43\n" ".unreq qX_43\n" - ".unreq xX_11\n" - ".unreq xX_12\n" - ".unreq xX_13\n" - ".unreq xX_14\n" - ".unreq xX_21\n" - ".unreq xX_22\n" - ".unreq xX_23\n" - ".unreq xX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - ".unreq xX_41\n" - ".unreq xX_42\n" - ".unreq xX_43\n" - ".unreq xX_44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [inptr3] "+r" (inptr3), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} -} -#endif diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3_fp32.cpp new file mode 100644 index 0000000000..381ae92182 --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3_fp32.cpp @@ -0,0 +1,409 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "transforms/input.hpp" +#include "winograd_gemm.hpp" +#include "arm.hpp" + +namespace winograd +{ + +using Transform = WinogradGEMM<2, 2, 3, 3>::InputTransform<float>; + +/****************************************************************************** + * Cost methods for the input transform. + * ===================================== + */ +template <> +template <> +int Transform::ops_performed(const Tensor4DShape &input_shape) +{ + // NOTE: Cost in FLOPs rather than instructions or uops. + const int tile_M = iceildiv(input_shape.n_rows, inner_tile_rows); + const int tile_N = iceildiv(input_shape.n_cols, inner_tile_cols); + return 16 * 16 * tile_M * tile_N * input_shape.n_channels; +} +/*****************************************************************************/ + +/***************************************************************************** +* F(2x2, 3x3) implies the use of a 4x4 input tile. Such tiles can require a +* variety of padding types. For example, tiles at the top and left of an image +* can require one row or column of padding on their top and left sides if the +* padding type is SAME (where X represents a padded value): +* +* _______ _______ +* |X X X X| |X X X X| +* |X | | | . . . +* |X | | | +* |X______| |_______| +* _______ +* |X | . +* |X | . . . . +* |X | . +* |X______| +* +* For tiles near the right or bottom of the image it is more complicated. Such +* tiles might require padding by 0 or 1 rows or columns if the padding type is +* VALID or 1 or 2 rows or columns if the padding type is SAME: +* +* _______ _______ _______ _______ +* |X X X X| |X X X X| |X X X X| |X X X X| +* |X | | | | X| | X X| +* |X | | | | X| | X X| +* |X______| |_______| |______X| |____X_X| +* _______ _______ _______ _______ +* |X | | | | X| | X X| +* |X | | | | X| | X X| +* |X | | | | X| | X X| +* |X______| |_______| |______X| |____X_X| +* _______ _______ _______ _______ +* |X | | | | X| | X X| +* |X | | | | X| | X X| +* |X | | | | X| | X X| +* |X_X_X_X| |X_X_X_X| |X_X_X_X| |X_X_X_X| +* _______ _______ _______ _______ +* |X | | | | X| | X X| +* |X | | | | X| | X X| +* |X X X X| |X X X X| |X X X X| |X X X X| +* |X_X_X_X| |X_X_X_X| |X_X_X_X| |X_X_X_X| +* +* Additional tiles are required for especially small input images. +* +* Build an array of the specialised methods that deal with each of the +* different padding combinations which may be required. These padding +* constraints are the space: +* +* Padding top in {0, 1} +* Padding left in {0, 1} +* Padding bottom in {0, 1, 2} +* Padding right in {0, 1, 2} +*/ +template <> +template <> +template <int pad_top, int pad_left, int pad_bottom, int pad_right> +void Transform::process_tile( + int n_channels, + const float* const input_base, + const int input_row_stride, + const int input_col_stride, + float* const matrix_base, + const int matrix_stride +) +{ + constexpr int inner_tile_i = 4, inner_tile_j = 4; + constexpr int cells_i = inner_tile_i - pad_bottom; + constexpr int cells_j = inner_tile_i - pad_right; + + float *outptr = matrix_base; + + // Get pointers into the input tile + const float *x_ptrs[inner_tile_i][inner_tile_j]; + for (int i = pad_top, xi = 0; i < cells_i; i++, xi++) + { + // Get a pointer into the row + const float* const row_ptr = input_base + xi*input_row_stride; + + for (int j = pad_left, xj = 0; j < cells_j; j++, xj++) + { + x_ptrs[i][j] = row_ptr + xj*input_col_stride; + } + } + + // Matrices used/computed in this kernel. + float x[inner_tile_i][inner_tile_j]; + float XTx[inner_tile_i][inner_tile_j]; + float U[inner_tile_i][inner_tile_j]; + + for (int i = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++) + { + x[i][j] = XTx[i][j] = 0.0f; + } + } + + // Perform the Winograd input transformation for each channel in the input + // tensor. + int channels_remaining = n_channels; +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used/computed in this kernel. + float32x4_t x[inner_tile_i][inner_tile_j]; + float32x4_t XTx[inner_tile_i][inner_tile_j]; + float32x4_t U[inner_tile_i][inner_tile_j]; + + for (int i = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++) + { + x[i][j] = vdupq_n_f32(0.0f); + XTx[i][j] = vdupq_n_f32(0.0f); + } + } + + // Load x + for (int i = pad_top; i < cells_i; i++) + { + for (int j = pad_left; j < cells_j; j++) + { + x[i][j] = vld1q_f32(x_ptrs[i][j]); + x_ptrs[i][j] += 4; + } + } + + // Compute XT . x + for (int j = pad_left; j < cells_j; j++) + { + // XTx[0][j] = x[0][j] - x[2][j]; + XTx[0][j] = vsubq_f32(x[0][j], x[2][j]); + + // XTx[1][j] = x[1][j] + x[2][j]; + XTx[1][j] = vaddq_f32(x[1][j], x[2][j]); + + // XTx[2][j] = x[2][j] - x[1][j]; + XTx[2][j] = vsubq_f32(x[2][j], x[1][j]); + + // XTx[3][j] = x[1][j] - x[3][j]; + XTx[3][j] = vsubq_f32(x[1][j], x[3][j]); + } + + // Compute U = XT . x . X + for (int i = 0; i < inner_tile_i; i++) + { + // U[i][0] = XTx[i][0] - XTx[i][2]; + U[i][0] = vsubq_f32(XTx[i][0], XTx[i][2]); + + // U[i][1] = XTx[i][1] + XTx[i][2]; + U[i][1] = vaddq_f32(XTx[i][1], XTx[i][2]); + + // U[i][2] = XTx[i][2] - XTx[i][1]; + U[i][2] = vsubq_f32(XTx[i][2], XTx[i][1]); + + // U[i][3] = XTx[i][1] - XTx[i][3]; + U[i][3] = vsubq_f32(XTx[i][1], XTx[i][3]); + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++, m++) + { + vst1q_f32(outptr + m*matrix_stride, U[i][j]); + } + } + outptr += 4; + } +#endif // __aarch64__ +#ifdef __arm_any__ + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used/computed in this kernel. + float32x2_t x[inner_tile_i][inner_tile_j]; + float32x2_t XTx[inner_tile_i][inner_tile_j]; + float32x2_t U[inner_tile_i][inner_tile_j]; + + for (int i = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++) + { + x[i][j] = vdup_n_f32(0.0f); + XTx[i][j] = vdup_n_f32(0.0f); + } + } + + // Load x + for (int i = pad_top; i < cells_i; i++) + { + for (int j = pad_left; j < cells_j; j++) + { + x[i][j] = vld1_f32(x_ptrs[i][j]); + x_ptrs[i][j] += 2; + } + } + + // Compute XT . x + for (int j = pad_left; j < cells_j; j++) + { + // XTx[0][j] = x[0][j] - x[2][j]; + XTx[0][j] = vsub_f32(x[0][j], x[2][j]); + + // XTx[1][j] = x[1][j] + x[2][j]; + XTx[1][j] = vadd_f32(x[1][j], x[2][j]); + + // XTx[2][j] = x[2][j] - x[1][j]; + XTx[2][j] = vsub_f32(x[2][j], x[1][j]); + + // XTx[3][j] = x[1][j] - x[3][j]; + XTx[3][j] = vsub_f32(x[1][j], x[3][j]); + } + + // Compute U = XT . x . X + for (int i = 0; i < inner_tile_i; i++) + { + // U[i][0] = XTx[i][0] - XTx[i][2]; + U[i][0] = vsub_f32(XTx[i][0], XTx[i][2]); + + // U[i][1] = XTx[i][1] + XTx[i][2]; + U[i][1] = vadd_f32(XTx[i][1], XTx[i][2]); + + // U[i][2] = XTx[i][2] - XTx[i][1]; + U[i][2] = vsub_f32(XTx[i][2], XTx[i][1]); + + // U[i][3] = XTx[i][1] - XTx[i][3]; + U[i][3] = vsub_f32(XTx[i][1], XTx[i][3]); + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++, m++) + { + vst1_f32(outptr + m*matrix_stride, U[i][j]); + } + } + outptr += 2; + } +#endif // __arm_any__ + for (; channels_remaining; channels_remaining--) + { + // Load x + for (int i = pad_top; i < cells_i; i++) + { + for (int j = pad_left; j < cells_j; j++) + { + x[i][j] = *(x_ptrs[i][j]++); + } + } + + // Compute XT . x + for (int j = pad_left; j < cells_j; j++) + { + XTx[0][j] = x[0][j] - x[2][j]; + XTx[1][j] = x[1][j] + x[2][j]; + XTx[2][j] = x[2][j] - x[1][j]; + XTx[3][j] = x[1][j] - x[3][j]; + } + + // Compute U = XT . x . X + for (int i = 0; i < inner_tile_i; i++) + { + U[i][0] = XTx[i][0] - XTx[i][2]; + U[i][1] = XTx[i][1] + XTx[i][2]; + U[i][2] = XTx[i][2] - XTx[i][1]; + U[i][3] = XTx[i][1] - XTx[i][3]; + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++, m++) + { + *(outptr + m*matrix_stride) = U[i][j]; + } + } + outptr++; + } +} + +template <> +template <> +const Transform::TileFn Transform::tile_fns[2][2][max_pad_bottom][max_pad_right] = +{ + { + { + { + Transform::template process_tile<0, 0, 0, 0>, // No padding + Transform::template process_tile<0, 0, 0, 1>, // Right + Transform::template process_tile<0, 0, 0, 2>, // Right + }, + { + Transform::template process_tile<0, 0, 1, 0>, // Bottom + Transform::template process_tile<0, 0, 1, 1>, // Bottom-right + Transform::template process_tile<0, 0, 1, 2>, // Bottom-right + }, + { + Transform::template process_tile<0, 0, 2, 0>, // Bottom + Transform::template process_tile<0, 0, 2, 1>, // Bottom-right + Transform::template process_tile<0, 0, 2, 2>, // Bottom-right + } + }, + { + { + Transform::template process_tile<0, 1, 0, 0>, // Left + Transform::template process_tile<0, 1, 0, 1>, // Left AND right + Transform::template process_tile<0, 1, 0, 2>, // Left AND right + }, + { + Transform::template process_tile<0, 1, 1, 0>, // Left-bottom + Transform::template process_tile<0, 1, 1, 1>, // Left, bottom AND right + Transform::template process_tile<0, 1, 1, 2>, // Left, bottom AND right + }, + { + Transform::template process_tile<0, 1, 2, 0>, // Left-bottom + Transform::template process_tile<0, 1, 2, 1>, // Left, bottom AND right + Transform::template process_tile<0, 1, 2, 2>, // Left, bottom AND right + } + }, + }, + { + { + { + Transform::template process_tile<1, 0, 0, 0>, // Top + Transform::template process_tile<1, 0, 0, 1>, // Top-right + Transform::template process_tile<1, 0, 0, 2>, // Top-right + }, + { + Transform::template process_tile<1, 0, 1, 0>, // Top AND bottom + Transform::template process_tile<1, 0, 1, 1>, // Top, bottom AND right + Transform::template process_tile<1, 0, 1, 2>, // Top, bottom AND right + }, + { + Transform::template process_tile<1, 0, 2, 0>, // Top AND bottom + Transform::template process_tile<1, 0, 2, 1>, // Top, bottom AND right + Transform::template process_tile<1, 0, 2, 2>, // Top, bottom AND right + } + }, + { + { + Transform::template process_tile<1, 1, 0, 0>, // Top-left + Transform::template process_tile<1, 1, 0, 1>, // Top, left AND right + Transform::template process_tile<1, 1, 0, 2>, // Top, left AND right + }, + { + Transform::template process_tile<1, 1, 1, 0>, // Top, left AND bottom + Transform::template process_tile<1, 1, 1, 1>, // All padded + Transform::template process_tile<1, 1, 1, 2>, // All padded + }, + { + Transform::template process_tile<1, 1, 2, 0>, // Top, left AND bottom + Transform::template process_tile<1, 1, 2, 1>, // All padded + Transform::template process_tile<1, 1, 2, 2>, // All padded + } + } + } +}; + +template struct WinogradGEMM<2, 2, 3, 3>::InputTransform<float>; +} // namespace winograd diff --git a/src/core/NEON/kernels/winograd/transforms/input_4x4_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/input_4x4_3x3_fp32.cpp new file mode 100644 index 0000000000..477aaaf34e --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/input_4x4_3x3_fp32.cpp @@ -0,0 +1,486 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "transforms/input.hpp" +#include "winograd_gemm.hpp" +#include "arm.hpp" + +namespace winograd +{ + +using Transform = WinogradGEMM<4, 4, 3, 3>::InputTransform<float>; + +template <> +template <> +int Transform::ops_performed(const Tensor4DShape &input_shape) +{ + // NOTE: Cost in FLOPs rather than instructions or uops. + const int tile_M = iceildiv(input_shape.n_rows, inner_tile_rows); + const int tile_N = iceildiv(input_shape.n_cols, inner_tile_cols); + return 12 * 24 * tile_M * tile_N * input_shape.n_channels; +} + +/* F(4x4, 3x3) implies the use of a 6x6 input tile. Such tiles can require a +* variety of padding types. For example, tiles at the top and left of an +* image can require one row or column of padding on their top and left sides +* if the padding type is SAME (where X represents a padded value): +* +* ___________ ___________ +* |X X X X X X| |X X X X X X| +* |X | | | +* |X | | | +* |X | | | +* |X | | | +* |X__________| |___________| +* ___________ +* |X | +* |X | +* |X | +* |X | +* |X | +* |X__________| +* +* For tiles near the right or bottom of the image it is more complicated. +* Such tiles might require padding by 0, 1, 2 or 3 rows or columns if the +* padding type is VALID or 1, 2, 3 or 4 rows or columns if the padding +* type is SAME. +* +* Build an array of the specialised methods that deal with each of the +* different padding combinations which may be required. These padding +* constraints are the space: +* +* Padding top in {0, 1} +* Padding left in {0, 1} +* Padding bottom in {0, 1, 2, 3, 4} +* Padding right in {0, 1, 2, 3, 4} +*/ +template <> +template <> +template <int pad_top, int pad_left, int pad_bottom, int pad_right> +void Transform::process_tile( + int n_channels, + const float* const input_base, + const int input_row_stride, + const int input_col_stride, + float* const matrix_base, + const int matrix_stride +) +{ + constexpr int cells_i = 6 - pad_bottom; + constexpr int cells_j = 6 - pad_right; + + float *outptr = matrix_base; + + // Get pointers into the input tile + const float *x_ptrs[6][6]; + for (int i = pad_top, xi = 0; i < cells_i; i++, xi++) + { + // Get a pointer into the row + const float* const row_ptr = input_base + xi*input_row_stride; + + for (int j = pad_left, xj = 0; j < cells_j; j++, xj++) + { + x_ptrs[i][j] = row_ptr + xj*input_col_stride; + } + } + + // Matrices used/computed in this kernel. + float x[6][6], XTx[6][6], U[6][6]; + for (int i = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++) + { + x[i][j] = XTx[i][j] = 0.0f; + } + } + + // Perform the Winograd input transformation for each channel in the input + // tensor. + int channels_remaining = n_channels; +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used/computed in this kernel + float32x4_t x[6][6], XTx[6][6], U[6][6]; + for (int i = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++) + { + x[i][j] = vdupq_n_f32(0.0f); + XTx[i][j] = vdupq_n_f32(0.0f); + } + } + + // Read a 6x6 tile in the Winograd domain + for (int i = pad_top; i < cells_i; i++) + { + for (int j = pad_left; j < cells_j; j++) + { + x[i][j] = vld1q_f32(x_ptrs[i][j]); + x_ptrs[i][j] += 4; + } + } + + // Compute XT . x + for (int j = pad_left; j < cells_j; j++) + { + // XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j]; + XTx[0][j] = vmlsq_n_f32(vmlaq_n_f32(x[4][j], x[0][j], 4.0f), x[2][j], 5.0f); + + // XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j]; + XTx[1][j] = vmlsq_n_f32(vaddq_f32(x[3][j], x[4][j]), vaddq_f32(x[1][j], x[2][j]), 4.0f); + + // XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j]; + XTx[2][j] = vmlaq_n_f32(vsubq_f32(x[4][j], x[3][j]), vsubq_f32(x[1][j], x[2][j]), 4.0f); + + // XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j]; + XTx[3][j] = vmlaq_n_f32(vsubq_f32(x[4][j], x[2][j]), vsubq_f32(x[3][j], x[1][j]), 2.0f); + + // XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j]; + XTx[4][j] = vmlaq_n_f32(vsubq_f32(x[4][j], x[2][j]), vsubq_f32(x[1][j], x[3][j]), 2.0f); + + // XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j]; + XTx[5][j] = vmlsq_n_f32(vmlaq_n_f32(x[5][j], x[1][j], 4.0f), x[3][j], 5.0f); + } + + // Compute U = XT . x . X + for (int i = 0; i < 6; i++) + { + // U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4]; + U[i][0] = vmlsq_n_f32(vmlaq_n_f32(XTx[i][4], XTx[i][0], 4.0f), XTx[i][2], 5.0f); + + // U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4]; + U[i][1] = vmlsq_n_f32(vaddq_f32(XTx[i][3], XTx[i][4]), vaddq_f32(XTx[i][1], XTx[i][2]), 4.0f); + + // U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4]; + U[i][2] = vmlaq_n_f32(vsubq_f32(XTx[i][4], XTx[i][3]), vsubq_f32(XTx[i][1], XTx[i][2]), 4.0f); + + // U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4]; + U[i][3] = vmlaq_n_f32(vsubq_f32(XTx[i][4], XTx[i][2]), vsubq_f32(XTx[i][3], XTx[i][1]), 2.0f); + + // U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4]; + U[i][4] = vmlaq_n_f32(vsubq_f32(XTx[i][4], XTx[i][2]), vsubq_f32(XTx[i][1], XTx[i][3]), 2.0f); + + // U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5]; + U[i][5] = vmlsq_n_f32(vmlaq_n_f32(XTx[i][5], XTx[i][1], 4.0f), XTx[i][3], 5.0f); + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + vst1q_f32(outptr + m*matrix_stride, U[i][j]); + } + } + outptr += 4; + } +#endif // __aarch64__ +#ifdef __arm_any__ + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used/computed in this kernel + float32x2_t x[6][6], XTx[6][6], U[6][6]; + for (int i = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++) + { + x[i][j] = vdup_n_f32(0.0f); + XTx[i][j] = vdup_n_f32(0.0f); + } + } + + // Read a 6x6 tile in the Winograd domain + for (int i = pad_top; i < cells_i; i++) + { + for (int j = pad_left; j < cells_j; j++) + { + x[i][j] = vld1_f32(x_ptrs[i][j]); + x_ptrs[i][j] += 2; + } + } + + // Compute XT . x + for (int j = pad_left; j < cells_j; j++) + { + // XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j]; + XTx[0][j] = vmls_n_f32(vmla_n_f32(x[4][j], x[0][j], 4.0f), x[2][j], 5.0f); + + // XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j]; + XTx[1][j] = vmls_n_f32(vadd_f32(x[3][j], x[4][j]), vadd_f32(x[1][j], x[2][j]), 4.0f); + + // XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j]; + XTx[2][j] = vmla_n_f32(vsub_f32(x[4][j], x[3][j]), vsub_f32(x[1][j], x[2][j]), 4.0f); + + // XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j]; + XTx[3][j] = vmla_n_f32(vsub_f32(x[4][j], x[2][j]), vsub_f32(x[3][j], x[1][j]), 2.0f); + + // XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j]; + XTx[4][j] = vmla_n_f32(vsub_f32(x[4][j], x[2][j]), vsub_f32(x[1][j], x[3][j]), 2.0f); + + // XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j]; + XTx[5][j] = vmls_n_f32(vmla_n_f32(x[5][j], x[1][j], 4.0f), x[3][j], 5.0f); + } + + // Compute U = XT . x . X + for (int i = 0; i < 6; i++) + { + // U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4]; + U[i][0] = vmls_n_f32(vmla_n_f32(XTx[i][4], XTx[i][0], 4.0f), XTx[i][2], 5.0f); + + // U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4]; + U[i][1] = vmls_n_f32(vadd_f32(XTx[i][3], XTx[i][4]), vadd_f32(XTx[i][1], XTx[i][2]), 4.0f); + + // U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4]; + U[i][2] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][3]), vsub_f32(XTx[i][1], XTx[i][2]), 4.0f); + + // U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4]; + U[i][3] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][2]), vsub_f32(XTx[i][3], XTx[i][1]), 2.0f); + + // U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4]; + U[i][4] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][2]), vsub_f32(XTx[i][1], XTx[i][3]), 2.0f); + + // U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5]; + U[i][5] = vmls_n_f32(vmla_n_f32(XTx[i][5], XTx[i][1], 4.0f), XTx[i][3], 5.0f); + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + vst1_f32(outptr + m*matrix_stride, U[i][j]); + } + } + outptr += 2; + } +#endif // __arm_any__ + for (; channels_remaining; channels_remaining--) + { + // Load x + for (int i = pad_top; i < cells_i; i++) + { + for (int j = pad_left; j < cells_j; j++) + { + x[i][j] = *(x_ptrs[i][j]++); + } + } + + // Compute XT . x + for (int j = pad_left; j < cells_j; j++) + { + XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j]; + XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j]; + XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j]; + XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j]; + XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j]; + XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j]; + } + + // Compute U = XT . x . X + for (int i = 0; i < 6; i++) + { + U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4]; + U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4]; + U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4]; + U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4]; + U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4]; + U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5]; + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + *(outptr + m*matrix_stride) = U[i][j]; + } + } + outptr++; + } +} + +/* In the below, unusual or especially small tiles are routed via the slow + * path whereas common or large tiles are routed through a faster path. + */ +template <> +template <> +const Transform::TileFn Transform::tile_fns[2][2][max_pad_bottom][max_pad_right] = +{ + { + { + { + Transform::template process_tile<0, 0, 0, 0>, // No padding + Transform::template process_tile<0, 0, 0, 1>, // Right + Transform::template process_tile<0, 0, 0, 2>, // " " + Transform::template process_tile<0, 0, 0, 3>, // " " + Transform::template process_tile<0, 0, 0, 4>, // " " + }, + { + Transform::template process_tile<0, 0, 1, 0>, // Bottom + Transform::template process_tile<0, 0, 1, 1>, // Bottom right + Transform::template process_tile<0, 0, 1, 2>, // " " + Transform::template process_tile<0, 0, 1, 3>, // " " + Transform::template process_tile<0, 0, 1, 4>, // " " + }, + { + Transform::template process_tile<0, 0, 2, 0>, // Bottom + Transform::template process_tile<0, 0, 2, 1>, // Bottom right + Transform::template process_tile<0, 0, 2, 2>, // " " + Transform::template process_tile<0, 0, 2, 3>, // " " + Transform::template process_tile<0, 0, 2, 4>, // " " + }, + { + Transform::template process_tile<0, 0, 3, 0>, // Bottom + Transform::template process_tile<0, 0, 3, 1>, // Bottom right + Transform::template process_tile<0, 0, 3, 2>, // " " + Transform::template process_tile<0, 0, 3, 3>, // " " + Transform::template process_tile<0, 0, 3, 4>, // " " + }, + { + Transform::template process_tile<0, 0, 4, 0>, // Bottom + Transform::template process_tile<0, 0, 4, 1>, // Bottom right + Transform::template process_tile<0, 0, 4, 2>, // " " + Transform::template process_tile<0, 0, 4, 3>, // " " + Transform::template process_tile<0, 0, 4, 4>, // " " + } + }, + { + { + Transform::template process_tile<0, 1, 0, 0>, // Left + Transform::template process_tile<0, 1, 0, 1>, + Transform::template process_tile<0, 1, 0, 2>, + Transform::template process_tile<0, 1, 0, 3>, + Transform::template process_tile<0, 1, 0, 4>, + }, + { + Transform::template process_tile<0, 1, 1, 0>, // Bottom left + Transform::template process_tile<0, 1, 1, 1>, + Transform::template process_tile<0, 1, 1, 2>, + Transform::template process_tile<0, 1, 1, 3>, + Transform::template process_tile<0, 1, 1, 4>, + }, + { + Transform::template process_tile<0, 1, 2, 0>, // " " + Transform::template process_tile<0, 1, 2, 1>, + Transform::template process_tile<0, 1, 2, 2>, + Transform::template process_tile<0, 1, 2, 3>, + Transform::template process_tile<0, 1, 2, 4>, + }, + { + Transform::template process_tile<0, 1, 3, 0>, // " " + Transform::template process_tile<0, 1, 3, 1>, + Transform::template process_tile<0, 1, 3, 2>, + Transform::template process_tile<0, 1, 3, 3>, + Transform::template process_tile<0, 1, 3, 4>, + }, + { + Transform::template process_tile<0, 1, 4, 0>, // " " + Transform::template process_tile<0, 1, 4, 1>, + Transform::template process_tile<0, 1, 4, 2>, + Transform::template process_tile<0, 1, 4, 3>, + Transform::template process_tile<0, 1, 4, 4>, + } + } + }, + { + { + { + Transform::template process_tile<1, 0, 0, 0>, // Top + Transform::template process_tile<1, 0, 0, 1>, // Top right + Transform::template process_tile<1, 0, 0, 2>, // " " + Transform::template process_tile<1, 0, 0, 3>, // " " + Transform::template process_tile<1, 0, 0, 4>, // " " + }, + { + Transform::template process_tile<1, 0, 1, 0>, + Transform::template process_tile<1, 0, 1, 1>, + Transform::template process_tile<1, 0, 1, 2>, + Transform::template process_tile<1, 0, 1, 3>, + Transform::template process_tile<1, 0, 1, 4>, + }, + { + Transform::template process_tile<1, 0, 2, 0>, + Transform::template process_tile<1, 0, 2, 1>, + Transform::template process_tile<1, 0, 2, 2>, + Transform::template process_tile<1, 0, 2, 3>, + Transform::template process_tile<1, 0, 2, 4>, + }, + { + Transform::template process_tile<1, 0, 3, 0>, + Transform::template process_tile<1, 0, 3, 1>, + Transform::template process_tile<1, 0, 3, 2>, + Transform::template process_tile<1, 0, 3, 3>, + Transform::template process_tile<1, 0, 3, 4>, + }, + { + Transform::template process_tile<1, 0, 4, 0>, + Transform::template process_tile<1, 0, 4, 1>, + Transform::template process_tile<1, 0, 4, 2>, + Transform::template process_tile<1, 0, 4, 3>, + Transform::template process_tile<1, 0, 4, 4>, + }, + }, + { + { + Transform::template process_tile<1, 1, 0, 0>, // Top left + Transform::template process_tile<1, 1, 0, 1>, + Transform::template process_tile<1, 1, 0, 2>, + Transform::template process_tile<1, 1, 0, 3>, + Transform::template process_tile<1, 1, 0, 4>, + }, + { + Transform::template process_tile<1, 1, 1, 0>, + Transform::template process_tile<1, 1, 1, 1>, + Transform::template process_tile<1, 1, 1, 2>, + Transform::template process_tile<1, 1, 1, 3>, + Transform::template process_tile<1, 1, 1, 4>, + }, + { + Transform::template process_tile<1, 1, 2, 0>, + Transform::template process_tile<1, 1, 2, 1>, + Transform::template process_tile<1, 1, 2, 2>, + Transform::template process_tile<1, 1, 2, 3>, + Transform::template process_tile<1, 1, 2, 4>, + }, + { + Transform::template process_tile<1, 1, 3, 0>, + Transform::template process_tile<1, 1, 3, 1>, + Transform::template process_tile<1, 1, 3, 2>, + Transform::template process_tile<1, 1, 3, 3>, + Transform::template process_tile<1, 1, 3, 4>, + }, + { + Transform::template process_tile<1, 1, 4, 0>, + Transform::template process_tile<1, 1, 4, 1>, + Transform::template process_tile<1, 1, 4, 2>, + Transform::template process_tile<1, 1, 4, 3>, + Transform::template process_tile<1, 1, 4, 4>, + } + } + } +}; + +template struct WinogradGEMM<4, 4, 3, 3>::InputTransform<float>; +} // namespace winograd diff --git a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp b/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp deleted file mode 100644 index 033442aa14..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -namespace winograd { - /* Transform a kernel into the Winograd domain. - * - * NOTE: It is assumed that the kernel is in the form [height x width x - * input_channels x output_channel]. - */ - template <typename T> - struct winograd2x2_3x3_gemm_kernel_transform_impl{ - static void execute( - const KernelShape &shape, - const T* const kernel, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride - ); - - protected: - template <const int output_channel_tail> - static void transform_kernel( - const T* const kernel, - const int n_input_channels, - const int n_output_channels, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride - ); - }; -} - -/*****************************************************************************/ -/* Transform a fp32 kernel into the Winograd domain. - */ -#include "kernel_2x2_3x3/a64_float.hpp" // AArch64 specialisations - -namespace winograd -{ -template <> -inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::execute( - const KernelShape &shape, - const float* const kernel, - float* const matrix_base, - const int matrix_stride, - const int matrix_row_stride -) { - // Delegate based on tail size - const int n_input_channels = shape.n_input_channels; - const int n_output_channels = shape.n_output_channels; - - switch (n_output_channels % 4) { - case 0: - transform_kernel<0>( - kernel, n_input_channels, n_output_channels, - matrix_base, matrix_stride, matrix_row_stride - ); - break; - case 1: - transform_kernel<1>( - kernel, n_input_channels, n_output_channels, - matrix_base, matrix_stride, matrix_row_stride - ); - break; - case 2: - transform_kernel<2>( - kernel, n_input_channels, n_output_channels, - matrix_base, matrix_stride, matrix_row_stride - ); - break; - case 3: - transform_kernel<3>( - kernel, n_input_channels, n_output_channels, - matrix_base, matrix_stride, matrix_row_stride - ); - break; - default: - ARM_COMPUTE_ERROR("Cannot happen"); - break; - } -} - -template <> -template<const int output_channel_tail> -inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel( - const float* const kernel, - const int n_input_channels, - const int n_output_channels, - float* const matrix_base, - const int mstride, - const int matrix_row_stride -) { - // Use one input pointer for each row of the kernel, use two additional - // offsets to extract columns. - const int kernel_col_stride = n_input_channels * n_output_channels; - const int kernel_row_stride = 3 * kernel_col_stride; - const float *inptr0 = kernel; - const float *inptr1 = kernel + kernel_row_stride; - const float *inptr2 = kernel + kernel_row_stride*2; - - // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three - // offsets to extract further matrices. - float *outptr0 = matrix_base; - float *outptr4 = matrix_base + mstride * 4; - float *outptr8 = matrix_base + mstride * 8; - float *outptr12 = matrix_base + mstride * 12; - - // For every input channel - for (int in_c = 0; in_c < n_input_channels; in_c++) { - // For every output channel - for (int c = 0; c < n_output_channels; c++) { - // Read in the kernel - float w11 = inptr0[0], w12 = inptr0[kernel_col_stride], w13 = inptr0[kernel_col_stride*2]; - float w21 = inptr1[0], w22 = inptr1[kernel_col_stride], w23 = inptr1[kernel_col_stride*2]; - float w31 = inptr2[0], w32 = inptr2[kernel_col_stride], w33 = inptr2[kernel_col_stride*2]; - - // Progress input pointers - inptr0++; - inptr1++; - inptr2++; - - // Compute the kernel W w, note we need only compute the middle two rows - // (2 and 3) because the first and last rows are merely copies of values - // from the matrix w. - float Ww11 = w11, Ww12 = w12, Ww13 = w13; - float Ww21 = 0.5*(w11 + w21 + w31), Ww22 = 0.5*(w12 + w22 + w32), Ww23 = 0.5*(w13 + w23 + w33); - float Ww31 = 0.5*(w11 - w21 + w31), Ww32 = 0.5*(w12 - w22 + w32), Ww33 = 0.5*(w13 - w23 + w33); - float Ww41 = w31, Ww42 = w32, Ww43 = w33; - - // Hence compute W w W.T; again note we need compute only the middle two - // columns since the first and last columns are copies of the first and - // last columns of the previous matrix. - float WwWT11 = Ww11, WwWT12 = 0.5*(Ww11 + Ww12 + Ww13), WwWT13 = 0.5*(Ww11 - Ww12 + Ww13), WwWT14 = Ww13; - float WwWT21 = Ww21, WwWT22 = 0.5*(Ww21 + Ww22 + Ww23), WwWT23 = 0.5*(Ww21 - Ww22 + Ww23), WwWT24 = Ww23; - float WwWT31 = Ww31, WwWT32 = 0.5*(Ww31 + Ww32 + Ww33), WwWT33 = 0.5*(Ww31 - Ww32 + Ww33), WwWT34 = Ww33; - float WwWT41 = Ww41, WwWT42 = 0.5*(Ww41 + Ww42 + Ww43), WwWT43 = 0.5*(Ww41 - Ww42 + Ww43), WwWT44 = Ww43; - - // Store the computed weights - outptr0[0 * mstride] = WwWT11; - outptr0[1 * mstride] = WwWT12; - outptr0[2 * mstride] = WwWT13; - outptr0[3 * mstride] = WwWT14; - - outptr4[0 * mstride] = WwWT21; - outptr4[1 * mstride] = WwWT22; - outptr4[2 * mstride] = WwWT23; - outptr4[3 * mstride] = WwWT24; - - outptr8[0 * mstride] = WwWT31; - outptr8[1 * mstride] = WwWT32; - outptr8[2 * mstride] = WwWT33; - outptr8[3 * mstride] = WwWT34; - - outptr12[0 * mstride] = WwWT41; - outptr12[1 * mstride] = WwWT42; - outptr12[2 * mstride] = WwWT43; - outptr12[3 * mstride] = WwWT44; - - // Progress output pointers - outptr0++; - outptr4++; - outptr8++; - outptr12++; - } - - // Progression to complete stride - outptr0 += matrix_row_stride - n_output_channels; - outptr4 += matrix_row_stride - n_output_channels; - outptr8 += matrix_row_stride - n_output_channels; - outptr12 += matrix_row_stride - n_output_channels; - } -} -} diff --git a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp b/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp deleted file mode 100644 index 3dd62d1ac1..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp +++ /dev/null @@ -1,822 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __aarch64__ -namespace winograd { -template <> -template <> -inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel<0>( - const float* const kernel, - const int n_input_channels, - const int n_output_channels, - float* const matrix_base, - const int mstride, - const int matrix_row_stride -) { - // Use one input pointer for each row of the kernel, use two additional - // offsets to extract columns. - const int kernel_col_stride = n_input_channels * n_output_channels; - const int kernel_row_stride = 3 * kernel_col_stride; - const float *inptr0 = kernel; - const float *inptr1 = kernel + kernel_row_stride; - const float *inptr2 = kernel + kernel_row_stride*2; - - // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three - // offsets to extract further matrices. - float *outptr0 = matrix_base; - float *outptr4 = matrix_base + mstride * 4; - float *outptr8 = matrix_base + mstride * 8; - float *outptr12 = matrix_base + mstride * 12; - - // For every input channel - for (int in_c = 0; in_c < n_input_channels; in_c++) { - int n_remaining_channels = n_output_channels; - - asm volatile ( - // Registers into which to read the kernel - "w_11 .req v0\n" "qw_11 .req q0\n" - "w_12 .req v1\n" "qw_12 .req q1\n" - "w_13 .req v2\n" "qw_13 .req q2\n" - "w_21 .req v3\n" "qw_21 .req q3\n" - "w_22 .req v4\n" "qw_22 .req q4\n" - "w_23 .req v5\n" "qw_23 .req q5\n" - "w_31 .req v6\n" "qw_31 .req q6\n" - "w_32 .req v7\n" "qw_32 .req q7\n" - "w_33 .req v8\n" "qw_33 .req q8\n" - - // Transformed matrix Ww - "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" - "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" - "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" - "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" - - // Output matrix U = WwWT - "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" - "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" - "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" - "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" - - // Storage view of output matrices - "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" - "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" - "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" - "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" - - "half .req v23\n" // {0.5, ..., 0.5} - "dup half.4s, %w[one_half]\n" - "scratch .req v24\n" - - "1:" - // Load tile of the kernel - "ldr qw_11, [%x[inptr0]]\n" - "str qU11, [%x[outptr0]]\n" - "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" - "str qU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qw_21, [%x[inptr1]]\n" - "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qw_31, [%x[inptr2]]\n" - "str qU41, [%x[outptr12]]\n" - "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" - "str qU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.4s, w_11.4s, w_31.4s\n" - "fmul Ww21.4s, scratch.4s, half.4s\n" - "fmla Ww21.4s, w_21.4s, half.4s\n" - "str qU21, [%x[outptr4]]\n" - "fmul Ww31.4s, scratch.4s, half.4s\n" - "fmls Ww31.4s, w_21.4s, half.4s\n" - "str qU31, [%x[outptr8]]\n" - - "fadd scratch.4s, w_12.4s, w_32.4s\n" - "fmul Ww22.4s, scratch.4s, half.4s\n" - "fmla Ww22.4s, w_22.4s, half.4s\n" - "fmul Ww32.4s, scratch.4s, half.4s\n" - "fmls Ww32.4s, w_22.4s, half.4s\n" - - "fadd scratch.4s, w_13.4s, w_33.4s\n" - "fmul Ww23.4s, scratch.4s, half.4s\n" - "fmla Ww23.4s, w_23.4s, half.4s\n" - "str qU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.4s, scratch.4s, half.4s\n" - "fmls Ww33.4s, w_23.4s, half.4s\n" - "str qU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns - // of U and update output pointers - "fadd scratch.4s, Ww11.4s, Ww13.4s\n" - "fmul U12.4s, scratch.4s, half.4s\n" - "fmla U12.4s, Ww12.4s, half.4s\n" - "str qU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.4s, scratch.4s, half.4s\n" - "fmls U13.4s, Ww12.4s, half.4s\n" - "str qU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd scratch.4s, Ww21.4s, Ww23.4s\n" - "fmul U22.4s, scratch.4s, half.4s\n" - "fmla U22.4s, Ww22.4s, half.4s\n" - "str qU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.4s, scratch.4s, half.4s\n" - "fmls U23.4s, Ww22.4s, half.4s\n" - "str qU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fadd scratch.4s, Ww31.4s, Ww33.4s\n" - "fmul U32.4s, scratch.4s, half.4s\n" - "fmla U32.4s, Ww32.4s, half.4s\n" - "str qU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.4s, scratch.4s, half.4s\n" - "fmls U33.4s, Ww32.4s, half.4s\n" - "str qU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fadd scratch.4s, Ww41.4s, Ww43.4s\n" - "fmul U42.4s, scratch.4s, half.4s\n" - "fmla U42.4s, Ww42.4s, half.4s\n" - "str qU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.4s, scratch.4s, half.4s\n" - "fmls U43.4s, Ww42.4s, half.4s\n" - "str qU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" - "bne 1b\n" - - // Clear aliases - ".unreq half\n" - ".unreq scratch\n" - ".unreq w_11\n" ".unreq qw_11\n" - ".unreq w_12\n" ".unreq qw_12\n" - ".unreq w_13\n" ".unreq qw_13\n" - ".unreq w_21\n" ".unreq qw_21\n" - ".unreq w_22\n" ".unreq qw_22\n" - ".unreq w_23\n" ".unreq qw_23\n" - ".unreq w_31\n" ".unreq qw_31\n" - ".unreq w_32\n" ".unreq qw_32\n" - ".unreq w_33\n" ".unreq qw_33\n" - ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" - ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" - ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" - ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" - ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" - ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" - ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" - ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" - ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" - ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" - ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" - ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [n_remaining_channels] "+r" (n_remaining_channels) - : [mstride1] "r" (sizeof(float) * mstride), - [mstride2] "r" (sizeof(float) * mstride * 2), - [mstride3] "r" (sizeof(float) * mstride * 3), - [colstride1] "r" (sizeof(float) * kernel_col_stride), - [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), - [one_half] "r" (0.5f) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24" - ); - - // Progression to complete stride - outptr0 += matrix_row_stride - n_output_channels; - outptr4 += matrix_row_stride - n_output_channels; - outptr8 += matrix_row_stride - n_output_channels; - outptr12 += matrix_row_stride - n_output_channels; - } -} - -template <> -template <> -inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel<2>( - const float* const kernel, - const int n_input_channels, - const int n_output_channels, - float* const matrix_base, - const int mstride, - const int matrix_row_stride -) { - // Use one input pointer for each row of the kernel, use two additional - // offsets to extract columns. - const int kernel_col_stride = n_input_channels * n_output_channels; - const int kernel_row_stride = 3 * kernel_col_stride; - const float *inptr0 = kernel; - const float *inptr1 = kernel + kernel_row_stride; - const float *inptr2 = kernel + kernel_row_stride*2; - - // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three - // offsets to extract further matrices. - float *outptr0 = matrix_base; - float *outptr4 = matrix_base + mstride * 4; - float *outptr8 = matrix_base + mstride * 8; - float *outptr12 = matrix_base + mstride * 12; - - // For every input channel - for (int in_c = 0; in_c < n_input_channels; in_c++) { - int n_remaining_channels = n_output_channels; - - asm volatile ( - // Registers into which to read the kernel - "w_11 .req v0\n" "qw_11 .req q0\n" "dw_11 .req d0\n" - "w_12 .req v1\n" "qw_12 .req q1\n" "dw_12 .req d1\n" - "w_13 .req v2\n" "qw_13 .req q2\n" "dw_13 .req d2\n" - "w_21 .req v3\n" "qw_21 .req q3\n" "dw_21 .req d3\n" - "w_22 .req v4\n" "qw_22 .req q4\n" "dw_22 .req d4\n" - "w_23 .req v5\n" "qw_23 .req q5\n" "dw_23 .req d5\n" - "w_31 .req v6\n" "qw_31 .req q6\n" "dw_31 .req d6\n" - "w_32 .req v7\n" "qw_32 .req q7\n" "dw_32 .req d7\n" - "w_33 .req v8\n" "qw_33 .req q8\n" "dw_33 .req d8\n" - - // Transformed matrix Ww - "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" - "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" - "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" - "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" - - // Output matrix U = WwWT - "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" - "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" - "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" - "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" - - // Storage view of output matrices - "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" - "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" - "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" - "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" - - "dU11 .req d0\n" "dU12 .req d15\n" "dU13 .req d16\n" "dU14 .req d2\n" - "dU21 .req d9\n" "dU22 .req d17\n" "dU23 .req d18\n" "dU24 .req d11\n" - "dU31 .req d12\n" "dU32 .req d19\n" "dU33 .req d20\n" "dU34 .req d14\n" - "dU41 .req d6\n" "dU42 .req d21\n" "dU43 .req d22\n" "dU44 .req d8\n" - - "half .req v23\n" // {0.5, ..., 0.5} - "dup half.4s, %w[one_half]\n" - "scratch .req v24\n" - - // Subtract the tail from the number of remaining channels and jump to - // the tail if necessary. - "subs %x[n_remaining_channels], %x[n_remaining_channels], #2\n" - "beq 2f\n" - - "1:" - // Load tile of the kernel - "ldr qw_11, [%x[inptr0]]\n" - "str qU11, [%x[outptr0]]\n" - "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" - "str qU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qw_21, [%x[inptr1]]\n" - "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qw_31, [%x[inptr2]]\n" - "str qU41, [%x[outptr12]]\n" - "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" - "str qU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.4s, w_11.4s, w_31.4s\n" - "fmul Ww21.4s, scratch.4s, half.4s\n" - "fmla Ww21.4s, w_21.4s, half.4s\n" - "str qU21, [%x[outptr4]]\n" - "fmul Ww31.4s, scratch.4s, half.4s\n" - "fmls Ww31.4s, w_21.4s, half.4s\n" - "str qU31, [%x[outptr8]]\n" - - "fadd scratch.4s, w_12.4s, w_32.4s\n" - "fmul Ww22.4s, scratch.4s, half.4s\n" - "fmla Ww22.4s, w_22.4s, half.4s\n" - "fmul Ww32.4s, scratch.4s, half.4s\n" - "fmls Ww32.4s, w_22.4s, half.4s\n" - - "fadd scratch.4s, w_13.4s, w_33.4s\n" - "fmul Ww23.4s, scratch.4s, half.4s\n" - "fmla Ww23.4s, w_23.4s, half.4s\n" - "str qU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.4s, scratch.4s, half.4s\n" - "fmls Ww33.4s, w_23.4s, half.4s\n" - "str qU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns - // of U and update output pointers - "fadd scratch.4s, Ww11.4s, Ww13.4s\n" - "fmul U12.4s, scratch.4s, half.4s\n" - "fmla U12.4s, Ww12.4s, half.4s\n" - "str qU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.4s, scratch.4s, half.4s\n" - "fmls U13.4s, Ww12.4s, half.4s\n" - "str qU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd scratch.4s, Ww21.4s, Ww23.4s\n" - "fmul U22.4s, scratch.4s, half.4s\n" - "fmla U22.4s, Ww22.4s, half.4s\n" - "str qU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.4s, scratch.4s, half.4s\n" - "fmls U23.4s, Ww22.4s, half.4s\n" - "str qU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fadd scratch.4s, Ww31.4s, Ww33.4s\n" - "fmul U32.4s, scratch.4s, half.4s\n" - "fmla U32.4s, Ww32.4s, half.4s\n" - "str qU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.4s, scratch.4s, half.4s\n" - "fmls U33.4s, Ww32.4s, half.4s\n" - "str qU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fadd scratch.4s, Ww41.4s, Ww43.4s\n" - "fmul U42.4s, scratch.4s, half.4s\n" - "fmla U42.4s, Ww42.4s, half.4s\n" - "str qU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.4s, scratch.4s, half.4s\n" - "fmls U43.4s, Ww42.4s, half.4s\n" - "str qU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" - "bne 1b\n" - - // Tail size 2 - "2:" - // Load tile of the kernel - "ldr dw_11, [%x[inptr0]]\n" - "str dU11, [%x[outptr0]]\n" - "ldr dw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr dw_13, [%x[inptr0], %x[colstride2]]\n" - "str dU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x08\n" - - "ldr dw_21, [%x[inptr1]]\n" - "ldr dw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr dw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x08\n" - - "ldr dw_31, [%x[inptr2]]\n" - "str dU41, [%x[outptr12]]\n" - "ldr dw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr dw_33, [%x[inptr2], %x[colstride2]]\n" - "str dU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x08\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.2s, w_11.2s, w_31.2s\n" - "fmul Ww21.2s, scratch.2s, half.2s\n" - "fmla Ww21.2s, w_21.2s, half.2s\n" - "str dU21, [%x[outptr4]]\n" - "fmul Ww31.2s, scratch.2s, half.2s\n" - "fmls Ww31.2s, w_21.2s, half.2s\n" - "str dU31, [%x[outptr8]]\n" - - "fadd scratch.2s, w_12.2s, w_32.2s\n" - "fmul Ww22.2s, scratch.2s, half.2s\n" - "fmla Ww22.2s, w_22.2s, half.2s\n" - "fmul Ww32.2s, scratch.2s, half.2s\n" - "fmls Ww32.2s, w_22.2s, half.2s\n" - - "fadd scratch.2s, w_13.2s, w_33.2s\n" - "fmul Ww23.2s, scratch.2s, half.2s\n" - "fmla Ww23.2s, w_23.2s, half.2s\n" - "str dU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.2s, scratch.2s, half.2s\n" - "fmls Ww33.2s, w_23.2s, half.2s\n" - "str dU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns of - // U and update output pointers - "fadd scratch.2s, Ww11.2s, Ww13.2s\n" - "fmul U12.2s, scratch.2s, half.2s\n" - "fmla U12.2s, Ww12.2s, half.2s\n" - "str dU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.2s, scratch.2s, half.2s\n" - "fmls U13.2s, Ww12.2s, half.2s\n" - "str dU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x08\n" - - "fadd scratch.2s, Ww21.2s, Ww23.2s\n" - "fmul U22.2s, scratch.2s, half.2s\n" - "fmla U22.2s, Ww22.2s, half.2s\n" - "str dU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.2s, scratch.2s, half.2s\n" - "fmls U23.2s, Ww22.2s, half.2s\n" - "str dU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x08\n" - - "fadd scratch.2s, Ww31.2s, Ww33.2s\n" - "fmul U32.2s, scratch.2s, half.2s\n" - "fmla U32.2s, Ww32.2s, half.2s\n" - "str dU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.2s, scratch.2s, half.2s\n" - "fmls U33.2s, Ww32.2s, half.2s\n" - "str dU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x08\n" - - "fadd scratch.2s, Ww41.2s, Ww43.2s\n" - "fmul U42.2s, scratch.2s, half.2s\n" - "fmla U42.2s, Ww42.2s, half.2s\n" - "str dU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.2s, scratch.2s, half.2s\n" - "fmls U43.2s, Ww42.2s, half.2s\n" - "str dU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x08\n" - - // Clear aliases - ".unreq half\n" - ".unreq scratch\n" - ".unreq w_11\n" ".unreq qw_11\n" ".unreq dw_11\n" - ".unreq w_12\n" ".unreq qw_12\n" ".unreq dw_12\n" - ".unreq w_13\n" ".unreq qw_13\n" ".unreq dw_13\n" - ".unreq w_21\n" ".unreq qw_21\n" ".unreq dw_21\n" - ".unreq w_22\n" ".unreq qw_22\n" ".unreq dw_22\n" - ".unreq w_23\n" ".unreq qw_23\n" ".unreq dw_23\n" - ".unreq w_31\n" ".unreq qw_31\n" ".unreq dw_31\n" - ".unreq w_32\n" ".unreq qw_32\n" ".unreq dw_32\n" - ".unreq w_33\n" ".unreq qw_33\n" ".unreq dw_33\n" - ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" - ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" - ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" - ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" - ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" - ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" - ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" - ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" - ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" - ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" - ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" - ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" - ".unreq dU11\n" ".unreq dU12\n" ".unreq dU13\n" ".unreq dU14\n" - ".unreq dU21\n" ".unreq dU22\n" ".unreq dU23\n" ".unreq dU24\n" - ".unreq dU31\n" ".unreq dU32\n" ".unreq dU33\n" ".unreq dU34\n" - ".unreq dU41\n" ".unreq dU42\n" ".unreq dU43\n" ".unreq dU44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [n_remaining_channels] "+r" (n_remaining_channels) - : [mstride1] "r" (sizeof(float) * mstride), - [mstride2] "r" (sizeof(float) * mstride * 2), - [mstride3] "r" (sizeof(float) * mstride * 3), - [colstride1] "r" (sizeof(float) * kernel_col_stride), - [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), - [one_half] "r" (0.5f) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24" - ); - - // Progression to complete stride - outptr0 += matrix_row_stride - n_output_channels; - outptr4 += matrix_row_stride - n_output_channels; - outptr8 += matrix_row_stride - n_output_channels; - outptr12 += matrix_row_stride - n_output_channels; - } -} - -template <> -template <> -inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel<1>( - const float* const kernel, - const int n_input_channels, - const int n_output_channels, - float* const matrix_base, - const int mstride, - const int matrix_row_stride -) { - // Use one input pointer for each row of the kernel, use two additional - // offsets to extract columns. - const int kernel_col_stride = n_input_channels * n_output_channels; - const int kernel_row_stride = 3 * kernel_col_stride; - const float *inptr0 = kernel; - const float *inptr1 = kernel + kernel_row_stride; - const float *inptr2 = kernel + kernel_row_stride*2; - - // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three - // offsets to extract further matrices. - float *outptr0 = matrix_base; - float *outptr4 = matrix_base + mstride * 4; - float *outptr8 = matrix_base + mstride * 8; - float *outptr12 = matrix_base + mstride * 12; - - // For every input channel - for (int in_c = 0; in_c < n_input_channels; in_c++) { - int n_remaining_channels = n_output_channels; - - asm volatile ( - // Registers into which to read the kernel - "w_11 .req v0\n" "qw_11 .req q0\n" "sw_11 .req s0\n" - "w_12 .req v1\n" "qw_12 .req q1\n" "sw_12 .req s1\n" - "w_13 .req v2\n" "qw_13 .req q2\n" "sw_13 .req s2\n" - "w_21 .req v3\n" "qw_21 .req q3\n" "sw_21 .req s3\n" - "w_22 .req v4\n" "qw_22 .req q4\n" "sw_22 .req s4\n" - "w_23 .req v5\n" "qw_23 .req q5\n" "sw_23 .req s5\n" - "w_31 .req v6\n" "qw_31 .req q6\n" "sw_31 .req s6\n" - "w_32 .req v7\n" "qw_32 .req q7\n" "sw_32 .req s7\n" - "w_33 .req v8\n" "qw_33 .req q8\n" "sw_33 .req s8\n" - - // Transformed matrix Ww - "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" - "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" - "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" - "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" - - // Output matrix U = WwWT - "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" - "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" - "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" - "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" - - // Storage view of output matrices - "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" - "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" - "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" - "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" - - "sU11 .req s0\n" "sU12 .req s15\n" "sU13 .req s16\n" "sU14 .req s2\n" - "sU21 .req s9\n" "sU22 .req s17\n" "sU23 .req s18\n" "sU24 .req s11\n" - "sU31 .req s12\n" "sU32 .req s19\n" "sU33 .req s20\n" "sU34 .req s14\n" - "sU41 .req s6\n" "sU42 .req s21\n" "sU43 .req s22\n" "sU44 .req s8\n" - - "half .req v23\n" // {0.5, ..., 0.5} - "dup half.4s, %w[one_half]\n" - "scratch .req v24\n" - - // Subtract the tail from the number of remaining channels and jump to - // the tail if necessary. - "subs %x[n_remaining_channels], %x[n_remaining_channels], #1\n" - "beq 2f\n" - - "1:" - // Load tile of the kernel - "ldr qw_11, [%x[inptr0]]\n" - "str qU11, [%x[outptr0]]\n" - "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" - "str qU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qw_21, [%x[inptr1]]\n" - "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qw_31, [%x[inptr2]]\n" - "str qU41, [%x[outptr12]]\n" - "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" - "str qU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.4s, w_11.4s, w_31.4s\n" - "fmul Ww21.4s, scratch.4s, half.4s\n" - "fmla Ww21.4s, w_21.4s, half.4s\n" - "str qU21, [%x[outptr4]]\n" - "fmul Ww31.4s, scratch.4s, half.4s\n" - "fmls Ww31.4s, w_21.4s, half.4s\n" - "str qU31, [%x[outptr8]]\n" - - "fadd scratch.4s, w_12.4s, w_32.4s\n" - "fmul Ww22.4s, scratch.4s, half.4s\n" - "fmla Ww22.4s, w_22.4s, half.4s\n" - "fmul Ww32.4s, scratch.4s, half.4s\n" - "fmls Ww32.4s, w_22.4s, half.4s\n" - - "fadd scratch.4s, w_13.4s, w_33.4s\n" - "fmul Ww23.4s, scratch.4s, half.4s\n" - "fmla Ww23.4s, w_23.4s, half.4s\n" - "str qU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.4s, scratch.4s, half.4s\n" - "fmls Ww33.4s, w_23.4s, half.4s\n" - "str qU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns - // of U and update output pointers - "fadd scratch.4s, Ww11.4s, Ww13.4s\n" - "fmul U12.4s, scratch.4s, half.4s\n" - "fmla U12.4s, Ww12.4s, half.4s\n" - "str qU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.4s, scratch.4s, half.4s\n" - "fmls U13.4s, Ww12.4s, half.4s\n" - "str qU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd scratch.4s, Ww21.4s, Ww23.4s\n" - "fmul U22.4s, scratch.4s, half.4s\n" - "fmla U22.4s, Ww22.4s, half.4s\n" - "str qU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.4s, scratch.4s, half.4s\n" - "fmls U23.4s, Ww22.4s, half.4s\n" - "str qU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fadd scratch.4s, Ww31.4s, Ww33.4s\n" - "fmul U32.4s, scratch.4s, half.4s\n" - "fmla U32.4s, Ww32.4s, half.4s\n" - "str qU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.4s, scratch.4s, half.4s\n" - "fmls U33.4s, Ww32.4s, half.4s\n" - "str qU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fadd scratch.4s, Ww41.4s, Ww43.4s\n" - "fmul U42.4s, scratch.4s, half.4s\n" - "fmla U42.4s, Ww42.4s, half.4s\n" - "str qU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.4s, scratch.4s, half.4s\n" - "fmls U43.4s, Ww42.4s, half.4s\n" - "str qU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" - "bne 1b\n" - - // Tail size 1 - "2:" - // Load tile of the kernel - "ldr sw_11, [%x[inptr0]]\n" - "str sU11, [%x[outptr0]]\n" - "ldr sw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr sw_13, [%x[inptr0], %x[colstride2]]\n" - "str sU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x04\n" - - "ldr sw_21, [%x[inptr1]]\n" - "ldr sw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr sw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x04\n" - - "ldr sw_31, [%x[inptr2]]\n" - "str sU41, [%x[outptr12]]\n" - "ldr sw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr sw_33, [%x[inptr2], %x[colstride2]]\n" - "str sU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x04\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.2s, w_11.2s, w_31.2s\n" - "fmul Ww21.2s, scratch.2s, half.2s\n" - "fmla Ww21.2s, w_21.2s, half.2s\n" - "str sU21, [%x[outptr4]]\n" - "fmul Ww31.2s, scratch.2s, half.2s\n" - "fmls Ww31.2s, w_21.2s, half.2s\n" - "str sU31, [%x[outptr8]]\n" - - "fadd scratch.2s, w_12.2s, w_32.2s\n" - "fmul Ww22.2s, scratch.2s, half.2s\n" - "fmla Ww22.2s, w_22.2s, half.2s\n" - "fmul Ww32.2s, scratch.2s, half.2s\n" - "fmls Ww32.2s, w_22.2s, half.2s\n" - - "fadd scratch.2s, w_13.2s, w_33.2s\n" - "fmul Ww23.2s, scratch.2s, half.2s\n" - "fmla Ww23.2s, w_23.2s, half.2s\n" - "str sU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.2s, scratch.2s, half.2s\n" - "fmls Ww33.2s, w_23.2s, half.2s\n" - "str sU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns of - // U and update output pointers - "fadd scratch.2s, Ww11.2s, Ww13.2s\n" - "fmul U12.2s, scratch.2s, half.2s\n" - "fmla U12.2s, Ww12.2s, half.2s\n" - "str sU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.2s, scratch.2s, half.2s\n" - "fmls U13.2s, Ww12.2s, half.2s\n" - "str sU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x04\n" - - "fadd scratch.2s, Ww21.2s, Ww23.2s\n" - "fmul U22.2s, scratch.2s, half.2s\n" - "fmla U22.2s, Ww22.2s, half.2s\n" - "str sU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.2s, scratch.2s, half.2s\n" - "fmls U23.2s, Ww22.2s, half.2s\n" - "str sU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x04\n" - - "fadd scratch.2s, Ww31.2s, Ww33.2s\n" - "fmul U32.2s, scratch.2s, half.2s\n" - "fmla U32.2s, Ww32.2s, half.2s\n" - "str sU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.2s, scratch.2s, half.2s\n" - "fmls U33.2s, Ww32.2s, half.2s\n" - "str sU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x04\n" - - "fadd scratch.2s, Ww41.2s, Ww43.2s\n" - "fmul U42.2s, scratch.2s, half.2s\n" - "fmla U42.2s, Ww42.2s, half.2s\n" - "str sU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.2s, scratch.2s, half.2s\n" - "fmls U43.2s, Ww42.2s, half.2s\n" - "str sU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x04\n" - - // Clear aliases - ".unreq half\n" - ".unreq scratch\n" - ".unreq w_11\n" ".unreq qw_11\n" ".unreq sw_11\n" - ".unreq w_12\n" ".unreq qw_12\n" ".unreq sw_12\n" - ".unreq w_13\n" ".unreq qw_13\n" ".unreq sw_13\n" - ".unreq w_21\n" ".unreq qw_21\n" ".unreq sw_21\n" - ".unreq w_22\n" ".unreq qw_22\n" ".unreq sw_22\n" - ".unreq w_23\n" ".unreq qw_23\n" ".unreq sw_23\n" - ".unreq w_31\n" ".unreq qw_31\n" ".unreq sw_31\n" - ".unreq w_32\n" ".unreq qw_32\n" ".unreq sw_32\n" - ".unreq w_33\n" ".unreq qw_33\n" ".unreq sw_33\n" - ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" - ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" - ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" - ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" - ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" - ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" - ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" - ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" - ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" - ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" - ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" - ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" - ".unreq sU11\n" ".unreq sU12\n" ".unreq sU13\n" ".unreq sU14\n" - ".unreq sU21\n" ".unreq sU22\n" ".unreq sU23\n" ".unreq sU24\n" - ".unreq sU31\n" ".unreq sU32\n" ".unreq sU33\n" ".unreq sU34\n" - ".unreq sU41\n" ".unreq sU42\n" ".unreq sU43\n" ".unreq sU44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [n_remaining_channels] "+r" (n_remaining_channels) - : [mstride1] "r" (sizeof(float) * mstride), - [mstride2] "r" (sizeof(float) * mstride * 2), - [mstride3] "r" (sizeof(float) * mstride * 3), - [colstride1] "r" (sizeof(float) * kernel_col_stride), - [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), - [one_half] "r" (0.5f) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24" - ); - - // Progression to complete stride - outptr0 += matrix_row_stride - n_output_channels; - outptr4 += matrix_row_stride - n_output_channels; - outptr8 += matrix_row_stride - n_output_channels; - outptr12 += matrix_row_stride - n_output_channels; - } -} -} -#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp deleted file mode 100644 index 0992c0bb44..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp +++ /dev/null @@ -1,356 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -namespace winograd { - /* Transform from the Winograd domain back to the spatial domain. - */ - template <typename T> - struct Winograd2x2_3x3GemmOutput { - static void execute( - const Tensor4DShape &output_shape, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - T* const output - ); - - protected: - /* Specialised implementation method. */ - template <bool tail_M, bool tail_N, int channel_tail> - static void _execute( - const Tensor4DShape &output_shape, - T *output, - const T *input, - const int matrix_stride, - const int matrix_row_stride - ); - }; - - /* Two-stage implementation of the transformation from the Winograd domain. - * - * First computes Z.F and then computes (Z.F).Z^T. - */ - template <typename T> - struct Winograd2x2_3x3GemmOutput_TwoStage { - static void execute( - const Tensor4DShape &output_shape, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - T* const output - ); - - protected: - template <int channel_tail> - static void compute_zf( - const int n_rows, const int n_channels, - T* const zf, const T* const input[16] - ); - - template <bool tail_M, bool tail_N, int channel_tail> - static void compute_zfzT( - const Tensor4DShape &output_shape, - T* const output, const T* const zf - ); - }; -} - -#include "output_2x2_3x3/a64_float.hpp" -// #include "output_2x2_3x3/a64_float_two_stage.hpp" - -/*****************************************************************************/ -/* -template <typename T> -void winograd::Winograd2x2_3x3GemmOutput<T>::execute( - const Tensor4DShape &output_shape, - const int tile_M, - const int tile_N, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - T* const output -) { - T* const antipadding = reinterpret_cast<T *>(malloc(sizeof(T) * output_shape.n_channels)); - - // Get input pointers - const T* inptrs[16]; - for (int i = 0; i < 16; i++) { - inptrs[i] = matrices[i]; - } - - for (int batch = 0; batch < output_shape.n_batches; batch++) { - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - // Get pointers for each of the 4 output cells required for this computation - T* outptrs[4]; - for (int cell_i = 0, c = 0; cell_i < 2; cell_i++) { - for (int cell_j = 0; cell_j < 2; cell_j++, c++) { - const int i = tile_i*2 + cell_i; - const int j = tile_j*2 + cell_j; - - if (i < output_shape.n_rows && j < output_shape.n_cols) { - outptrs[c] = output + ( - (batch*output_shape.n_rows + i) * output_shape.n_cols + - j) * output_shape.n_channels; - } else { - outptrs[c] = antipadding; - } - } // cell_j - } // cell_i - - for (int n = 0; n < output_shape.n_channels; n++) { - // Read 16 values and progress pointers - T v[16]; - for (int i = 0; i < 16; i++) { - v[i] = *(inptrs[i]++); - } - - // Compute output for 4 pixels - *(outptrs[0]++) = v[ 0] + v[ 1] + v[ 2] + - v[ 4] + v[ 5] + v[ 6] + - v[ 8] + v[ 9] + v[10]; - *(outptrs[1]++) = v[ 1] - v[ 2] - v[ 3] + - v[ 5] - v[ 6] - v[ 7] + - v[ 9] - v[10] - v[11]; - *(outptrs[2]++) = v[ 4] + v[ 5] + v[ 6] - - v[ 8] - v[ 9] - v[10] - - v[12] - v[13] - v[14]; - *(outptrs[3]++) = v[ 5] - v[ 6] - v[ 7] - - v[ 9] + v[10] + v[11] - - v[13] + v[14] + v[15]; - } // output_channel - } // tile_j - } // tile_i - } // batch - - free(antipadding); -} -*/ - -/*****************************************************************************/ -/* -template <typename T> -void winograd::Winograd2x2_3x3GemmOutput_TwoStage<T>::execute( - const Tensor4DShape &output_shape, - T* const matrices[16], T* const output -) { - // Allocate memory for the intermediate matrices - const int tile_M = iceildiv(output_shape.n_rows, 2); - const int tile_N = iceildiv(output_shape.n_cols, 2); - const int n_rows = output_shape.n_batches * tile_M * tile_N; - const int n_channels = output_shape.n_channels; - T* matrices_zf = reinterpret_cast<T*>( - calloc(8 * n_rows * n_channels, sizeof(T)) - ); - - // Perform the first stage transform, computing ZF. - // Specializations should dispatch to different methods based on tail size. - compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); - - // Perform the second stage transform, finishing Z F Z^T - variable dispatch - // based on size of the output. Specialisations can also dispatch based on - // the tail-size of the channel. - if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { - compute_zfzT<true, true, 0>(output_shape, output, matrices_zf); - } else if (output_shape.n_rows % 2) { - compute_zfzT<true, false, 0>(output_shape, output, matrices_zf); - } else if (output_shape.n_cols % 2) { - compute_zfzT<false, true, 0>(output_shape, output, matrices_zf); - } else { - compute_zfzT<false, false, 0>(output_shape, output, matrices_zf); - } - - free(reinterpret_cast<void*>(matrices_zf)); -} - -template <typename T> -template <int channel_tail> -void winograd::Winograd2x2_3x3GemmOutput_TwoStage<T>::compute_zf( - const int n_rows, const int n_channels, - T* output, const T* const input[16] -) { - // Extract 8 output pointers - T* outptr[8]; - for (int i = 0; i < 8; i++) { - outptr[i] = output + i*n_rows*n_channels; - } - - // Copy the 16 input pointers - const T* inptr[16]; - for (int i = 0; i < 16; i++) { - inptr[i] = input[i]; - } - - // For every row of the matrices - for (int i = 0; i < n_rows; i++) { - // For every channel - for (int j = 0; j < n_channels; j++) { - // Extract values from the input matrices - T val[16]; - for (int n = 0; n < 16; n++) { - val[n] = *(inptr[n]++); - } - - // Compute output values - *(outptr[0]++) = val[0] + val[1] + val[2]; - *(outptr[1]++) = val[1] - val[2] - val[3]; - *(outptr[2]++) = val[4] + val[5] + val[6]; - *(outptr[3]++) = val[5] - val[6] - val[7]; - *(outptr[4]++) = val[8] + val[9] + val[10]; - *(outptr[5]++) = val[9] - val[10] - val[11]; - *(outptr[6]++) = val[12] + val[13] + val[14]; - *(outptr[7]++) = val[13] - val[14] - val[15]; - } - } -} - -template <typename T> -template <bool tail_M, bool tail_N, int channel_tail> -void winograd::Winograd2x2_3x3GemmOutput_TwoStage<T>::compute_zfzT( - const Tensor4DShape &output_shape, - T* const output, const T* const input -) { - // Sizing information - const int tile_M = output_shape.n_rows / 2; - const int tile_N = output_shape.n_cols / 2; - - const int n_rows = (output_shape.n_batches * - (tile_M + (tail_M ? 1 : 0)) * - (tile_N + (tail_N ? 1 : 0))); - const int n_channels = output_shape.n_channels; - - // Extract 8 input pointers - const T* inptr[8]; - for (int i = 0; i < 8; i++) { - inptr[i] = input + i*n_rows*n_channels; - } - - // Extract 4 output pointers - T* outptr00 = output; - T* outptr01 = outptr00 + n_channels; - T* outptr10 = outptr00 + output_shape.n_cols * n_channels; - T* outptr11 = outptr10 + n_channels; - - // Progress over the output tiles, generating output values. - for (int batch = 0; batch < output_shape.n_batches; batch++) { - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - T v[8]; - for (int i = 0; i < 8; i++) { - v[i] = *(inptr[i]++); - } - - // Compute the output values and progress the output pointers. - *(outptr00++) = v[0] + v[2] + v[4]; - *(outptr01++) = v[1] + v[3] + v[5]; - *(outptr10++) = v[2] - v[4] - v[6]; - *(outptr11++) = v[3] - v[5] - v[7]; - } - - // Progress the output pointers to the next column - outptr00 += n_channels; - outptr01 += n_channels; - outptr10 += n_channels; - outptr11 += n_channels; - } - - if (tail_N) { - // Only evaluate the left-most columns of the output - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - T v[8]; - for (int i = 0; i < 4; i++) { - v[i * 2] = *inptr[i * 2]; - } - for (int i = 0; i < 8; i++) { - inptr[i]++; - } - - // Compute the output values and progress the output pointers. - *(outptr00++) = v[0] + v[2] + v[4]; - *(outptr10++) = v[2] - v[4] - v[6]; - } - - // Progress the output pointers to the next column - outptr01 += n_channels; // Account for being skipped above - outptr11 += n_channels; // Account for being skipped above - } - - // Progress the output pointers to the next row - outptr00 += output_shape.n_cols * n_channels; - outptr01 += output_shape.n_cols * n_channels; - outptr10 += output_shape.n_cols * n_channels; - outptr11 += output_shape.n_cols * n_channels; - } - - if (tail_M) { - // Only work on the upper row of the output - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - T v[8]; - for (int i = 0; i < 8; i++) { - v[i] = *(inptr[i]++); - } - - // Compute the output values and progress the output pointers. - *(outptr00++) = v[0] + v[2] + v[4]; - *(outptr01++) = v[1] + v[3] + v[5]; - } - - // Progress the output pointers to the next column - outptr00 += n_channels; - outptr01 += n_channels; - outptr10 += 2 * n_channels; // Account for being skipped above - outptr11 += 2 * n_channels; // Account for being skipped above - } - - if (tail_N) { - // Only evaluate the upper-left cell of the output - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - T v[8]; - for (int i = 0; i < 3; i++) { - v[i * 2] = *inptr[i * 2]; - } - for (int i = 0; i < 8; i++) { - inptr[i]++; - } - - // Compute the output values and progress the output pointers. - *(outptr00++) = v[0] + v[2] + v[4]; - } - - // Progress the output pointers to the next column - outptr01 += n_channels; // Account for being skipped above - outptr10 += n_channels; // Account for being skipped above - outptr11 += n_channels; // Account for being skipped above - } - } - } -} -*/ diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp deleted file mode 100644 index bf6ba907b9..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp +++ /dev/null @@ -1,650 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -/* Float implementation for AArch64. - */ -#ifdef __aarch64__ -namespace winograd { - - -template <> -template <> -inline void Winograd2x2_3x3GemmOutput<float>::_execute<false, false, 0>( - const Tensor4DShape &output_shape, - float *output, - const float *input, - const int mstride, - const int matrix_row_stride -) { - const int tile_M = output_shape.n_rows / 2; - const int tile_N = output_shape.n_cols / 2; - int batch = output_shape.n_batches; - float *outptr = output; - - const float *inptr0 = input; - const float *inptr4 = input + 4 * mstride; - const float *inptr8 = input + 8 * mstride; - const float *inptr12 = input + 12 * mstride; - - const size_t col_stride = sizeof(float) * output_shape.n_channels; - const size_t row_stride = col_stride * tile_N * 2; - - asm volatile ( - // Aliases for elements of the input matrix `F` - // V-register Q-register - "F11 .req v0\n" "qF11 .req q0\n" - "F12 .req v1\n" "qF12 .req q1\n" - "F13 .req v2\n" "qF13 .req q2\n" - "F14 .req v3\n" "qF14 .req q3\n" - "F21 .req v4\n" "qF21 .req q4\n" - "F22 .req v5\n" "qF22 .req q5\n" - "F23 .req v6\n" "qF23 .req q6\n" - "F24 .req v7\n" "qF24 .req q7\n" - "F31 .req v8\n" "qF31 .req q8\n" - "F32 .req v9\n" "qF32 .req q9\n" - "F33 .req v10\n" "qF33 .req q10\n" - "F34 .req v11\n" "qF34 .req q11\n" - "F41 .req v12\n" "qF41 .req q12\n" - "F42 .req v13\n" "qF42 .req q13\n" - "F43 .req v14\n" "qF43 .req q14\n" - "F44 .req v15\n" "qF44 .req q15\n" - - // Aliases for elements of the intermediate matrix `FZ` - "FZ11 .req v16\n" - "FZ12 .req v17\n" - "FZ21 .req v18\n" - "FZ22 .req v19\n" - "FZ31 .req v20\n" - "FZ32 .req v21\n" - "FZ41 .req v22\n" - "FZ42 .req v23\n" - - // Aliases for elements of the output matrix `f` (called `g` due to case - // insensitivity of aliases). - " g11 .req v24\n" - "qg11 .req q24\n" - " g12 .req v25\n" - "qg12 .req q25\n" - " g21 .req v26\n" - "qg21 .req q26\n" - " g22 .req v27\n" - "qg22 .req q27\n" - - // Prepare the various strides - "col_stride .req %x[col_stride]\n" - "row_stride .req %x[row_stride]\n" - "row_plus_col_stride .req %x[row_plus_col_stride]\n" - - "mstride1 .req %x[mstride1]\n" - "mstride2 .req %x[mstride2]\n" - "mstride3 .req %x[mstride3]\n" - - "tile_i .req x19\n" // Tile row counter - "tile_j .req x20\n" // Tile column counter - "channel .req x21\n" // Channel counter - - "1:" // Loop over batches - "mov tile_i, %x[tile_M]\n" // Reset tile row counter - - "2:" // Loop over rows of tiles - "mov tile_j, %x[tile_N]\n" // Reset tile column counter - - "3:" // Loop over columns of tiles - // Perform initial loads of the matrix `F` - "ldr qF11, [%x[inptr0]]\n" - "ldr qF12, [%x[inptr0], mstride1]\n" - "ldr qF13, [%x[inptr0], mstride2]\n" - "ldr qF14, [%x[inptr0], mstride3]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - "ldr qF21, [%x[inptr4]]\n" - "ldr qF22, [%x[inptr4], mstride1]\n" - "subs channel, %x[n_channels], #4\n" // Reset channel counter - - "ldr qF23, [%x[inptr4], mstride2]\n" - "ldr qF24, [%x[inptr4], mstride3]\n" - "add %x[inptr4], %x[inptr4], #0x10\n" - "beq 5f\n" // Jump straight to tail if necessary - - "4:" // Loop over channels - "ldr qF31, [%x[inptr8]]\n" - "fadd FZ11.4s, F11.4s, F12.4s\n" - - "ldr qF32, [%x[inptr8], mstride1]\n" - "fsub FZ12.4s, F12.4s, F13.4s\n" - - "ldr qF33, [%x[inptr8], mstride2]\n" - "fadd FZ11.4s, FZ11.4s, F13.4s\n" - - "ldr qF34, [%x[inptr8], mstride3]\n" - "fsub FZ12.4s, FZ12.4s, F14.4s\n" - - "ldr qF41, [%x[inptr12]]\n" - "fadd FZ21.4s, F21.4s, F22.4s\n" - - "ldr qF42, [%x[inptr12], mstride1]\n" - "fsub FZ22.4s, F22.4s, F23.4s\n" - - "ldr qF43, [%x[inptr12], mstride2]\n" - "fadd FZ21.4s, FZ21.4s, F23.4s\n" - - "ldr qF44, [%x[inptr12], mstride3]\n" - "fsub FZ22.4s, FZ22.4s, F24.4s\n" - - "fadd FZ31.4s, F31.4s, F32.4s\n" - "add %x[inptr8], %x[inptr8], #0x10\n" - - "fsub FZ32.4s, F32.4s, F33.4s\n" - "add %x[inptr12], %x[inptr12], #0x10\n" - - "fadd FZ31.4s, FZ31.4s, F33.4s\n" - - "fsub FZ32.4s, FZ32.4s, F34.4s\n" - - "fadd g11.4s, FZ11.4s, FZ21.4s\n" - - "fadd g12.4s, FZ12.4s, FZ22.4s\n" - - "fadd g11.4s, g11.4s, FZ31.4s\n" - - "fadd g12.4s, g12.4s, FZ32.4s\n" - - "ldr qF11, [%x[inptr0]]\n" - "fadd FZ41.4s, F41.4s, F42.4s\n" - - "ldr qF12, [%x[inptr0], mstride1]\n" - "fsub g21.4s, FZ21.4s, FZ31.4s\n" - - "ldr qF13, [%x[inptr0], mstride2]\n" - "fsub FZ42.4s, F42.4s, F43.4s\n" - - "ldr qF14, [%x[inptr0], mstride3]\n" - "str qg11, [%x[outptr]]\n" - - "ldr qF21, [%x[inptr4]]\n" - "fadd FZ41.4s, FZ41.4s, F43.4s\n" - - "ldr qF22, [%x[inptr4], mstride1]\n" - "str qg12, [%x[outptr], col_stride]\n" - - "ldr qF23, [%x[inptr4], mstride2]\n" - "fsub FZ42.4s, FZ42.4s, F44.4s\n" - - "ldr qF24, [%x[inptr4], mstride3]\n" - "fsub g22.4s, FZ22.4s, FZ32.4s\n" - - "fsub g21.4s, g21.4s, FZ41.4s\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "fsub g22.4s, g22.4s, FZ42.4s\n" - "add %x[inptr4], %x[inptr4], #0x10\n" - - "subs channel, channel, #4\n" - - "str qg21, [%x[outptr], row_stride]\n" - - "str qg22, [%x[outptr], row_plus_col_stride]\n" - - "add %x[outptr], %x[outptr], #0x10\n" - - "bne 4b\n" - - "5:" // Channel tail - "ldr qF31, [%x[inptr8]]\n" - "fadd FZ11.4s, F11.4s, F12.4s\n" - - "ldr qF32, [%x[inptr8], mstride1]\n" - "fsub FZ12.4s, F12.4s, F13.4s\n" - - "ldr qF33, [%x[inptr8], mstride2]\n" - "fadd FZ11.4s, FZ11.4s, F13.4s\n" - - "ldr qF34, [%x[inptr8], mstride3]\n" - "fsub FZ12.4s, FZ12.4s, F14.4s\n" - - "ldr qF41, [%x[inptr12]]\n" - "fadd FZ21.4s, F21.4s, F22.4s\n" - - "ldr qF42, [%x[inptr12], mstride1]\n" - "fsub FZ22.4s, F22.4s, F23.4s\n" - - "ldr qF43, [%x[inptr12], mstride2]\n" - "fadd FZ21.4s, FZ21.4s, F23.4s\n" - - "ldr qF44, [%x[inptr12], mstride3]\n" - "fsub FZ22.4s, FZ22.4s, F24.4s\n" - - "fadd FZ31.4s, F31.4s, F32.4s\n" - "add %x[inptr8], %x[inptr8], #0x10\n" - - "fsub FZ32.4s, F32.4s, F33.4s\n" - "add %x[inptr12], %x[inptr12], #0x10\n" - - "fadd FZ31.4s, FZ31.4s, F33.4s\n" - - "fsub FZ32.4s, FZ32.4s, F34.4s\n" - - "fadd g11.4s, FZ11.4s, FZ21.4s\n" - - "fadd g12.4s, FZ12.4s, FZ22.4s\n" - - "fadd g11.4s, g11.4s, FZ31.4s\n" - - "fadd g12.4s, g12.4s, FZ32.4s\n" - - "fadd FZ41.4s, F41.4s, F42.4s\n" - - "fsub g21.4s, FZ21.4s, FZ31.4s\n" - - "fsub FZ42.4s, F42.4s, F43.4s\n" - - "str qg11, [%x[outptr]]\n" - - "fadd FZ41.4s, FZ41.4s, F43.4s\n" - - "str qg12, [%x[outptr], col_stride]\n" - - "fsub FZ42.4s, FZ42.4s, F44.4s\n" - - "fsub g22.4s, FZ22.4s, FZ32.4s\n" - - "fsub g21.4s, g21.4s, FZ41.4s\n" - - "fsub g22.4s, g22.4s, FZ42.4s\n" - - "subs channel, channel, #4\n" - - "str qg21, [%x[outptr], row_stride]\n" - - // Progress input pointers to the next row of the matrix - "add %x[inptr0], %x[inptr0], %x[mrowpad]\n" - "add %x[inptr4], %x[inptr4], %x[mrowpad]\n" - "add %x[inptr8], %x[inptr8], %x[mrowpad]\n" - "add %x[inptr12], %x[inptr12], %x[mrowpad]\n" - - "str qg22, [%x[outptr], row_plus_col_stride]\n" - - "add %x[outptr], %x[outptr], #0x10\n" - - - "add %x[outptr], %x[outptr], col_stride\n" - "subs tile_j, tile_j, #1\n" - "bne 3b\n" - - "add %x[outptr], %x[outptr], row_stride\n" - "subs tile_i, tile_i, #1\n" - "bne 2b\n" - - "subs %w[batch], %w[batch], #1\n" - "bne 1b\n" - - ".unreq F11\n" ".unreq qF11\n" - ".unreq F12\n" ".unreq qF12\n" - ".unreq F13\n" ".unreq qF13\n" - ".unreq F14\n" ".unreq qF14\n" - ".unreq F21\n" ".unreq qF21\n" - ".unreq F22\n" ".unreq qF22\n" - ".unreq F23\n" ".unreq qF23\n" - ".unreq F24\n" ".unreq qF24\n" - ".unreq F31\n" ".unreq qF31\n" - ".unreq F32\n" ".unreq qF32\n" - ".unreq F33\n" ".unreq qF33\n" - ".unreq F34\n" ".unreq qF34\n" - ".unreq F41\n" ".unreq qF41\n" - ".unreq F42\n" ".unreq qF42\n" - ".unreq F43\n" ".unreq qF43\n" - ".unreq F44\n" ".unreq qF44\n" - - ".unreq FZ11\n" ".unreq FZ12\n" - ".unreq FZ21\n" ".unreq FZ22\n" - ".unreq FZ31\n" ".unreq FZ32\n" - ".unreq FZ41\n" ".unreq FZ42\n" - - ".unreq g11\n" ".unreq qg11\n" - ".unreq g12\n" ".unreq qg12\n" - ".unreq g21\n" ".unreq qg21\n" - ".unreq g22\n" ".unreq qg22\n" - - ".unreq col_stride\n" - ".unreq row_stride\n" - ".unreq row_plus_col_stride\n" - - ".unreq mstride1\n" - ".unreq mstride2\n" - ".unreq mstride3\n" - - ".unreq tile_i \n" - ".unreq tile_j \n" - ".unreq channel\n" - - : [batch] "+r" (batch), - [outptr] "+r" (outptr), - [inptr0] "+r" (inptr0), - [inptr4] "+r" (inptr4), - [inptr8] "+r" (inptr8), - [inptr12] "+r" (inptr12) - : [tile_M] "r" (tile_M), - [tile_N] "r" (tile_N), - [n_channels] "r" (output_shape.n_channels), - [col_stride] "r" (col_stride), - [row_stride] "r" (row_stride), - [row_plus_col_stride] "r" (row_stride + col_stride), - [mstride1] "r" (mstride * sizeof(float)), - [mstride2] "r" (2 * mstride * sizeof(float)), - [mstride3] "r" (3 * mstride * sizeof(float)), - [mrowpad] "r" ((matrix_row_stride - output_shape.n_channels) * sizeof(float)) - : "x19", "x20", "x21", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", - "cc", "memory" - ); -} - -template <> -template <bool tail_M, bool tail_N, const int channel_tail> -inline void Winograd2x2_3x3GemmOutput<float>::_execute( - const Tensor4DShape &output_shape, - float *output, - const float *input, - const int mstride, - const int matrix_row_stride -) { - // Compute basic information about the shape of the matrices - const int tile_M = output_shape.n_rows / 2; - const int tile_N = output_shape.n_cols / 2; - const int n_channels = output_shape.n_channels; - - // Extract 16 input pointers - const float* inptr[16]; - for (int i = 0; i < 16; i++) { - inptr[i] = input + i*mstride; - } - - // Extract 4 output pointers - float *outptr00 = output; - float *outptr01 = outptr00 + n_channels; - float *outptr10 = outptr00 + output_shape.n_cols * n_channels; - float *outptr11 = outptr10 + n_channels; - - // Progress over the output tiles, generating output values. - for (int batch = 0; batch < output_shape.n_batches; batch++) { - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - float F[4][4]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - F[i][j] = *(inptr[i*4 + j]++); - } - } - - // Compute the matrix F.Z - float ZF[4][2]; - ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; - ZF[0][1] = F[0][1] - F[0][2] - F[0][3]; - ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; - ZF[1][1] = F[1][1] - F[1][2] - F[1][3]; - ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; - ZF[2][1] = F[2][1] - F[2][2] - F[2][3]; - ZF[3][0] = F[3][0] + F[3][1] + F[3][2]; - ZF[3][1] = F[3][1] - F[3][2] - F[3][3]; - - // Hence compute the output matrix Z^T . (F.Z) - *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; - *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1]; - *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0]; - *(outptr11++) = ZF[1][1] - ZF[2][1] - ZF[3][1]; - } - - // Progress the input pointers to the next row - for (int i = 0; i < 16; i++) { - inptr[i] += matrix_row_stride - n_channels; - } - - // Progress the output pointers to the next column - outptr00 += n_channels; - outptr01 += n_channels; - outptr10 += n_channels; - outptr11 += n_channels; - } - - if (tail_N) { - // Only evaluate the left-most columns of the output - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - float F[4][3]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 3; j++) { - F[i][j] = *(inptr[i*4 + j]++); - } - } - for (int i = 0; i < 4; i++) { - inptr[i*4 + 3]++; - } - - // Compute the matrix F.Z - float ZF[4][1]; - ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; - ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; - ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; - ZF[3][0] = F[3][0] + F[3][1] + F[3][2]; - - // Hence compute the output matrix Z^T . (F.Z) - *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; - *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0]; - } - - // Progress the input pointers to the next row - for (int i = 0; i < 16; i++) { - inptr[i] += matrix_row_stride - n_channels; - } - - // Progress the output pointers to the next column - outptr01 += n_channels; // Account for being skipped above - outptr11 += n_channels; // Account for being skipped above - } - - // Progress the output pointers to the next row - outptr00 += output_shape.n_cols * n_channels; - outptr01 += output_shape.n_cols * n_channels; - outptr10 += output_shape.n_cols * n_channels; - outptr11 += output_shape.n_cols * n_channels; - } - - if (tail_M) { - // Only work on the upper row of the output - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - float F[3][4]; - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4; j++) { - F[i][j] = *(inptr[i*4 + j]++); - } - } - for (int j = 0; j < 4; j++) { - inptr[12 + j]++; - } - - // Compute the matrix F.Z - float ZF[3][2]; - ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; - ZF[0][1] = F[0][1] - F[0][2] - F[0][3]; - ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; - ZF[1][1] = F[1][1] - F[1][2] - F[1][3]; - ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; - ZF[2][1] = F[2][1] - F[2][2] - F[2][3]; - - // Hence compute the output matrix Z^T . (F.Z) - *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; - *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1]; - } - - // Progress the input pointers to the next row - for (int i = 0; i < 16; i++) { - inptr[i] += matrix_row_stride - n_channels; - } - - // Progress the output pointers to the next column - outptr00 += n_channels; - outptr01 += n_channels; - outptr10 += 2 * n_channels; // Account for being skipped above - outptr11 += 2 * n_channels; // Account for being skipped above - } - - if (tail_N) { - // Only evaluate the upper-left cell of the output - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - float F[3][3]; - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 3; j++) { - F[i][j] = *(inptr[i*4 + j]); - } - } - for (int i = 0; i < 16; i++) { - inptr[i]++; - } - - // Compute the matrix F.Z - float ZF[3][1]; - ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; - ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; - ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; - - // Hence compute the output matrix Z^T . (F.Z) - *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; - } - - // Progress the input pointers to the next row - for (int i = 0; i < 16; i++) { - inptr[i] += matrix_row_stride - n_channels; - } - - // Progress the output pointers to the next column - outptr01 += n_channels; // Account for being skipped above - outptr10 += n_channels; // Account for being skipped above - outptr11 += n_channels; // Account for being skipped above - } - } - } -} - -/*****************************************************************************/ -template <> -inline void Winograd2x2_3x3GemmOutput<float>::execute( - const Tensor4DShape &output_shape, - float* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - float* const output -) { - // Dispatch to an appropriate implementation based on the shape of the output - // tensor. - if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { - constexpr bool tail_M = true, tail_N = true; - switch (output_shape.n_channels % 4) { - case 0: - _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 1: - _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 2: - _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 3: - _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - default: - assert(0); - break; - } - } else if (output_shape.n_rows % 2) { - constexpr bool tail_M = true, tail_N = false; - switch (output_shape.n_channels % 4) { - case 0: - _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 1: - _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 2: - _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 3: - _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - default: - assert(0); - break; - } - } else if (output_shape.n_cols % 2) { - constexpr bool tail_M = false, tail_N = true; - switch (output_shape.n_channels % 4) { - case 0: - _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 1: - _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 2: - _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 3: - _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - default: - assert(0); - break; - - } - } else { - constexpr bool tail_M = false, tail_N = false; - switch (output_shape.n_channels % 4) { - case 0: - _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 1: - _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 2: - _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 3: - _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - default: - assert(0); - break; - - } - } -} -/*****************************************************************************/ - -} // namespace winograd -#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp deleted file mode 100644 index f551b12b52..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp +++ /dev/null @@ -1,655 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __aarch64__ - -/*****************************************************************************/ -// Compute ZF specializations - -template <> -template <> -inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::compute_zf<0>( - const int n_rows, const int n_channels, - float* output, const float* const input[16] -) { - // Make copies of some variables - int row = n_rows; - float* outptr = output; - const float* inptr = input[0]; - - // Perform the transformation - asm volatile ( - // "inptr0 .req %x[inptr]\n" - "inptr1 .req x0\n" - "inptr2 .req x1\n" - "inptr3 .req x2\n" - "inptr4 .req x3\n" - "inptr5 .req x4\n" - "inptr6 .req x5\n" - "inptr7 .req x6\n" - "inptr8 .req x7\n" - "inptr9 .req x8\n" - "inptr10 .req x9\n" - "inptr11 .req x10\n" - "inptr12 .req x11\n" - "inptr13 .req x12\n" - "inptr14 .req x13\n" - "inptr15 .req x14\n" - - // "outptr0 .req %x[outptr]\n" - "outptr1 .req x15\n" - "outptr2 .req x16\n" - "outptr3 .req x17\n" - "outptr4 .req x18\n" - "outptr5 .req x19\n" - "outptr6 .req x20\n" - "outptr7 .req x21\n" - - // Compute additional pointers into the input and output matrices. - "mstride .req x22\n" // Matrix stride - "mul mstride, %x[row], %x[n_channels]\n" - "lsl mstride, mstride, #2\n" // * sizeof(float) - - "add inptr1, %x[inptr], mstride\n" - "add inptr2, %x[inptr], mstride, LSL #1\n" - "add inptr3, inptr2, mstride\n" - "add inptr4, inptr3, mstride\n" - "add inptr5, inptr4, mstride\n" - "add inptr6, inptr5, mstride\n" - "add inptr7, inptr6, mstride\n" - "add inptr8, inptr7, mstride\n" - "add inptr9, inptr8, mstride\n" - "add inptr10, inptr9, mstride\n" - "add inptr11, inptr10, mstride\n" - "add inptr12, inptr11, mstride\n" - "add inptr13, inptr12, mstride\n" - "add inptr14, inptr13, mstride\n" - "add inptr15, inptr14, mstride\n" - - "add outptr1, %[outptr], mstride\n" - "add outptr2, outptr1, mstride\n" - "add outptr3, outptr2, mstride\n" - "add outptr4, outptr3, mstride\n" - "add outptr5, outptr4, mstride\n" - "add outptr6, outptr5, mstride\n" - "add outptr7, outptr6, mstride\n" - - ".unreq mstride\n" - - "column .req x22\n" // Column loop counter - - "1:" // Loop over rows - "ldr q0, [%x[inptr]], #0x10\n" - "ldr q1, [inptr1], #0x10\n" - "ldr q2, [inptr2], #0x10\n" - "ldr q3, [inptr3], #0x10\n" - "ldr q4, [inptr4], #0x10\n" - "ldr q5, [inptr5], #0x10\n" - "ldr q6, [inptr6], #0x10\n" - "ldr q7, [inptr7], #0x10\n" - "subs column, %x[n_channels], #0x4\n" - "beq 3f\n" - - "2:" // Loop over columns - "ldr q8, [inptr8], #0x10\n" - "prfm pldl1keep, [%x[inptr], #196]\n" - "fadd v16.4s, v0.4s, v1.4s\n" - - "ldr q9, [inptr9], #0x10\n" - "prfm pldl1keep, [inptr1, #196]\n" - "fsub v17.4s, v1.4s, v2.4s\n" - - "ldr q10, [inptr10], #0x10\n" - "prfm pldl1keep, [inptr2, #196]\n" - "fadd v16.4s, v16.4s, v2.4s\n" - - "ldr q11, [inptr11], #0x10\n" - "prfm pldl1keep, [inptr3, #196]\n" - "fsub v17.4s, v17.4s, v3.4s\n" - - "ldr q12, [inptr12], #0x10\n" - "prfm pldl1keep, [inptr4, #196]\n" - "str q16, [%x[outptr]], #0x10\n" - - "ldr q13, [inptr13], #0x10\n" - "prfm pldl1keep, [inptr5, #196]\n" - "str q17, [outptr1], #0x10\n" - - "ldr q14, [inptr14], #0x10\n" - "prfm pldl1keep, [inptr6, #196]\n" - "fadd v16.4s, v4.4s, v5.4s\n" - - "ldr q15, [inptr15], #0x10\n" - "prfm pldl1keep, [inptr7, #196]\n" - "fsub v17.4s, v5.4s, v6.4s\n" - - "ldr q0, [%x[inptr]], #0x10\n" - "prfm pldl1keep, [inptr8, #196]\n" - "fadd v16.4s, v16.4s, v6.4s\n" - - "ldr q1, [inptr1], #0x10\n" - "prfm pldl1keep, [inptr9, #196]\n" - "fsub v17.4s, v17.4s, v7.4s\n" - - "ldr q2, [inptr2], #0x10\n" - "prfm pldl1keep, [inptr10, #196]\n" - "str q16, [outptr2], #0x10\n" - - "ldr q3, [inptr3], #0x10\n" - "prfm pldl1keep, [inptr11, #196]\n" - "str q17, [outptr3], #0x10\n" - - "ldr q4, [inptr4], #0x10\n" - "prfm pldl1keep, [inptr12, #196]\n" - "fadd v16.4s, v8.4s, v9.4s\n" - - "ldr q5, [inptr5], #0x10\n" - "prfm pldl1keep, [inptr13, #196]\n" - "fsub v17.4s, v9.4s, v10.4s\n" - - "ldr q6, [inptr6], #0x10\n" - "prfm pldl1keep, [inptr14, #196]\n" - "fadd v16.4s, v16.4s, v10.4s\n" - - "ldr q7, [inptr7], #0x10\n" - "prfm pldl1keep, [inptr15, #196]\n" - "fsub v17.4s, v17.4s, v11.4s\n" - - "str q16, [outptr4], #0x10\n" - "fadd v16.4s, v12.4s, v13.4s\n" - "fsub v18.4s, v13.4s, v14.4s\n" - - "str q17, [outptr5], #0x10\n" - "fadd v16.4s, v16.4s, v14.4s\n" - "fsub v18.4s, v18.4s, v15.4s\n" - - "str q16, [outptr6], #0x10\n" - "subs column, column, #0x4\n" - - "str q18, [outptr7], #0x10\n" - "bne 2b\n" - - "3:" // Tail - "ldr q8, [inptr8], #0x10\n" - "prfm pldl1keep, [%x[inptr], #196]\n" - "fadd v16.4s, v0.4s, v1.4s\n" - - "ldr q9, [inptr9], #0x10\n" - "prfm pldl1keep, [inptr1, #196]\n" - "fsub v17.4s, v1.4s, v2.4s\n" - - "ldr q10, [inptr10], #0x10\n" - "prfm pldl1keep, [inptr2, #196]\n" - "fadd v16.4s, v16.4s, v2.4s\n" - - "ldr q11, [inptr11], #0x10\n" - "prfm pldl1keep, [inptr3, #196]\n" - "fsub v17.4s, v17.4s, v3.4s\n" - - "ldr q12, [inptr12], #0x10\n" - "prfm pldl1keep, [inptr4, #196]\n" - "str q16, [%x[outptr]], #0x10\n" - - "ldr q13, [inptr13], #0x10\n" - "prfm pldl1keep, [inptr5, #196]\n" - "str q17, [outptr1], #0x10\n" - - "ldr q14, [inptr14], #0x10\n" - "prfm pldl1keep, [inptr6, #196]\n" - "fadd v16.4s, v4.4s, v5.4s\n" - - "ldr q15, [inptr15], #0x10\n" - "prfm pldl1keep, [inptr7, #196]\n" - "fsub v17.4s, v5.4s, v6.4s\n" - - "prfm pldl1keep, [inptr8, #196]\n" - "prfm pldl1keep, [inptr9, #196]\n" - "fadd v16.4s, v16.4s, v6.4s\n" - - "prfm pldl1keep, [inptr10, #196]\n" - "prfm pldl1keep, [inptr11, #196]\n" - "fsub v17.4s, v17.4s, v7.4s\n" - - "prfm pldl1keep, [inptr12, #196]\n" - "prfm pldl1keep, [inptr13, #196]\n" - "str q16, [outptr2], #0x10\n" - - "prfm pldl1keep, [inptr14, #196]\n" - "prfm pldl1keep, [inptr15, #196]\n" - "str q17, [outptr3], #0x10\n" - - "fadd v16.4s, v8.4s, v9.4s\n" - "fsub v17.4s, v9.4s, v10.4s\n" - - "fadd v16.4s, v16.4s, v10.4s\n" - "fsub v17.4s, v17.4s, v11.4s\n" - - "str q16, [outptr4], #0x10\n" - "fadd v16.4s, v12.4s, v13.4s\n" - "fsub v18.4s, v13.4s, v14.4s\n" - - "str q17, [outptr5], #0x10\n" - "fadd v16.4s, v16.4s, v14.4s\n" - "fsub v18.4s, v18.4s, v15.4s\n" - - "str q16, [outptr6], #0x10\n" - "str q18, [outptr7], #0x10\n" - - "subs %x[row], %x[row], #0x1\n" - "bne 1b\n" - - ".unreq inptr1\n" - ".unreq inptr2\n" - ".unreq inptr3\n" - ".unreq inptr4\n" - ".unreq inptr5\n" - ".unreq inptr6\n" - ".unreq inptr7\n" - ".unreq inptr8\n" - ".unreq inptr9\n" - ".unreq inptr10\n" - ".unreq inptr11\n" - ".unreq inptr12\n" - ".unreq inptr13\n" - ".unreq inptr14\n" - ".unreq inptr15\n" - ".unreq outptr1\n" - ".unreq outptr2\n" - ".unreq outptr3\n" - ".unreq outptr4\n" - ".unreq outptr5\n" - ".unreq outptr6\n" - ".unreq outptr7\n" - - : [row] "+r" (row), - [inptr] "+r" (inptr), - [outptr] "+r" (outptr) - : [n_channels] "r" (n_channels), - [sizeof_float] "i" (sizeof(float)) - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", - "q12", "q13", "q14", "q15", "q16", "q17", "x0", "x1", "x2", "x3", "x4", - "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", - "x16", "x17", "x18", "x19", "x20", "x21", "x22", "cc", "memory" - ); -} - -/*****************************************************************************/ -// Compute ZFZ^T specializations - -template <> -template <> -inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::compute_zfzT<false, false, 0>( - const Tensor4DShape &output_shape, - float* const output, const float* const input -) { - const int tile_M = output_shape.n_rows / 2; - const int tile_N = output_shape.n_cols / 2; - int batch = output_shape.n_batches; - float *outptr = output; - const float *inptr = input; - - asm volatile ( - // Compute input pointers - "inptr1 .req x0\n" - "inptr2 .req x1\n" - "inptr3 .req x2\n" - "inptr4 .req x3\n" - "inptr5 .req x4\n" - "inptr6 .req x5\n" - "inptr7 .req x6\n" - "inptr8 .req x7\n" - - "mstride .req x8\n" - "mul mstride, %x[tile_M], %x[tile_N]\n" - "mul mstride, mstride, %x[n_channels]\n" - "lsl mstride, mstride, #2\n" // * sizeof(float) - - "add inptr1, %[inptr], mstride\n" - "add inptr2, inptr1, mstride\n" - "add inptr3, inptr2, mstride\n" - "add inptr4, inptr3, mstride\n" - "add inptr5, inptr4, mstride\n" - "add inptr6, inptr5, mstride\n" - "add inptr7, inptr6, mstride\n" - "add inptr8, inptr7, mstride\n" - - ".unreq mstride\n" - - // Compute initial output pointers - "outptr01 .req x8\n" - "outptr10 .req x9\n" - "outptr11 .req x10\n" - - "add outptr01, %x[outptr], %x[n_channels], LSL #2\n" - "add outptr10, %x[outptr], %x[row_stride], LSL #2\n" - "add outptr11, outptr10, %x[n_channels], LSL #2\n" - - "tile_i .req x11\n" - "tile_j .req x12\n" - "channel .req x13\n" - - "1:" // Loop over batches - "mov tile_i, %x[tile_M]\n" - - "2:" // Loop over rows of output tiles - "mov tile_j, %x[tile_N]\n" - - "3:" // Loop over columns of output tiles - "ldr q0, [%x[inptr]], #0x10\n" - "ldr q2, [inptr2], #0x10\n" - "subs channel, %x[n_channels], #0x4\n" - - "ldr q1, [inptr1], #0x10\n" - "ldr q3, [inptr3], #0x10\n" - "beq 6f\n" - - "4:" - "ldr q4, [inptr4], #0x10\n" - "ldr q5, [inptr5], #0x10\n" - "fadd v16.4s, v0.4s, v2.4s\n" - - "ldr q6, [inptr6], #0x10\n" - "ldr q7, [inptr7], #0x10\n" - "fadd v17.4s, v1.4s, v3.4s\n" - - "ldr q8, [%x[inptr]], #0x10\n" - "ldr q10, [inptr2], #0x10\n" - "fadd v16.4s, v16.4s, v4.4s\n" - - "ldr q9, [inptr1], #0x10\n" - "ldr q11, [inptr3], #0x10\n" - "fadd v17.4s, v17.4s, v5.4s\n" - - "str q16, [%x[outptr]], #0x10\n" - "prfm pldl1strm, [%x[inptr], #196]\n" - "fsub v18.4s, v2.4s, v4.4s\n" - - "str q17, [outptr01], #0x10\n" - "prfm pldl1strm, [inptr2, #196]\n" - "fsub v19.4s, v3.4s, v5.4s\n" - - "prfm pldl1strm, [inptr1, #196]\n" - "prfm pldl1strm, [inptr3, #196]\n" - "fsub v18.4s, v18.4s, v6.4s\n" - - "prfm pldl1strm, [inptr4, #196]\n" - "prfm pldl1strm, [inptr5, #196]\n" - "fsub v19.4s, v19.4s, v7.4s\n" - - "str q18, [outptr10], #0x10\n" - "prfm pldl1strm, [inptr6, #196]\n" - "prfm pldl1strm, [inptr7, #196]\n" - - "subs channel, channel, #0x4\n" - - "str q19, [outptr11], #0x10\n" - "beq 6f\n" // Branch to tail - - "ldr q12, [inptr4], #0x10\n" - "ldr q13, [inptr5], #0x10\n" - "fadd v16.4s, v8.4s, v10.4s\n" - - "ldr q14, [inptr6], #0x10\n" - "ldr q15, [inptr7], #0x10\n" - "fadd v17.4s, v9.4s, v11.4s\n" - - "ldr q0, [%x[inptr]], #0x10\n" - "ldr q2, [inptr2], #0x10\n" - "fadd v16.4s, v16.4s, v12.4s\n" - - "ldr q1, [inptr1], #0x10\n" - "ldr q3, [inptr3], #0x10\n" - "fadd v17.4s, v17.4s, v13.4s\n" - - "str q16, [%x[outptr]], #0x10\n" - "prfm pldl1strm, [%x[inptr], #196]\n" - "fsub v18.4s, v10.4s, v12.4s\n" - - "str q17, [outptr01], #0x10\n" - "prfm pldl1strm, [inptr2, #196]\n" - "fsub v19.4s, v11.4s, v13.4s\n" - - "prfm pldl1strm, [inptr1, #196]\n" - "prfm pldl1strm, [inptr3, #196]\n" - "fsub v18.4s, v18.4s, v14.4s\n" - - "prfm pldl1strm, [inptr4, #196]\n" - "prfm pldl1strm, [inptr5, #196]\n" - "fsub v19.4s, v19.4s, v15.4s\n" - - "str q18, [outptr10], #0x10\n" - "prfm pldl1strm, [inptr6, #196]\n" - "prfm pldl1strm, [inptr7, #196]\n" - - "subs channel, channel, #0x4\n" - - "str q19, [outptr11], #0x10\n" - "bne 4b\n" // Continue loop - - "5:" // Tail - "ldr q12, [inptr4], #0x10\n" - "ldr q13, [inptr5], #0x10\n" - "fadd v16.4s, v8.4s, v10.4s\n" - - "ldr q14, [inptr6], #0x10\n" - "ldr q15, [inptr7], #0x10\n" - "fadd v17.4s, v9.4s, v11.4s\n" - - "fadd v16.4s, v16.4s, v12.4s\n" - - "fadd v17.4s, v17.4s, v13.4s\n" - - "str q16, [%x[outptr]], #0x10\n" - "fsub v18.4s, v10.4s, v12.4s\n" - "fsub v19.4s, v11.4s, v13.4s\n" - - "str q17, [outptr01], #0x10\n" - "fsub v18.4s, v18.4s, v14.4s\n" - "fsub v19.4s, v19.4s, v15.4s\n" - - "str q18, [outptr10], #0x10\n" - "str q19, [outptr11], #0x10\n" - "b 7f\n" - - "6:" // Tail - "ldr q4, [inptr4], #0x10\n" - "ldr q5, [inptr5], #0x10\n" - "fadd v16.4s, v0.4s, v2.4s\n" - - "ldr q6, [inptr6], #0x10\n" - "ldr q7, [inptr7], #0x10\n" - "fadd v17.4s, v1.4s, v3.4s\n" - - "fadd v16.4s, v16.4s, v4.4s\n" - - "fadd v17.4s, v17.4s, v5.4s\n" - - "str q16, [%x[outptr]], #0x10\n" - "fsub v18.4s, v2.4s, v4.4s\n" - "fsub v19.4s, v3.4s, v5.4s\n" - - "str q17, [outptr01], #0x10\n" - "fsub v18.4s, v18.4s, v6.4s\n" - "fsub v19.4s, v19.4s, v7.4s\n" - - "str q18, [outptr10], #0x10\n" - "str q19, [outptr11], #0x10\n" - - "7:" - "add %x[outptr], %x[outptr], %x[n_channels], LSL #2\n" - "add outptr01, outptr01, %x[n_channels], LSL #2\n" - "add outptr10, outptr10, %x[n_channels], LSL #2\n" - "add outptr11, outptr11, %x[n_channels], LSL #2\n" - - "subs tile_j, tile_j, #1\n" - "bne 3b\n" - - // Progress the output pointers to the new row - "add %x[outptr], %x[outptr], %x[row_stride], LSL #2\n" - "add outptr01, outptr01, %x[row_stride], LSL #2\n" - "add outptr10, outptr10, %x[row_stride], LSL #2\n" - "add outptr11, outptr11, %x[row_stride], LSL #2\n" - - "subs tile_i, tile_i, #1\n" - "bne 2b\n" - - "subs %[batch], %[batch], #1\n" - "bne 1b\n" - "5:" - - ".unreq inptr1\n" - ".unreq inptr2\n" - ".unreq inptr3\n" - ".unreq inptr4\n" - ".unreq inptr5\n" - ".unreq inptr6\n" - ".unreq inptr7\n" - ".unreq inptr8\n" - ".unreq outptr01\n" - ".unreq outptr10\n" - ".unreq outptr11\n" - : [batch] "+r" (batch), - [outptr] "+r" (outptr), - [inptr] "+r" (inptr) - : [tile_M] "r" (tile_M), - [tile_N] "r" (tile_N), - [n_channels] "r" (output_shape.n_channels), - [row_stride] "r" (output_shape.n_cols * output_shape.n_channels) - : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", - "x12", "x13", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "cc", "memory" - ); -} -/*****************************************************************************/ - -/*****************************************************************************/ -template <> -inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::execute( - const Tensor4DShape &output_shape, - float* const matrices[16], float* const output -) { - // profiler prof; - - // Allocate memory for the intermediate matrices - const int tile_M = iceildiv(output_shape.n_rows, 2); - const int tile_N = iceildiv(output_shape.n_cols, 2); - const int n_rows = output_shape.n_batches * tile_M * tile_N; - const int n_channels = output_shape.n_channels; - float* matrices_zf = reinterpret_cast<float*>( - calloc(8 * n_rows * n_channels, sizeof(float)) - ); - - // Perform the first stage transform, computing ZF. - const auto f_compute_zf = [&] () { - switch (n_channels % 4) { - case 0: - compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); - break; - case 1: - compute_zf<1>(n_rows, n_channels, matrices_zf, matrices); - break; - case 2: - compute_zf<2>(n_rows, n_channels, matrices_zf, matrices); - break; - case 3: - compute_zf<3>(n_rows, n_channels, matrices_zf, matrices); - }; - }; - // prof("Compute ZF", f_compute_zf, 16 * n_rows * n_channels * sizeof(float), 0, 8 * n_rows * n_channels * sizeof(float)); - f_compute_zf(); - - // Perform the second stage transform, finishing Z F Z^T - variable dispatch - // based on size of the output and the channel tail. - const auto f_compute_zfzT = [&] () { - if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { - constexpr bool tail_M = true, tail_N = true; - switch (n_channels % 4) { - case 0: - compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf); - break; - case 1: - compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf); - break; - case 2: - compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf); - break; - case 3: - compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf); - } - } else if (output_shape.n_rows % 2) { - constexpr bool tail_M = true, tail_N = false; - switch (n_channels % 4) { - case 0: - compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf); - break; - case 1: - compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf); - break; - case 2: - compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf); - break; - case 3: - compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf); - } - } else if (output_shape.n_cols % 2) { - constexpr bool tail_M = false, tail_N = true; - switch (n_channels % 4) { - case 0: - compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf); - break; - case 1: - compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf); - break; - case 2: - compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf); - break; - case 3: - compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf); - } - } else { - constexpr bool tail_M = false, tail_N = false; - switch (n_channels % 4) { - case 0: - compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf); - break; - case 1: - compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf); - break; - case 2: - compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf); - break; - case 3: - compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf); - } - } - }; - // prof("Compute ZFZT", f_compute_zfzT, 8 * n_rows * n_channels * sizeof(float), 0, 4 * n_rows * n_channels * sizeof(float)); - f_compute_zfzT(); - - free(reinterpret_cast<void*>(matrices_zf)); -} -/*****************************************************************************/ - -#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp new file mode 100644 index 0000000000..e7907d18c0 --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp @@ -0,0 +1,238 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "transforms/output.hpp" +#include "winograd_gemm.hpp" +#include "arm.hpp" + +namespace winograd +{ + +using Transform = WinogradGEMM<2, 2, 3, 3>::OutputTransform<float>; + +template <> +template <> +int Transform::ops_performed(const Tensor4DShape &shape) +{ + // NOTE: Cost in FLOPs rather than instructions or uops. + const int tile_M = iceildiv(shape.n_rows, 2); + const int tile_N = iceildiv(shape.n_cols, 2); + return 24 * tile_M * tile_N * shape.n_channels; +} + +/* F(2x2, 3x3) constructs 2x2 output tiles from a 3x3 convolution. Since we use + * enough tiles to cover the output space each output tile may contain 0 or 1 + * padded values to the right and bottom columns or rows of the tile, e.g.: + * + * ___ ___ + * | | | X| + * |___| |__X| + * + * ___ ___ + * | | | X| + * |X_X| |X_X| + * + * + * We provide a specialised output transform for each of these instances. + * Consequently we below construct an array of the various padding options, the + * array contains pointers to the specific implementations. + */ +template <> +template <> +template <int pad_bottom, int pad_right> +void Transform::process_tile( + const int n_channels, + const float* const matrix_base, + const int matrix_stride, + float* const output, + const int output_row_stride, + const int output_col_stride +) +{ + constexpr int cells_i = 2 - pad_bottom; + constexpr int cells_j = 2 - pad_right; + + // Construct a map to the output cells + float *outptrs[cells_i][cells_j]; + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + outptrs[i][j] = output + i*output_row_stride + j*output_col_stride; + } + } + const float *inptr = matrix_base; + + // For each channel of the output + int channels_remaining = n_channels; +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used and computed during this transform + float32x4_t F[4][4], FZ[4][2], f[2][2]; + + // Read a 4x4 tile in the Winograd domain + for (int i = 0, m = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++, m++) + { + F[i][j] = vld1q_f32(inptr + m*matrix_stride); + } + } + inptr += 4; + + // Compute the matrix F Z + for (int i = 0; i < 4; i++) + { + // FZ[i][0] = F[i][0] + F[i][1] + F[i][2]; + FZ[i][0] = vaddq_f32(vaddq_f32(F[i][0], F[i][1]), F[i][2]); + + // FZ[i][1] = F[i][1] - F[i][2] - F[i][3]; + FZ[i][1] = vsubq_f32(vsubq_f32(F[i][1], F[i][2]), F[i][3]); + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 2; j++) + { + // f[0][j] = FZ[0][j] + FZ[1][j] + FZ[2][j]; + f[0][j] = vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), FZ[2][j]); + + // f[1][j] = FZ[1][j] - FZ[2][j] - FZ[3][j]; + f[1][j] = vsubq_f32(vsubq_f32(FZ[1][j], FZ[2][j]), FZ[3][j]); + } + + // Write out the output tile + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + vst1q_f32(outptrs[i][j], f[i][j]); + outptrs[i][j] += 4; + } + } + } +#endif // __aarch64__ +#ifdef __arm_any__ + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used and computed during this transform + float32x2_t F[4][4], FZ[4][2], f[2][2]; + + // Read a 4x4 tile in the Winograd domain + for (int i = 0, m = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++, m++) + { + F[i][j] = vld1_f32(inptr + m*matrix_stride); + } + } + inptr += 2; + + // Compute the matrix F Z + for (int i = 0; i < 4; i++) + { + // FZ[i][0] = F[i][0] + F[i][1] + F[i][2]; + FZ[i][0] = vadd_f32(vadd_f32(F[i][0], F[i][1]), F[i][2]); + + // FZ[i][1] = F[i][1] - F[i][2] - F[i][3]; + FZ[i][1] = vsub_f32(vsub_f32(F[i][1], F[i][2]), F[i][3]); + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 2; j++) + { + // f[0][j] = FZ[0][j] + FZ[1][j] + FZ[2][j]; + f[0][j] = vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), FZ[2][j]); + + // f[1][j] = FZ[1][j] - FZ[2][j] - FZ[3][j]; + f[1][j] = vsub_f32(vsub_f32(FZ[1][j], FZ[2][j]), FZ[3][j]); + } + + // Write out the output tile + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + vst1_f32(outptrs[i][j], f[i][j]); + outptrs[i][j] += 2; + } + } + } +#endif // __arm_any__ + for (; channels_remaining; channels_remaining--) + { + // Matrices used and computed during this transform + float F[4][4], FZ[4][2], f[2][2]; + + // Read a 4x4 tile in the Winograd domain + for (int i = 0, m = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++, m++) + { + F[i][j] = *(inptr + m*matrix_stride); + } + } + inptr++; + + // Compute the matrix F Z + for (int i = 0; i < 4; i++) + { + FZ[i][0] = F[i][0] + F[i][1] + F[i][2]; + FZ[i][1] = F[i][1] - F[i][2] - F[i][3]; + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 2; j++) + { + f[0][j] = FZ[0][j] + FZ[1][j] + FZ[2][j]; + f[1][j] = FZ[1][j] - FZ[2][j] - FZ[3][j]; + } + + // Write out the output tile + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + *(outptrs[i][j]++) = f[i][j]; + } + } + } +} + +template <> +template <> +const Transform::TileFn Transform::tile_fns[max_pad_bottom][max_pad_right] = +{ + { + Transform::template process_tile<0, 0>, // No padding + Transform::template process_tile<0, 1>, // Right padding + }, + { + Transform::template process_tile<1, 0>, // Bottom padding + Transform::template process_tile<1, 1>, // Bottom and right padding + } +}; + +template struct WinogradGEMM<2, 2, 3, 3>::OutputTransform<float>; +} // namespace winograd diff --git a/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp new file mode 100644 index 0000000000..483e5c110b --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp @@ -0,0 +1,299 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "transforms/output.hpp" +#include "winograd_gemm.hpp" +#include "arm.hpp" + +namespace winograd +{ + +using Transform = WinogradGEMM<4, 4, 3, 3>::OutputTransform<float>; + +template <> +template <> +int Transform::ops_performed(const Tensor4DShape &shape) +{ + // NOTE: Cost in FLOPs rather than instructions or uops. + const int tile_M = iceildiv(shape.n_rows, 4); + const int tile_N = iceildiv(shape.n_cols, 4); + return 170 * tile_M * tile_N * shape.n_channels; +} + +// Instantiate cost methods +template int Transform::ops_performed(const Tensor4DShape&); + +/* F(4x4, 3x3) constructs 4x4 output tiles from a 3x3 convolution. Since we use + * enough tiles to cover the output space each output tile may contain up to 3 + * padded values to the right and bottom columns or rows of the tile, e.g.: +* +* ________ ________ ________ ________ +* | | | X| | X X| | X X X| +* | | | X| | X X| | X X X| +* | | | X| | X X| | X X X| +* |_______| |______X| |____X_X| |__X_X_X| +* +* ________ ________ ________ ________ +* | | | X| | X X| | X X X| +* | | | X| | X X| | X X X| +* | | | X| | X X| | X X X| +* |X_X_X_X| |X_X_X_X| |X_X_X_X| |X_X_X_X| +* +* ________ ________ ________ ________ +* | | | X| | X X| | X X X| +* | | | X| | X X| | X X X| +* |X X X X| |X X X X| |X X X X| |X X X X| +* |X_X_X_X| |X_X_X_X| |X_X_X_X| |X_X_X_X| +* +* ________ ________ ________ ________ +* | | | X| | X X| | X X X| +* |X X X X| |X X X X| |X X X X| |X X X X| +* |X X X X| |X X X X| |X X X X| |X X X X| +* |X_X_X_X| |X_X_X_X| |X_X_X_X| |X_X_X_X| +* +* +* We provide a specialised output transform for each of these instances. +*/ +template <> +template <> +template <int pad_bottom, int pad_right> +void Transform::process_tile( + const int n_channels, + const float* const matrix_base, + const int matrix_stride, + float* const output, + const int output_row_stride, + const int output_col_stride +) +{ + constexpr int cells_i = 4 - pad_bottom; + constexpr int cells_j = 4 - pad_right; + + // Construct a map to the output cells + float *outptrs[cells_i][cells_j]; + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + outptrs[i][j] = output + i*output_row_stride + j*output_col_stride; + } + } + const float *inptr = matrix_base; + + // For each channel of the output + int channels_remaining = n_channels; +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used and computed during this transform + float32x4_t F[6][6], FZ[6][4], f[4][4]; + + // Read a 6x6 tile in the Winograd domain + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + F[i][j] = vld1q_f32(inptr + m*matrix_stride); + } + } + inptr += 4; + + // Compute the matrix F Z + for (int i = 0; i < 6; i++) + { + // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; + FZ[i][0] = vaddq_f32(vaddq_f32(vaddq_f32(F[i][0], F[i][1]), vaddq_f32(F[i][2], F[i][3])), F[i][4]); + + // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4]; + FZ[i][1] = vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 2.0f); + + // FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4]; + FZ[i][2] = vmlaq_n_f32(vaddq_f32(F[i][1], F[i][2]), vaddq_f32(F[i][3], F[i][4]), 4.0f); + + // FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5]; + FZ[i][3] = vaddq_f32(vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 8.0f), F[i][5]); + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 4; j++) + { + // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; + f[0][j] = vaddq_f32(vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), vaddq_f32(FZ[2][j], FZ[3][j])), FZ[4][j]); + + // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j]; + f[1][j] = vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 2.0f); + + // f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j]; + f[2][j] = vmlaq_n_f32(vaddq_f32(FZ[1][j], FZ[2][j]), vaddq_f32(FZ[3][j], FZ[4][j]), 4.0f); + + // f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j]; + f[3][j] = vaddq_f32(vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 8.0f), FZ[5][j]); + } + + // Write out the output tile + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + vst1q_f32(outptrs[i][j], f[i][j]); + outptrs[i][j] += 4; + } + } + } +#endif // __aarch64__ +#ifdef __arm_any__ + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used and computed during this transform + float32x2_t F[6][6], FZ[6][4], f[4][4]; + + // Read a 6x6 tile in the Winograd domain + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + F[i][j] = vld1_f32(inptr + m*matrix_stride); + } + } + inptr += 2; + + // Compute the matrix F Z + for (int i = 0; i < 6; i++) + { + // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; + FZ[i][0] = vadd_f32(vadd_f32(vadd_f32(F[i][0], F[i][1]), vadd_f32(F[i][2], F[i][3])), F[i][4]); + + // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4]; + FZ[i][1] = vmla_n_f32(vsub_f32(F[i][1], F[i][2]), vsub_f32(F[i][3], F[i][4]), 2.0f); + + // FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4]; + FZ[i][2] = vmla_n_f32(vadd_f32(F[i][1], F[i][2]), vadd_f32(F[i][3], F[i][4]), 4.0f); + + // FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5]; + FZ[i][3] = vadd_f32(vmla_n_f32(vsub_f32(F[i][1], F[i][2]), vsub_f32(F[i][3], F[i][4]), 8.0f), F[i][5]); + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 4; j++) + { + // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; + f[0][j] = vadd_f32(vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), vadd_f32(FZ[2][j], FZ[3][j])), FZ[4][j]); + + // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j]; + f[1][j] = vmla_n_f32(vsub_f32(FZ[1][j], FZ[2][j]), vsub_f32(FZ[3][j], FZ[4][j]), 2.0f); + + // f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j]; + f[2][j] = vmla_n_f32(vadd_f32(FZ[1][j], FZ[2][j]), vadd_f32(FZ[3][j], FZ[4][j]), 4.0f); + + // f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j]; + f[3][j] = vadd_f32(vmla_n_f32(vsub_f32(FZ[1][j], FZ[2][j]), vsub_f32(FZ[3][j], FZ[4][j]), 8.0f), FZ[5][j]); + } + + // Write out the output tile + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + vst1_f32(outptrs[i][j], f[i][j]); + outptrs[i][j] += 2; + } + } + } +#endif + for (; channels_remaining; channels_remaining--) + { + // Matrices used and computed during this transform + float F[6][6], FZ[6][4], f[4][4]; + + // Read a 6x6 tile in the Winograd domain + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + F[i][j] = *(inptr + m*matrix_stride); + } + } + inptr++; + + // Compute the matrix F Z + for (int i = 0; i < 6; i++) + { + FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; + FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4]; + FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4]; + FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5]; + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 4; j++) + { + f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; + f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j]; + f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j]; + f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j]; + } + + // Write out the output tile + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + *(outptrs[i][j]++) = f[i][j]; + } + } + } +} + +template <> +template <> +const Transform::TileFn Transform::tile_fns[max_pad_bottom][max_pad_right] = +{ + { + Transform::template process_tile<0, 0>, + Transform::template process_tile<0, 1>, + Transform::template process_tile<0, 2>, + Transform::template process_tile<0, 3>, + }, + { + Transform::template process_tile<1, 0>, + Transform::template process_tile<1, 1>, + Transform::template process_tile<1, 2>, + Transform::template process_tile<1, 3>, + }, + { + Transform::template process_tile<2, 0>, + Transform::template process_tile<2, 1>, + Transform::template process_tile<2, 2>, + Transform::template process_tile<2, 3>, + }, + { + Transform::template process_tile<3, 0>, + Transform::template process_tile<3, 1>, + Transform::template process_tile<3, 2>, + Transform::template process_tile<3, 3>, + } +}; + +template struct WinogradGEMM<4, 4, 3, 3>::OutputTransform<float>; +} // namespace winograd diff --git a/src/core/NEON/kernels/winograd/transforms/weights_2x2_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/weights_2x2_3x3_fp32.cpp new file mode 100644 index 0000000000..c0b282431e --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/weights_2x2_3x3_fp32.cpp @@ -0,0 +1,228 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "arm.hpp" +#include "winograd_gemm.hpp" +#include "transforms/kernel.hpp" + +namespace winograd +{ + template <> + template <> + void WinogradGEMM<2, 2, 3, 3>::WeightsTransform<float>::execute( + const int n_output_channels, + const int n_input_channels, + const float* const input, + float* const output, + const int matrix_stride, + const int matrix_row_stride + ) + { + constexpr int inner_tile_i = 4; + constexpr int inner_tile_j = 4; + + // Get pointers to each cell of the weight tensor + const auto weight_col_stride = n_input_channels * n_output_channels; + const auto weight_row_stride = 3 * weight_col_stride; + const float *inptrs[3][3]; + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride; + } + } + + // For each input channel + for (int ic = 0; ic < n_input_channels; ic++) + { + float *outptr = output + ic * matrix_row_stride; + + // For each output channel + int channels_remaining = n_output_channels; +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used and computed in this kernel + float32x4_t w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = vld1q_f32(inptrs[i][j]); + inptrs[i][j] += 4; + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = w[0][j]; + + // Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]); + Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); + + // Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]); + Ww[2][j] = vmulq_n_f32(vaddq_f32(vsubq_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); + + Ww[3][j] = w[2][j]; + } + + // Compute V = W w WT + for (int i = 0; i < inner_tile_i; i++) + { + V[i][0] = Ww[i][0]; + + // V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]); + V[i][1] = vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); + + // V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]); + V[i][2] = vmulq_n_f32(vaddq_f32(vsubq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); + + V[i][3] = Ww[i][2]; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++, m++) + { + vst1q_f32(outptr + m*matrix_stride, V[i][j]); + } + } + outptr += 4; + } +#endif // __aarch64__ +#ifdef __arm_any__ + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used and computed in this kernel + float32x2_t w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = vld1_f32(inptrs[i][j]); + inptrs[i][j] += 2; + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = w[0][j]; + + // Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]); + Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); + + // Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]); + Ww[2][j] = vmul_n_f32(vadd_f32(vsub_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); + + Ww[3][j] = w[2][j]; + } + + // Compute V = W w WT + for (int i = 0; i < inner_tile_i; i++) + { + V[i][0] = Ww[i][0]; + + // V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]); + V[i][1] = vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); + + // V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]); + V[i][2] = vmul_n_f32(vadd_f32(vsub_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); + + V[i][3] = Ww[i][2]; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++, m++) + { + vst1_f32(outptr + m*matrix_stride, V[i][j]); + } + } + outptr += 2; + } +#endif // __arm_any__ + for (; channels_remaining; channels_remaining--) + { + // Matrices used and computed in this kernel + float w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = *(inptrs[i][j]++); + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = w[0][j]; + Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]); + Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]); + Ww[3][j] = w[2][j]; + } + + // Compute V = W w WT + for (int i = 0; i < inner_tile_i; i++) + { + V[i][0] = Ww[i][0]; + V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]); + V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]); + V[i][3] = Ww[i][2]; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++, m++) + { + *(outptr + m*matrix_stride) = V[i][j]; + } + } + outptr++; + } + } + } + + template <> + template <> + int WinogradGEMM<2, 2, 3, 3>::WeightsTransform<float>::ops_performed(const KernelShape &shape) + { + const int channel_prod = shape.n_input_channels * shape.n_output_channels; + return 2 * 18 * channel_prod; + } + + template struct WinogradGEMM<2, 2, 3, 3>::WeightsTransform<float>; +} // namespace winograd diff --git a/src/core/NEON/kernels/winograd/transforms/weights_4x4_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/weights_4x4_3x3_fp32.cpp new file mode 100644 index 0000000000..de659c38e0 --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/weights_4x4_3x3_fp32.cpp @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "arm.hpp" +#include "winograd_gemm.hpp" +#include "transforms/kernel.hpp" + +namespace winograd +{ + /* Float implementation for kernel transform F(4x4, 3x3) */ + template <> + template <> + void WinogradGEMM<4, 4, 3, 3>::WeightsTransform<float>::execute( + const int n_output_channels, + const int n_input_channels, + const float* const input, // NOTE: Data in HWIO order + float* const output, + const int matrix_stride, + const int matrix_row_stride + ) + { + // Get pointers to each cell of the weight tensor + const auto weight_col_stride = n_input_channels * n_output_channels; + const auto weight_row_stride = 3 * weight_col_stride; + const float *inptrs[3][3]; + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride; + } + } + + // For each input channel + for (int ic = 0; ic < n_input_channels; ic++) + { + float *outptr = output + ic * matrix_row_stride; + + // For each output channel + int channels_remaining = n_output_channels; +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used and computed in this kernel + float32x4_t w[3][3], Ww[6][3], V[6][6]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = vld1q_f32(inptrs[i][j]); + inptrs[i][j] += 4; + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + // Ww[0][j] = 6*w[0][j]; + Ww[0][j] = vmulq_n_f32(w[0][j], 6.0); + + // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; + Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), -4.0); + + // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; + Ww[2][j] = vmulq_n_f32(vsubq_f32(vsubq_f32(w[1][j], w[0][j]), w[2][j]), 4.0); + + // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; + Ww[3][j] = vmlaq_n_f32(vmlaq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); + + // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; + Ww[4][j] = vmlaq_n_f32(vmlsq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); + + // Ww[5][j] = 24*w[2][j]; + Ww[5][j] = vmulq_n_f32(w[2][j], 24.0f); + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + const float recip576 = 1.0f / 576.0f; + + // V[i][0] = 6*Ww[i][0]; + V[i][0] = vmulq_n_f32(vmulq_n_f32(Ww[i][0], 6.0), recip576); + + // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]; + V[i][1] = vmulq_n_f32(vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576); + + // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]; + V[i][2] = vmulq_n_f32(vmulq_n_f32(vsubq_f32(vsubq_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576); + + // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]; + V[i][3] = vmulq_n_f32(vmlaq_n_f32(vmlaq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); + + // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]; + V[i][4] = vmulq_n_f32(vmlaq_n_f32(vmlsq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); + + // V[i][5] = 24*Ww[i][2]; + V[i][5] = vmulq_n_f32(vmulq_n_f32(Ww[i][2], 24.0f), recip576); + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + vst1q_f32(outptr + m*matrix_stride, V[i][j]); + } + } + outptr += 4; + } +#endif // __aarch64__ +#ifdef __arm_any__ + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used and computed in this kernel + float32x2_t w[3][3], Ww[6][3], V[6][6]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = vld1_f32(inptrs[i][j]); + inptrs[i][j] += 2; + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + // Ww[0][j] = 6*w[0][j]; + Ww[0][j] = vmul_n_f32(w[0][j], 6.0); + + // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; + Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), -4.0); + + // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; + Ww[2][j] = vmul_n_f32(vsub_f32(vsub_f32(w[1][j], w[0][j]), w[2][j]), 4.0); + + // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; + Ww[3][j] = vmla_n_f32(vmla_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); + + // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; + Ww[4][j] = vmla_n_f32(vmls_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); + + // Ww[5][j] = 24*w[2][j]; + Ww[5][j] = vmul_n_f32(w[2][j], 24.0f); + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + const float recip576 = 1.0f / 576.0f; + + // V[i][0] = 6*Ww[i][0]; + V[i][0] = vmul_n_f32(vmul_n_f32(Ww[i][0], 6.0), recip576); + + // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]; + V[i][1] = vmul_n_f32(vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576); + + // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]; + V[i][2] = vmul_n_f32(vmul_n_f32(vsub_f32(vsub_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576); + + // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]; + V[i][3] = vmul_n_f32(vmla_n_f32(vmla_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); + + // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]; + V[i][4] = vmul_n_f32(vmla_n_f32(vmls_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); + + // V[i][5] = 24*Ww[i][2]; + V[i][5] = vmul_n_f32(vmul_n_f32(Ww[i][2], 24.0f), recip576); + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + vst1_f32(outptr + m*matrix_stride, V[i][j]); + } + } + outptr += 2; + } +#endif // __arm_any__ + for (; channels_remaining; channels_remaining--) + { + // Matrices used and computed in this kernel + float w[3][3], Ww[6][3], V[6][6]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = *(inptrs[i][j]++); + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = 6*w[0][j]; + Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; + Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; + Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; + Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; + Ww[5][j] = 24*w[2][j]; + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + V[i][0] = ( 6*Ww[i][0]) / 576.0; + V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0; + V[i][2] = (-4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]) / 576.0; + V[i][3] = ( 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]) / 576.0; + V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]) / 576.0; + V[i][5] = (24*Ww[i][2]) / 576.0; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + *(outptr + m*matrix_stride) = V[i][j]; + } + } + outptr++; + } + } + } + + template <> + template <> + int WinogradGEMM<4, 4, 3, 3>::WeightsTransform<float>::ops_performed(const KernelShape &shape) + { + const int channel_prod = shape.n_input_channels * shape.n_output_channels; + return 9 * 16 * channel_prod; + } + + template struct WinogradGEMM<4, 4, 3, 3>::WeightsTransform<float>; +} |