aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3')
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp650
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp655
2 files changed, 0 insertions, 1305 deletions
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp
deleted file mode 100644
index 5925f9d569..0000000000
--- a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp
+++ /dev/null
@@ -1,650 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-/* Float implementation for AArch64.
- */
-#ifdef __aarch64__
-namespace winograd {
-
-
-template <>
-template <>
-inline void Winograd2x2_3x3GemmOutput<float>::_execute<false, false, 0>(
- const Tensor4DShape &output_shape,
- float *output,
- const float *input,
- const int mstride,
- const int matrix_row_stride
-) {
- const int tile_M = output_shape.n_rows / 2;
- const int tile_N = output_shape.n_cols / 2;
- int batch = output_shape.n_batches;
- float *outptr = output;
-
- const float *inptr0 = input;
- const float *inptr4 = input + 4 * mstride;
- const float *inptr8 = input + 8 * mstride;
- const float *inptr12 = input + 12 * mstride;
-
- const size_t col_stride = sizeof(float) * output_shape.n_channels;
- const size_t row_stride = col_stride * tile_N * 2;
-
- asm volatile (
- // Aliases for elements of the input matrix `F`
- // V-register Q-register
- "F11 .req v0\n" "qF11 .req q0\n"
- "F12 .req v1\n" "qF12 .req q1\n"
- "F13 .req v2\n" "qF13 .req q2\n"
- "F14 .req v3\n" "qF14 .req q3\n"
- "F21 .req v4\n" "qF21 .req q4\n"
- "F22 .req v5\n" "qF22 .req q5\n"
- "F23 .req v6\n" "qF23 .req q6\n"
- "F24 .req v7\n" "qF24 .req q7\n"
- "F31 .req v8\n" "qF31 .req q8\n"
- "F32 .req v9\n" "qF32 .req q9\n"
- "F33 .req v10\n" "qF33 .req q10\n"
- "F34 .req v11\n" "qF34 .req q11\n"
- "F41 .req v12\n" "qF41 .req q12\n"
- "F42 .req v13\n" "qF42 .req q13\n"
- "F43 .req v14\n" "qF43 .req q14\n"
- "F44 .req v15\n" "qF44 .req q15\n"
-
- // Aliases for elements of the intermediate matrix `FZ`
- "FZ11 .req v16\n"
- "FZ12 .req v17\n"
- "FZ21 .req v18\n"
- "FZ22 .req v19\n"
- "FZ31 .req v20\n"
- "FZ32 .req v21\n"
- "FZ41 .req v22\n"
- "FZ42 .req v23\n"
-
- // Aliases for elements of the output matrix `f` (called `g` due to case
- // insensitivity of aliases).
- " g11 .req v24\n"
- "qg11 .req q24\n"
- " g12 .req v25\n"
- "qg12 .req q25\n"
- " g21 .req v26\n"
- "qg21 .req q26\n"
- " g22 .req v27\n"
- "qg22 .req q27\n"
-
- // Prepare the various strides
- "col_stride .req %x[col_stride]\n"
- "row_stride .req %x[row_stride]\n"
- "row_plus_col_stride .req %x[row_plus_col_stride]\n"
-
- "mstride1 .req %x[mstride1]\n"
- "mstride2 .req %x[mstride2]\n"
- "mstride3 .req %x[mstride3]\n"
-
- "tile_i .req x19\n" // Tile row counter
- "tile_j .req x20\n" // Tile column counter
- "channel .req x21\n" // Channel counter
-
- "1:" // Loop over batches
- "mov tile_i, %x[tile_M]\n" // Reset tile row counter
-
- "2:" // Loop over rows of tiles
- "mov tile_j, %x[tile_N]\n" // Reset tile column counter
-
- "3:" // Loop over columns of tiles
- // Perform initial loads of the matrix `F`
- "ldr qF11, [%x[inptr0]]\n"
- "ldr qF12, [%x[inptr0], mstride1]\n"
- "ldr qF13, [%x[inptr0], mstride2]\n"
- "ldr qF14, [%x[inptr0], mstride3]\n"
- "add %x[inptr0], %x[inptr0], #0x10\n"
- "ldr qF21, [%x[inptr4]]\n"
- "ldr qF22, [%x[inptr4], mstride1]\n"
- "subs channel, %x[n_channels], #4\n" // Reset channel counter
-
- "ldr qF23, [%x[inptr4], mstride2]\n"
- "ldr qF24, [%x[inptr4], mstride3]\n"
- "add %x[inptr4], %x[inptr4], #0x10\n"
- "beq 5f\n" // Jump straight to tail if necessary
-
- "4:" // Loop over channels
- "ldr qF31, [%x[inptr8]]\n"
- "fadd FZ11.4s, F11.4s, F12.4s\n"
-
- "ldr qF32, [%x[inptr8], mstride1]\n"
- "fsub FZ12.4s, F12.4s, F13.4s\n"
-
- "ldr qF33, [%x[inptr8], mstride2]\n"
- "fadd FZ11.4s, FZ11.4s, F13.4s\n"
-
- "ldr qF34, [%x[inptr8], mstride3]\n"
- "fsub FZ12.4s, FZ12.4s, F14.4s\n"
-
- "ldr qF41, [%x[inptr12]]\n"
- "fadd FZ21.4s, F21.4s, F22.4s\n"
-
- "ldr qF42, [%x[inptr12], mstride1]\n"
- "fsub FZ22.4s, F22.4s, F23.4s\n"
-
- "ldr qF43, [%x[inptr12], mstride2]\n"
- "fadd FZ21.4s, FZ21.4s, F23.4s\n"
-
- "ldr qF44, [%x[inptr12], mstride3]\n"
- "fsub FZ22.4s, FZ22.4s, F24.4s\n"
-
- "fadd FZ31.4s, F31.4s, F32.4s\n"
- "add %x[inptr8], %x[inptr8], #0x10\n"
-
- "fsub FZ32.4s, F32.4s, F33.4s\n"
- "add %x[inptr12], %x[inptr12], #0x10\n"
-
- "fadd FZ31.4s, FZ31.4s, F33.4s\n"
-
- "fsub FZ32.4s, FZ32.4s, F34.4s\n"
-
- "fadd g11.4s, FZ11.4s, FZ21.4s\n"
-
- "fadd g12.4s, FZ12.4s, FZ22.4s\n"
-
- "fadd g11.4s, g11.4s, FZ31.4s\n"
-
- "fadd g12.4s, g12.4s, FZ32.4s\n"
-
- "ldr qF11, [%x[inptr0]]\n"
- "fadd FZ41.4s, F41.4s, F42.4s\n"
-
- "ldr qF12, [%x[inptr0], mstride1]\n"
- "fsub g21.4s, FZ21.4s, FZ31.4s\n"
-
- "ldr qF13, [%x[inptr0], mstride2]\n"
- "fsub FZ42.4s, F42.4s, F43.4s\n"
-
- "ldr qF14, [%x[inptr0], mstride3]\n"
- "str qg11, [%x[outptr]]\n"
-
- "ldr qF21, [%x[inptr4]]\n"
- "fadd FZ41.4s, FZ41.4s, F43.4s\n"
-
- "ldr qF22, [%x[inptr4], mstride1]\n"
- "str qg12, [%x[outptr], col_stride]\n"
-
- "ldr qF23, [%x[inptr4], mstride2]\n"
- "fsub FZ42.4s, FZ42.4s, F44.4s\n"
-
- "ldr qF24, [%x[inptr4], mstride3]\n"
- "fsub g22.4s, FZ22.4s, FZ32.4s\n"
-
- "fsub g21.4s, g21.4s, FZ41.4s\n"
- "add %x[inptr0], %x[inptr0], #0x10\n"
-
- "fsub g22.4s, g22.4s, FZ42.4s\n"
- "add %x[inptr4], %x[inptr4], #0x10\n"
-
- "subs channel, channel, #4\n"
-
- "str qg21, [%x[outptr], row_stride]\n"
-
- "str qg22, [%x[outptr], row_plus_col_stride]\n"
-
- "add %x[outptr], %x[outptr], #0x10\n"
-
- "bne 4b\n"
-
- "5:" // Channel tail
- "ldr qF31, [%x[inptr8]]\n"
- "fadd FZ11.4s, F11.4s, F12.4s\n"
-
- "ldr qF32, [%x[inptr8], mstride1]\n"
- "fsub FZ12.4s, F12.4s, F13.4s\n"
-
- "ldr qF33, [%x[inptr8], mstride2]\n"
- "fadd FZ11.4s, FZ11.4s, F13.4s\n"
-
- "ldr qF34, [%x[inptr8], mstride3]\n"
- "fsub FZ12.4s, FZ12.4s, F14.4s\n"
-
- "ldr qF41, [%x[inptr12]]\n"
- "fadd FZ21.4s, F21.4s, F22.4s\n"
-
- "ldr qF42, [%x[inptr12], mstride1]\n"
- "fsub FZ22.4s, F22.4s, F23.4s\n"
-
- "ldr qF43, [%x[inptr12], mstride2]\n"
- "fadd FZ21.4s, FZ21.4s, F23.4s\n"
-
- "ldr qF44, [%x[inptr12], mstride3]\n"
- "fsub FZ22.4s, FZ22.4s, F24.4s\n"
-
- "fadd FZ31.4s, F31.4s, F32.4s\n"
- "add %x[inptr8], %x[inptr8], #0x10\n"
-
- "fsub FZ32.4s, F32.4s, F33.4s\n"
- "add %x[inptr12], %x[inptr12], #0x10\n"
-
- "fadd FZ31.4s, FZ31.4s, F33.4s\n"
-
- "fsub FZ32.4s, FZ32.4s, F34.4s\n"
-
- "fadd g11.4s, FZ11.4s, FZ21.4s\n"
-
- "fadd g12.4s, FZ12.4s, FZ22.4s\n"
-
- "fadd g11.4s, g11.4s, FZ31.4s\n"
-
- "fadd g12.4s, g12.4s, FZ32.4s\n"
-
- "fadd FZ41.4s, F41.4s, F42.4s\n"
-
- "fsub g21.4s, FZ21.4s, FZ31.4s\n"
-
- "fsub FZ42.4s, F42.4s, F43.4s\n"
-
- "str qg11, [%x[outptr]]\n"
-
- "fadd FZ41.4s, FZ41.4s, F43.4s\n"
-
- "str qg12, [%x[outptr], col_stride]\n"
-
- "fsub FZ42.4s, FZ42.4s, F44.4s\n"
-
- "fsub g22.4s, FZ22.4s, FZ32.4s\n"
-
- "fsub g21.4s, g21.4s, FZ41.4s\n"
-
- "fsub g22.4s, g22.4s, FZ42.4s\n"
-
- "subs channel, channel, #4\n"
-
- "str qg21, [%x[outptr], row_stride]\n"
-
- // Progress input pointers to the next row of the matrix
- "add %x[inptr0], %x[inptr0], %x[mrowpad]\n"
- "add %x[inptr4], %x[inptr4], %x[mrowpad]\n"
- "add %x[inptr8], %x[inptr8], %x[mrowpad]\n"
- "add %x[inptr12], %x[inptr12], %x[mrowpad]\n"
-
- "str qg22, [%x[outptr], row_plus_col_stride]\n"
-
- "add %x[outptr], %x[outptr], #0x10\n"
-
-
- "add %x[outptr], %x[outptr], col_stride\n"
- "subs tile_j, tile_j, #1\n"
- "bne 3b\n"
-
- "add %x[outptr], %x[outptr], row_stride\n"
- "subs tile_i, tile_i, #1\n"
- "bne 2b\n"
-
- "subs %[batch], %[batch], #1\n"
- "bne 1b\n"
-
- ".unreq F11\n" ".unreq qF11\n"
- ".unreq F12\n" ".unreq qF12\n"
- ".unreq F13\n" ".unreq qF13\n"
- ".unreq F14\n" ".unreq qF14\n"
- ".unreq F21\n" ".unreq qF21\n"
- ".unreq F22\n" ".unreq qF22\n"
- ".unreq F23\n" ".unreq qF23\n"
- ".unreq F24\n" ".unreq qF24\n"
- ".unreq F31\n" ".unreq qF31\n"
- ".unreq F32\n" ".unreq qF32\n"
- ".unreq F33\n" ".unreq qF33\n"
- ".unreq F34\n" ".unreq qF34\n"
- ".unreq F41\n" ".unreq qF41\n"
- ".unreq F42\n" ".unreq qF42\n"
- ".unreq F43\n" ".unreq qF43\n"
- ".unreq F44\n" ".unreq qF44\n"
-
- ".unreq FZ11\n" ".unreq FZ12\n"
- ".unreq FZ21\n" ".unreq FZ22\n"
- ".unreq FZ31\n" ".unreq FZ32\n"
- ".unreq FZ41\n" ".unreq FZ42\n"
-
- ".unreq g11\n" ".unreq qg11\n"
- ".unreq g12\n" ".unreq qg12\n"
- ".unreq g21\n" ".unreq qg21\n"
- ".unreq g22\n" ".unreq qg22\n"
-
- ".unreq col_stride\n"
- ".unreq row_stride\n"
- ".unreq row_plus_col_stride\n"
-
- ".unreq mstride1\n"
- ".unreq mstride2\n"
- ".unreq mstride3\n"
-
- ".unreq tile_i \n"
- ".unreq tile_j \n"
- ".unreq channel\n"
-
- : [batch] "+r" (batch),
- [outptr] "+r" (outptr),
- [inptr0] "+r" (inptr0),
- [inptr4] "+r" (inptr4),
- [inptr8] "+r" (inptr8),
- [inptr12] "+r" (inptr12)
- : [tile_M] "r" (tile_M),
- [tile_N] "r" (tile_N),
- [n_channels] "r" (output_shape.n_channels),
- [col_stride] "r" (col_stride),
- [row_stride] "r" (row_stride),
- [row_plus_col_stride] "r" (row_stride + col_stride),
- [mstride1] "r" (mstride * sizeof(float)),
- [mstride2] "r" (2 * mstride * sizeof(float)),
- [mstride3] "r" (3 * mstride * sizeof(float)),
- [mrowpad] "r" ((matrix_row_stride - output_shape.n_channels) * sizeof(float))
- : "x19", "x20", "x21",
- "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
- "q12", "q13", "q14", "q15", "q16", "q17", "q18", "q19", "q20", "q21",
- "q22", "q23", "q24", "q25", "q26", "q27",
- "cc", "memory"
- );
-}
-
-template <>
-template <bool tail_M, bool tail_N, const int channel_tail>
-inline void Winograd2x2_3x3GemmOutput<float>::_execute(
- const Tensor4DShape &output_shape,
- float *output,
- const float *input,
- const int mstride,
- const int matrix_row_stride
-) {
- // Compute basic information about the shape of the matrices
- const int tile_M = output_shape.n_rows / 2;
- const int tile_N = output_shape.n_cols / 2;
- const int n_channels = output_shape.n_channels;
-
- // Extract 16 input pointers
- const float* inptr[16];
- for (int i = 0; i < 16; i++) {
- inptr[i] = input + i*mstride;
- }
-
- // Extract 4 output pointers
- float *outptr00 = output;
- float *outptr01 = outptr00 + n_channels;
- float *outptr10 = outptr00 + output_shape.n_cols * n_channels;
- float *outptr11 = outptr10 + n_channels;
-
- // Progress over the output tiles, generating output values.
- for (int batch = 0; batch < output_shape.n_batches; batch++) {
- for (int tile_i = 0; tile_i < tile_M; tile_i++) {
- for (int tile_j = 0; tile_j < tile_N; tile_j++) {
- for (int channel = 0; channel < n_channels; channel++) {
- // Read values from the input pointers
- float F[4][4];
- for (int i = 0; i < 4; i++) {
- for (int j = 0; j < 4; j++) {
- F[i][j] = *(inptr[i*4 + j]++);
- }
- }
-
- // Compute the matrix F.Z
- float ZF[4][2];
- ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
- ZF[0][1] = F[0][1] - F[0][2] - F[0][3];
- ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
- ZF[1][1] = F[1][1] - F[1][2] - F[1][3];
- ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
- ZF[2][1] = F[2][1] - F[2][2] - F[2][3];
- ZF[3][0] = F[3][0] + F[3][1] + F[3][2];
- ZF[3][1] = F[3][1] - F[3][2] - F[3][3];
-
- // Hence compute the output matrix Z^T . (F.Z)
- *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
- *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1];
- *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0];
- *(outptr11++) = ZF[1][1] - ZF[2][1] - ZF[3][1];
- }
-
- // Progress the input pointers to the next row
- for (int i = 0; i < 16; i++) {
- inptr[i] += matrix_row_stride - n_channels;
- }
-
- // Progress the output pointers to the next column
- outptr00 += n_channels;
- outptr01 += n_channels;
- outptr10 += n_channels;
- outptr11 += n_channels;
- }
-
- if (tail_N) {
- // Only evaluate the left-most columns of the output
- for (int channel = 0; channel < n_channels; channel++) {
- // Read values from the input pointers
- float F[4][3];
- for (int i = 0; i < 4; i++) {
- for (int j = 0; j < 3; j++) {
- F[i][j] = *(inptr[i*4 + j]++);
- }
- }
- for (int i = 0; i < 4; i++) {
- inptr[i*4 + 3]++;
- }
-
- // Compute the matrix F.Z
- float ZF[4][1];
- ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
- ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
- ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
- ZF[3][0] = F[3][0] + F[3][1] + F[3][2];
-
- // Hence compute the output matrix Z^T . (F.Z)
- *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
- *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0];
- }
-
- // Progress the input pointers to the next row
- for (int i = 0; i < 16; i++) {
- inptr[i] += matrix_row_stride - n_channels;
- }
-
- // Progress the output pointers to the next column
- outptr01 += n_channels; // Account for being skipped above
- outptr11 += n_channels; // Account for being skipped above
- }
-
- // Progress the output pointers to the next row
- outptr00 += output_shape.n_cols * n_channels;
- outptr01 += output_shape.n_cols * n_channels;
- outptr10 += output_shape.n_cols * n_channels;
- outptr11 += output_shape.n_cols * n_channels;
- }
-
- if (tail_M) {
- // Only work on the upper row of the output
- for (int tile_j = 0; tile_j < tile_N; tile_j++) {
- for (int channel = 0; channel < n_channels; channel++) {
- // Read values from the input pointers
- float F[3][4];
- for (int i = 0; i < 3; i++) {
- for (int j = 0; j < 4; j++) {
- F[i][j] = *(inptr[i*4 + j]++);
- }
- }
- for (int j = 0; j < 4; j++) {
- inptr[12 + j]++;
- }
-
- // Compute the matrix F.Z
- float ZF[3][2];
- ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
- ZF[0][1] = F[0][1] - F[0][2] - F[0][3];
- ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
- ZF[1][1] = F[1][1] - F[1][2] - F[1][3];
- ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
- ZF[2][1] = F[2][1] - F[2][2] - F[2][3];
-
- // Hence compute the output matrix Z^T . (F.Z)
- *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
- *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1];
- }
-
- // Progress the input pointers to the next row
- for (int i = 0; i < 16; i++) {
- inptr[i] += matrix_row_stride - n_channels;
- }
-
- // Progress the output pointers to the next column
- outptr00 += n_channels;
- outptr01 += n_channels;
- outptr10 += 2 * n_channels; // Account for being skipped above
- outptr11 += 2 * n_channels; // Account for being skipped above
- }
-
- if (tail_N) {
- // Only evaluate the upper-left cell of the output
- for (int channel = 0; channel < n_channels; channel++) {
- // Read values from the input pointers
- float F[3][3];
- for (int i = 0; i < 3; i++) {
- for (int j = 0; j < 3; j++) {
- F[i][j] = *(inptr[i*4 + j]);
- }
- }
- for (int i = 0; i < 16; i++) {
- inptr[i]++;
- }
-
- // Compute the matrix F.Z
- float ZF[3][1];
- ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
- ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
- ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
-
- // Hence compute the output matrix Z^T . (F.Z)
- *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
- }
-
- // Progress the input pointers to the next row
- for (int i = 0; i < 16; i++) {
- inptr[i] += matrix_row_stride - n_channels;
- }
-
- // Progress the output pointers to the next column
- outptr01 += n_channels; // Account for being skipped above
- outptr10 += n_channels; // Account for being skipped above
- outptr11 += n_channels; // Account for being skipped above
- }
- }
- }
-}
-
-/*****************************************************************************/
-template <>
-inline void Winograd2x2_3x3GemmOutput<float>::execute(
- const Tensor4DShape &output_shape,
- float* const matrix_base,
- const int matrix_stride,
- const int matrix_row_stride,
- float* const output
-) {
- // Dispatch to an appropriate implementation based on the shape of the output
- // tensor.
- if (output_shape.n_rows % 2 && output_shape.n_cols % 2) {
- constexpr bool tail_M = true, tail_N = true;
- switch (output_shape.n_channels % 4) {
- case 0:
- _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 1:
- _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 2:
- _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 3:
- _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- default:
- assert(0);
- break;
- }
- } else if (output_shape.n_rows % 2) {
- constexpr bool tail_M = true, tail_N = false;
- switch (output_shape.n_channels % 4) {
- case 0:
- _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 1:
- _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 2:
- _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 3:
- _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- default:
- assert(0);
- break;
- }
- } else if (output_shape.n_cols % 2) {
- constexpr bool tail_M = false, tail_N = true;
- switch (output_shape.n_channels % 4) {
- case 0:
- _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 1:
- _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 2:
- _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 3:
- _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- default:
- assert(0);
- break;
-
- }
- } else {
- constexpr bool tail_M = false, tail_N = false;
- switch (output_shape.n_channels % 4) {
- case 0:
- _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 1:
- _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 2:
- _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- case 3:
- _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
- break;
- default:
- assert(0);
- break;
-
- }
- }
-}
-/*****************************************************************************/
-
-} // namespace winograd
-#endif // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp
deleted file mode 100644
index f551b12b52..0000000000
--- a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp
+++ /dev/null
@@ -1,655 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-#ifdef __aarch64__
-
-/*****************************************************************************/
-// Compute ZF specializations
-
-template <>
-template <>
-inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::compute_zf<0>(
- const int n_rows, const int n_channels,
- float* output, const float* const input[16]
-) {
- // Make copies of some variables
- int row = n_rows;
- float* outptr = output;
- const float* inptr = input[0];
-
- // Perform the transformation
- asm volatile (
- // "inptr0 .req %x[inptr]\n"
- "inptr1 .req x0\n"
- "inptr2 .req x1\n"
- "inptr3 .req x2\n"
- "inptr4 .req x3\n"
- "inptr5 .req x4\n"
- "inptr6 .req x5\n"
- "inptr7 .req x6\n"
- "inptr8 .req x7\n"
- "inptr9 .req x8\n"
- "inptr10 .req x9\n"
- "inptr11 .req x10\n"
- "inptr12 .req x11\n"
- "inptr13 .req x12\n"
- "inptr14 .req x13\n"
- "inptr15 .req x14\n"
-
- // "outptr0 .req %x[outptr]\n"
- "outptr1 .req x15\n"
- "outptr2 .req x16\n"
- "outptr3 .req x17\n"
- "outptr4 .req x18\n"
- "outptr5 .req x19\n"
- "outptr6 .req x20\n"
- "outptr7 .req x21\n"
-
- // Compute additional pointers into the input and output matrices.
- "mstride .req x22\n" // Matrix stride
- "mul mstride, %x[row], %x[n_channels]\n"
- "lsl mstride, mstride, #2\n" // * sizeof(float)
-
- "add inptr1, %x[inptr], mstride\n"
- "add inptr2, %x[inptr], mstride, LSL #1\n"
- "add inptr3, inptr2, mstride\n"
- "add inptr4, inptr3, mstride\n"
- "add inptr5, inptr4, mstride\n"
- "add inptr6, inptr5, mstride\n"
- "add inptr7, inptr6, mstride\n"
- "add inptr8, inptr7, mstride\n"
- "add inptr9, inptr8, mstride\n"
- "add inptr10, inptr9, mstride\n"
- "add inptr11, inptr10, mstride\n"
- "add inptr12, inptr11, mstride\n"
- "add inptr13, inptr12, mstride\n"
- "add inptr14, inptr13, mstride\n"
- "add inptr15, inptr14, mstride\n"
-
- "add outptr1, %[outptr], mstride\n"
- "add outptr2, outptr1, mstride\n"
- "add outptr3, outptr2, mstride\n"
- "add outptr4, outptr3, mstride\n"
- "add outptr5, outptr4, mstride\n"
- "add outptr6, outptr5, mstride\n"
- "add outptr7, outptr6, mstride\n"
-
- ".unreq mstride\n"
-
- "column .req x22\n" // Column loop counter
-
- "1:" // Loop over rows
- "ldr q0, [%x[inptr]], #0x10\n"
- "ldr q1, [inptr1], #0x10\n"
- "ldr q2, [inptr2], #0x10\n"
- "ldr q3, [inptr3], #0x10\n"
- "ldr q4, [inptr4], #0x10\n"
- "ldr q5, [inptr5], #0x10\n"
- "ldr q6, [inptr6], #0x10\n"
- "ldr q7, [inptr7], #0x10\n"
- "subs column, %x[n_channels], #0x4\n"
- "beq 3f\n"
-
- "2:" // Loop over columns
- "ldr q8, [inptr8], #0x10\n"
- "prfm pldl1keep, [%x[inptr], #196]\n"
- "fadd v16.4s, v0.4s, v1.4s\n"
-
- "ldr q9, [inptr9], #0x10\n"
- "prfm pldl1keep, [inptr1, #196]\n"
- "fsub v17.4s, v1.4s, v2.4s\n"
-
- "ldr q10, [inptr10], #0x10\n"
- "prfm pldl1keep, [inptr2, #196]\n"
- "fadd v16.4s, v16.4s, v2.4s\n"
-
- "ldr q11, [inptr11], #0x10\n"
- "prfm pldl1keep, [inptr3, #196]\n"
- "fsub v17.4s, v17.4s, v3.4s\n"
-
- "ldr q12, [inptr12], #0x10\n"
- "prfm pldl1keep, [inptr4, #196]\n"
- "str q16, [%x[outptr]], #0x10\n"
-
- "ldr q13, [inptr13], #0x10\n"
- "prfm pldl1keep, [inptr5, #196]\n"
- "str q17, [outptr1], #0x10\n"
-
- "ldr q14, [inptr14], #0x10\n"
- "prfm pldl1keep, [inptr6, #196]\n"
- "fadd v16.4s, v4.4s, v5.4s\n"
-
- "ldr q15, [inptr15], #0x10\n"
- "prfm pldl1keep, [inptr7, #196]\n"
- "fsub v17.4s, v5.4s, v6.4s\n"
-
- "ldr q0, [%x[inptr]], #0x10\n"
- "prfm pldl1keep, [inptr8, #196]\n"
- "fadd v16.4s, v16.4s, v6.4s\n"
-
- "ldr q1, [inptr1], #0x10\n"
- "prfm pldl1keep, [inptr9, #196]\n"
- "fsub v17.4s, v17.4s, v7.4s\n"
-
- "ldr q2, [inptr2], #0x10\n"
- "prfm pldl1keep, [inptr10, #196]\n"
- "str q16, [outptr2], #0x10\n"
-
- "ldr q3, [inptr3], #0x10\n"
- "prfm pldl1keep, [inptr11, #196]\n"
- "str q17, [outptr3], #0x10\n"
-
- "ldr q4, [inptr4], #0x10\n"
- "prfm pldl1keep, [inptr12, #196]\n"
- "fadd v16.4s, v8.4s, v9.4s\n"
-
- "ldr q5, [inptr5], #0x10\n"
- "prfm pldl1keep, [inptr13, #196]\n"
- "fsub v17.4s, v9.4s, v10.4s\n"
-
- "ldr q6, [inptr6], #0x10\n"
- "prfm pldl1keep, [inptr14, #196]\n"
- "fadd v16.4s, v16.4s, v10.4s\n"
-
- "ldr q7, [inptr7], #0x10\n"
- "prfm pldl1keep, [inptr15, #196]\n"
- "fsub v17.4s, v17.4s, v11.4s\n"
-
- "str q16, [outptr4], #0x10\n"
- "fadd v16.4s, v12.4s, v13.4s\n"
- "fsub v18.4s, v13.4s, v14.4s\n"
-
- "str q17, [outptr5], #0x10\n"
- "fadd v16.4s, v16.4s, v14.4s\n"
- "fsub v18.4s, v18.4s, v15.4s\n"
-
- "str q16, [outptr6], #0x10\n"
- "subs column, column, #0x4\n"
-
- "str q18, [outptr7], #0x10\n"
- "bne 2b\n"
-
- "3:" // Tail
- "ldr q8, [inptr8], #0x10\n"
- "prfm pldl1keep, [%x[inptr], #196]\n"
- "fadd v16.4s, v0.4s, v1.4s\n"
-
- "ldr q9, [inptr9], #0x10\n"
- "prfm pldl1keep, [inptr1, #196]\n"
- "fsub v17.4s, v1.4s, v2.4s\n"
-
- "ldr q10, [inptr10], #0x10\n"
- "prfm pldl1keep, [inptr2, #196]\n"
- "fadd v16.4s, v16.4s, v2.4s\n"
-
- "ldr q11, [inptr11], #0x10\n"
- "prfm pldl1keep, [inptr3, #196]\n"
- "fsub v17.4s, v17.4s, v3.4s\n"
-
- "ldr q12, [inptr12], #0x10\n"
- "prfm pldl1keep, [inptr4, #196]\n"
- "str q16, [%x[outptr]], #0x10\n"
-
- "ldr q13, [inptr13], #0x10\n"
- "prfm pldl1keep, [inptr5, #196]\n"
- "str q17, [outptr1], #0x10\n"
-
- "ldr q14, [inptr14], #0x10\n"
- "prfm pldl1keep, [inptr6, #196]\n"
- "fadd v16.4s, v4.4s, v5.4s\n"
-
- "ldr q15, [inptr15], #0x10\n"
- "prfm pldl1keep, [inptr7, #196]\n"
- "fsub v17.4s, v5.4s, v6.4s\n"
-
- "prfm pldl1keep, [inptr8, #196]\n"
- "prfm pldl1keep, [inptr9, #196]\n"
- "fadd v16.4s, v16.4s, v6.4s\n"
-
- "prfm pldl1keep, [inptr10, #196]\n"
- "prfm pldl1keep, [inptr11, #196]\n"
- "fsub v17.4s, v17.4s, v7.4s\n"
-
- "prfm pldl1keep, [inptr12, #196]\n"
- "prfm pldl1keep, [inptr13, #196]\n"
- "str q16, [outptr2], #0x10\n"
-
- "prfm pldl1keep, [inptr14, #196]\n"
- "prfm pldl1keep, [inptr15, #196]\n"
- "str q17, [outptr3], #0x10\n"
-
- "fadd v16.4s, v8.4s, v9.4s\n"
- "fsub v17.4s, v9.4s, v10.4s\n"
-
- "fadd v16.4s, v16.4s, v10.4s\n"
- "fsub v17.4s, v17.4s, v11.4s\n"
-
- "str q16, [outptr4], #0x10\n"
- "fadd v16.4s, v12.4s, v13.4s\n"
- "fsub v18.4s, v13.4s, v14.4s\n"
-
- "str q17, [outptr5], #0x10\n"
- "fadd v16.4s, v16.4s, v14.4s\n"
- "fsub v18.4s, v18.4s, v15.4s\n"
-
- "str q16, [outptr6], #0x10\n"
- "str q18, [outptr7], #0x10\n"
-
- "subs %x[row], %x[row], #0x1\n"
- "bne 1b\n"
-
- ".unreq inptr1\n"
- ".unreq inptr2\n"
- ".unreq inptr3\n"
- ".unreq inptr4\n"
- ".unreq inptr5\n"
- ".unreq inptr6\n"
- ".unreq inptr7\n"
- ".unreq inptr8\n"
- ".unreq inptr9\n"
- ".unreq inptr10\n"
- ".unreq inptr11\n"
- ".unreq inptr12\n"
- ".unreq inptr13\n"
- ".unreq inptr14\n"
- ".unreq inptr15\n"
- ".unreq outptr1\n"
- ".unreq outptr2\n"
- ".unreq outptr3\n"
- ".unreq outptr4\n"
- ".unreq outptr5\n"
- ".unreq outptr6\n"
- ".unreq outptr7\n"
-
- : [row] "+r" (row),
- [inptr] "+r" (inptr),
- [outptr] "+r" (outptr)
- : [n_channels] "r" (n_channels),
- [sizeof_float] "i" (sizeof(float))
- : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
- "q12", "q13", "q14", "q15", "q16", "q17", "x0", "x1", "x2", "x3", "x4",
- "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15",
- "x16", "x17", "x18", "x19", "x20", "x21", "x22", "cc", "memory"
- );
-}
-
-/*****************************************************************************/
-// Compute ZFZ^T specializations
-
-template <>
-template <>
-inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::compute_zfzT<false, false, 0>(
- const Tensor4DShape &output_shape,
- float* const output, const float* const input
-) {
- const int tile_M = output_shape.n_rows / 2;
- const int tile_N = output_shape.n_cols / 2;
- int batch = output_shape.n_batches;
- float *outptr = output;
- const float *inptr = input;
-
- asm volatile (
- // Compute input pointers
- "inptr1 .req x0\n"
- "inptr2 .req x1\n"
- "inptr3 .req x2\n"
- "inptr4 .req x3\n"
- "inptr5 .req x4\n"
- "inptr6 .req x5\n"
- "inptr7 .req x6\n"
- "inptr8 .req x7\n"
-
- "mstride .req x8\n"
- "mul mstride, %x[tile_M], %x[tile_N]\n"
- "mul mstride, mstride, %x[n_channels]\n"
- "lsl mstride, mstride, #2\n" // * sizeof(float)
-
- "add inptr1, %[inptr], mstride\n"
- "add inptr2, inptr1, mstride\n"
- "add inptr3, inptr2, mstride\n"
- "add inptr4, inptr3, mstride\n"
- "add inptr5, inptr4, mstride\n"
- "add inptr6, inptr5, mstride\n"
- "add inptr7, inptr6, mstride\n"
- "add inptr8, inptr7, mstride\n"
-
- ".unreq mstride\n"
-
- // Compute initial output pointers
- "outptr01 .req x8\n"
- "outptr10 .req x9\n"
- "outptr11 .req x10\n"
-
- "add outptr01, %x[outptr], %x[n_channels], LSL #2\n"
- "add outptr10, %x[outptr], %x[row_stride], LSL #2\n"
- "add outptr11, outptr10, %x[n_channels], LSL #2\n"
-
- "tile_i .req x11\n"
- "tile_j .req x12\n"
- "channel .req x13\n"
-
- "1:" // Loop over batches
- "mov tile_i, %x[tile_M]\n"
-
- "2:" // Loop over rows of output tiles
- "mov tile_j, %x[tile_N]\n"
-
- "3:" // Loop over columns of output tiles
- "ldr q0, [%x[inptr]], #0x10\n"
- "ldr q2, [inptr2], #0x10\n"
- "subs channel, %x[n_channels], #0x4\n"
-
- "ldr q1, [inptr1], #0x10\n"
- "ldr q3, [inptr3], #0x10\n"
- "beq 6f\n"
-
- "4:"
- "ldr q4, [inptr4], #0x10\n"
- "ldr q5, [inptr5], #0x10\n"
- "fadd v16.4s, v0.4s, v2.4s\n"
-
- "ldr q6, [inptr6], #0x10\n"
- "ldr q7, [inptr7], #0x10\n"
- "fadd v17.4s, v1.4s, v3.4s\n"
-
- "ldr q8, [%x[inptr]], #0x10\n"
- "ldr q10, [inptr2], #0x10\n"
- "fadd v16.4s, v16.4s, v4.4s\n"
-
- "ldr q9, [inptr1], #0x10\n"
- "ldr q11, [inptr3], #0x10\n"
- "fadd v17.4s, v17.4s, v5.4s\n"
-
- "str q16, [%x[outptr]], #0x10\n"
- "prfm pldl1strm, [%x[inptr], #196]\n"
- "fsub v18.4s, v2.4s, v4.4s\n"
-
- "str q17, [outptr01], #0x10\n"
- "prfm pldl1strm, [inptr2, #196]\n"
- "fsub v19.4s, v3.4s, v5.4s\n"
-
- "prfm pldl1strm, [inptr1, #196]\n"
- "prfm pldl1strm, [inptr3, #196]\n"
- "fsub v18.4s, v18.4s, v6.4s\n"
-
- "prfm pldl1strm, [inptr4, #196]\n"
- "prfm pldl1strm, [inptr5, #196]\n"
- "fsub v19.4s, v19.4s, v7.4s\n"
-
- "str q18, [outptr10], #0x10\n"
- "prfm pldl1strm, [inptr6, #196]\n"
- "prfm pldl1strm, [inptr7, #196]\n"
-
- "subs channel, channel, #0x4\n"
-
- "str q19, [outptr11], #0x10\n"
- "beq 6f\n" // Branch to tail
-
- "ldr q12, [inptr4], #0x10\n"
- "ldr q13, [inptr5], #0x10\n"
- "fadd v16.4s, v8.4s, v10.4s\n"
-
- "ldr q14, [inptr6], #0x10\n"
- "ldr q15, [inptr7], #0x10\n"
- "fadd v17.4s, v9.4s, v11.4s\n"
-
- "ldr q0, [%x[inptr]], #0x10\n"
- "ldr q2, [inptr2], #0x10\n"
- "fadd v16.4s, v16.4s, v12.4s\n"
-
- "ldr q1, [inptr1], #0x10\n"
- "ldr q3, [inptr3], #0x10\n"
- "fadd v17.4s, v17.4s, v13.4s\n"
-
- "str q16, [%x[outptr]], #0x10\n"
- "prfm pldl1strm, [%x[inptr], #196]\n"
- "fsub v18.4s, v10.4s, v12.4s\n"
-
- "str q17, [outptr01], #0x10\n"
- "prfm pldl1strm, [inptr2, #196]\n"
- "fsub v19.4s, v11.4s, v13.4s\n"
-
- "prfm pldl1strm, [inptr1, #196]\n"
- "prfm pldl1strm, [inptr3, #196]\n"
- "fsub v18.4s, v18.4s, v14.4s\n"
-
- "prfm pldl1strm, [inptr4, #196]\n"
- "prfm pldl1strm, [inptr5, #196]\n"
- "fsub v19.4s, v19.4s, v15.4s\n"
-
- "str q18, [outptr10], #0x10\n"
- "prfm pldl1strm, [inptr6, #196]\n"
- "prfm pldl1strm, [inptr7, #196]\n"
-
- "subs channel, channel, #0x4\n"
-
- "str q19, [outptr11], #0x10\n"
- "bne 4b\n" // Continue loop
-
- "5:" // Tail
- "ldr q12, [inptr4], #0x10\n"
- "ldr q13, [inptr5], #0x10\n"
- "fadd v16.4s, v8.4s, v10.4s\n"
-
- "ldr q14, [inptr6], #0x10\n"
- "ldr q15, [inptr7], #0x10\n"
- "fadd v17.4s, v9.4s, v11.4s\n"
-
- "fadd v16.4s, v16.4s, v12.4s\n"
-
- "fadd v17.4s, v17.4s, v13.4s\n"
-
- "str q16, [%x[outptr]], #0x10\n"
- "fsub v18.4s, v10.4s, v12.4s\n"
- "fsub v19.4s, v11.4s, v13.4s\n"
-
- "str q17, [outptr01], #0x10\n"
- "fsub v18.4s, v18.4s, v14.4s\n"
- "fsub v19.4s, v19.4s, v15.4s\n"
-
- "str q18, [outptr10], #0x10\n"
- "str q19, [outptr11], #0x10\n"
- "b 7f\n"
-
- "6:" // Tail
- "ldr q4, [inptr4], #0x10\n"
- "ldr q5, [inptr5], #0x10\n"
- "fadd v16.4s, v0.4s, v2.4s\n"
-
- "ldr q6, [inptr6], #0x10\n"
- "ldr q7, [inptr7], #0x10\n"
- "fadd v17.4s, v1.4s, v3.4s\n"
-
- "fadd v16.4s, v16.4s, v4.4s\n"
-
- "fadd v17.4s, v17.4s, v5.4s\n"
-
- "str q16, [%x[outptr]], #0x10\n"
- "fsub v18.4s, v2.4s, v4.4s\n"
- "fsub v19.4s, v3.4s, v5.4s\n"
-
- "str q17, [outptr01], #0x10\n"
- "fsub v18.4s, v18.4s, v6.4s\n"
- "fsub v19.4s, v19.4s, v7.4s\n"
-
- "str q18, [outptr10], #0x10\n"
- "str q19, [outptr11], #0x10\n"
-
- "7:"
- "add %x[outptr], %x[outptr], %x[n_channels], LSL #2\n"
- "add outptr01, outptr01, %x[n_channels], LSL #2\n"
- "add outptr10, outptr10, %x[n_channels], LSL #2\n"
- "add outptr11, outptr11, %x[n_channels], LSL #2\n"
-
- "subs tile_j, tile_j, #1\n"
- "bne 3b\n"
-
- // Progress the output pointers to the new row
- "add %x[outptr], %x[outptr], %x[row_stride], LSL #2\n"
- "add outptr01, outptr01, %x[row_stride], LSL #2\n"
- "add outptr10, outptr10, %x[row_stride], LSL #2\n"
- "add outptr11, outptr11, %x[row_stride], LSL #2\n"
-
- "subs tile_i, tile_i, #1\n"
- "bne 2b\n"
-
- "subs %[batch], %[batch], #1\n"
- "bne 1b\n"
- "5:"
-
- ".unreq inptr1\n"
- ".unreq inptr2\n"
- ".unreq inptr3\n"
- ".unreq inptr4\n"
- ".unreq inptr5\n"
- ".unreq inptr6\n"
- ".unreq inptr7\n"
- ".unreq inptr8\n"
- ".unreq outptr01\n"
- ".unreq outptr10\n"
- ".unreq outptr11\n"
- : [batch] "+r" (batch),
- [outptr] "+r" (outptr),
- [inptr] "+r" (inptr)
- : [tile_M] "r" (tile_M),
- [tile_N] "r" (tile_N),
- [n_channels] "r" (output_shape.n_channels),
- [row_stride] "r" (output_shape.n_cols * output_shape.n_channels)
- : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11",
- "x12", "x13", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
- "cc", "memory"
- );
-}
-/*****************************************************************************/
-
-/*****************************************************************************/
-template <>
-inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::execute(
- const Tensor4DShape &output_shape,
- float* const matrices[16], float* const output
-) {
- // profiler prof;
-
- // Allocate memory for the intermediate matrices
- const int tile_M = iceildiv(output_shape.n_rows, 2);
- const int tile_N = iceildiv(output_shape.n_cols, 2);
- const int n_rows = output_shape.n_batches * tile_M * tile_N;
- const int n_channels = output_shape.n_channels;
- float* matrices_zf = reinterpret_cast<float*>(
- calloc(8 * n_rows * n_channels, sizeof(float))
- );
-
- // Perform the first stage transform, computing ZF.
- const auto f_compute_zf = [&] () {
- switch (n_channels % 4) {
- case 0:
- compute_zf<0>(n_rows, n_channels, matrices_zf, matrices);
- break;
- case 1:
- compute_zf<1>(n_rows, n_channels, matrices_zf, matrices);
- break;
- case 2:
- compute_zf<2>(n_rows, n_channels, matrices_zf, matrices);
- break;
- case 3:
- compute_zf<3>(n_rows, n_channels, matrices_zf, matrices);
- };
- };
- // prof("Compute ZF", f_compute_zf, 16 * n_rows * n_channels * sizeof(float), 0, 8 * n_rows * n_channels * sizeof(float));
- f_compute_zf();
-
- // Perform the second stage transform, finishing Z F Z^T - variable dispatch
- // based on size of the output and the channel tail.
- const auto f_compute_zfzT = [&] () {
- if (output_shape.n_rows % 2 && output_shape.n_cols % 2) {
- constexpr bool tail_M = true, tail_N = true;
- switch (n_channels % 4) {
- case 0:
- compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
- break;
- case 1:
- compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
- break;
- case 2:
- compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
- break;
- case 3:
- compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
- }
- } else if (output_shape.n_rows % 2) {
- constexpr bool tail_M = true, tail_N = false;
- switch (n_channels % 4) {
- case 0:
- compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
- break;
- case 1:
- compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
- break;
- case 2:
- compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
- break;
- case 3:
- compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
- }
- } else if (output_shape.n_cols % 2) {
- constexpr bool tail_M = false, tail_N = true;
- switch (n_channels % 4) {
- case 0:
- compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
- break;
- case 1:
- compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
- break;
- case 2:
- compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
- break;
- case 3:
- compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
- }
- } else {
- constexpr bool tail_M = false, tail_N = false;
- switch (n_channels % 4) {
- case 0:
- compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
- break;
- case 1:
- compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
- break;
- case 2:
- compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
- break;
- case 3:
- compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
- }
- }
- };
- // prof("Compute ZFZT", f_compute_zfzT, 8 * n_rows * n_channels * sizeof(float), 0, 4 * n_rows * n_channels * sizeof(float));
- f_compute_zfzT();
-
- free(reinterpret_cast<void*>(matrices_zf));
-}
-/*****************************************************************************/
-
-#endif // __aarch64__