/* * Copyright (c) 2017 ARM Limited. * * SPDX-License-Identifier: MIT * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to * deal in the Software without restriction, including without limitation the * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or * sell copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #pragma once #ifdef __aarch64__ /*****************************************************************************/ // Compute ZF specializations template <> template <> inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zf<0>( const int n_rows, const int n_channels, float* output, const float* const input[16] ) { // Make copies of some variables int row = n_rows; float* outptr = output; const float* inptr = input[0]; // Perform the transformation asm volatile ( // "inptr0 .req %x[inptr]\n" "inptr1 .req x0\n" "inptr2 .req x1\n" "inptr3 .req x2\n" "inptr4 .req x3\n" "inptr5 .req x4\n" "inptr6 .req x5\n" "inptr7 .req x6\n" "inptr8 .req x7\n" "inptr9 .req x8\n" "inptr10 .req x9\n" "inptr11 .req x10\n" "inptr12 .req x11\n" "inptr13 .req x12\n" "inptr14 .req x13\n" "inptr15 .req x14\n" // "outptr0 .req %x[outptr]\n" "outptr1 .req x15\n" "outptr2 .req x16\n" "outptr3 .req x17\n" "outptr4 .req x18\n" "outptr5 .req x19\n" "outptr6 .req x20\n" "outptr7 .req x21\n" // Compute additional pointers into the input and output matrices. "mstride .req x22\n" // Matrix stride "mul mstride, %x[row], %x[n_channels]\n" "lsl mstride, mstride, #2\n" // * sizeof(float) "add inptr1, %x[inptr], mstride\n" "add inptr2, %x[inptr], mstride, LSL #1\n" "add inptr3, inptr2, mstride\n" "add inptr4, inptr3, mstride\n" "add inptr5, inptr4, mstride\n" "add inptr6, inptr5, mstride\n" "add inptr7, inptr6, mstride\n" "add inptr8, inptr7, mstride\n" "add inptr9, inptr8, mstride\n" "add inptr10, inptr9, mstride\n" "add inptr11, inptr10, mstride\n" "add inptr12, inptr11, mstride\n" "add inptr13, inptr12, mstride\n" "add inptr14, inptr13, mstride\n" "add inptr15, inptr14, mstride\n" "add outptr1, %[outptr], mstride\n" "add outptr2, outptr1, mstride\n" "add outptr3, outptr2, mstride\n" "add outptr4, outptr3, mstride\n" "add outptr5, outptr4, mstride\n" "add outptr6, outptr5, mstride\n" "add outptr7, outptr6, mstride\n" ".unreq mstride\n" "column .req x22\n" // Column loop counter "1:" // Loop over rows "ldr q0, [%x[inptr]], #0x10\n" "ldr q1, [inptr1], #0x10\n" "ldr q2, [inptr2], #0x10\n" "ldr q3, [inptr3], #0x10\n" "ldr q4, [inptr4], #0x10\n" "ldr q5, [inptr5], #0x10\n" "ldr q6, [inptr6], #0x10\n" "ldr q7, [inptr7], #0x10\n" "subs column, %x[n_channels], #0x4\n" "beq 3f\n" "2:" // Loop over columns "ldr q8, [inptr8], #0x10\n" "prfm pldl1keep, [%x[inptr], #196]\n" "fadd v16.4s, v0.4s, v1.4s\n" "ldr q9, [inptr9], #0x10\n" "prfm pldl1keep, [inptr1, #196]\n" "fsub v17.4s, v1.4s, v2.4s\n" "ldr q10, [inptr10], #0x10\n" "prfm pldl1keep, [inptr2, #196]\n" "fadd v16.4s, v16.4s, v2.4s\n" "ldr q11, [inptr11], #0x10\n" "prfm pldl1keep, [inptr3, #196]\n" "fsub v17.4s, v17.4s, v3.4s\n" "ldr q12, [inptr12], #0x10\n" "prfm pldl1keep, [inptr4, #196]\n" "str q16, [%x[outptr]], #0x10\n" "ldr q13, [inptr13], #0x10\n" "prfm pldl1keep, [inptr5, #196]\n" "str q17, [outptr1], #0x10\n" "ldr q14, [inptr14], #0x10\n" "prfm pldl1keep, [inptr6, #196]\n" "fadd v16.4s, v4.4s, v5.4s\n" "ldr q15, [inptr15], #0x10\n" "prfm pldl1keep, [inptr7, #196]\n" "fsub v17.4s, v5.4s, v6.4s\n" "ldr q0, [%x[inptr]], #0x10\n" "prfm pldl1keep, [inptr8, #196]\n" "fadd v16.4s, v16.4s, v6.4s\n" "ldr q1, [inptr1], #0x10\n" "prfm pldl1keep, [inptr9, #196]\n" "fsub v17.4s, v17.4s, v7.4s\n" "ldr q2, [inptr2], #0x10\n" "prfm pldl1keep, [inptr10, #196]\n" "str q16, [outptr2], #0x10\n" "ldr q3, [inptr3], #0x10\n" "prfm pldl1keep, [inptr11, #196]\n" "str q17, [outptr3], #0x10\n" "ldr q4, [inptr4], #0x10\n" "prfm pldl1keep, [inptr12, #196]\n" "fadd v16.4s, v8.4s, v9.4s\n" "ldr q5, [inptr5], #0x10\n" "prfm pldl1keep, [inptr13, #196]\n" "fsub v17.4s, v9.4s, v10.4s\n" "ldr q6, [inptr6], #0x10\n" "prfm pldl1keep, [inptr14, #196]\n" "fadd v16.4s, v16.4s, v10.4s\n" "ldr q7, [inptr7], #0x10\n" "prfm pldl1keep, [inptr15, #196]\n" "fsub v17.4s, v17.4s, v11.4s\n" "str q16, [outptr4], #0x10\n" "fadd v16.4s, v12.4s, v13.4s\n" "fsub v18.4s, v13.4s, v14.4s\n" "str q17, [outptr5], #0x10\n" "fadd v16.4s, v16.4s, v14.4s\n" "fsub v18.4s, v18.4s, v15.4s\n" "str q16, [outptr6], #0x10\n" "subs column, column, #0x4\n" "str q18, [outptr7], #0x10\n" "bne 2b\n" "3:" // Tail "ldr q8, [inptr8], #0x10\n" "prfm pldl1keep, [%x[inptr], #196]\n" "fadd v16.4s, v0.4s, v1.4s\n" "ldr q9, [inptr9], #0x10\n" "prfm pldl1keep, [inptr1, #196]\n" "fsub v17.4s, v1.4s, v2.4s\n" "ldr q10, [inptr10], #0x10\n" "prfm pldl1keep, [inptr2, #196]\n" "fadd v16.4s, v16.4s, v2.4s\n" "ldr q11, [inptr11], #0x10\n" "prfm pldl1keep, [inptr3, #196]\n" "fsub v17.4s, v17.4s, v3.4s\n" "ldr q12, [inptr12], #0x10\n" "prfm pldl1keep, [inptr4, #196]\n" "str q16, [%x[outptr]], #0x10\n" "ldr q13, [inptr13], #0x10\n" "prfm pldl1keep, [inptr5, #196]\n" "str q17, [outptr1], #0x10\n" "ldr q14, [inptr14], #0x10\n" "prfm pldl1keep, [inptr6, #196]\n" "fadd v16.4s, v4.4s, v5.4s\n" "ldr q15, [inptr15], #0x10\n" "prfm pldl1keep, [inptr7, #196]\n" "fsub v17.4s, v5.4s, v6.4s\n" "prfm pldl1keep, [inptr8, #196]\n" "prfm pldl1keep, [inptr9, #196]\n" "fadd v16.4s, v16.4s, v6.4s\n" "prfm pldl1keep, [inptr10, #196]\n" "prfm pldl1keep, [inptr11, #196]\n" "fsub v17.4s, v17.4s, v7.4s\n" "prfm pldl1keep, [inptr12, #196]\n" "prfm pldl1keep, [inptr13, #196]\n" "str q16, [outptr2], #0x10\n" "prfm pldl1keep, [inptr14, #196]\n" "prfm pldl1keep, [inptr15, #196]\n" "str q17, [outptr3], #0x10\n" "fadd v16.4s, v8.4s, v9.4s\n" "fsub v17.4s, v9.4s, v10.4s\n" "fadd v16.4s, v16.4s, v10.4s\n" "fsub v17.4s, v17.4s, v11.4s\n" "str q16, [outptr4], #0x10\n" "fadd v16.4s, v12.4s, v13.4s\n" "fsub v18.4s, v13.4s, v14.4s\n" "str q17, [outptr5], #0x10\n" "fadd v16.4s, v16.4s, v14.4s\n" "fsub v18.4s, v18.4s, v15.4s\n" "str q16, [outptr6], #0x10\n" "str q18, [outptr7], #0x10\n" "subs %x[row], %x[row], #0x1\n" "bne 1b\n" ".unreq inptr1\n" ".unreq inptr2\n" ".unreq inptr3\n" ".unreq inptr4\n" ".unreq inptr5\n" ".unreq inptr6\n" ".unreq inptr7\n" ".unreq inptr8\n" ".unreq inptr9\n" ".unreq inptr10\n" ".unreq inptr11\n" ".unreq inptr12\n" ".unreq inptr13\n" ".unreq inptr14\n" ".unreq inptr15\n" ".unreq outptr1\n" ".unreq outptr2\n" ".unreq outptr3\n" ".unreq outptr4\n" ".unreq outptr5\n" ".unreq outptr6\n" ".unreq outptr7\n" : [row] "+r" (row), [inptr] "+r" (inptr), [outptr] "+r" (outptr) : [n_channels] "r" (n_channels), [sizeof_float] "i" (sizeof(float)) : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "q16", "q17", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x18", "x19", "x20", "x21", "x22", "cc", "memory" ); } /*****************************************************************************/ // Compute ZFZ^T specializations template <> template <> inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zfzT( const Tensor4DShape &output_shape, float* const output, const float* const input ) { const int tile_M = output_shape.n_rows / 2; const int tile_N = output_shape.n_cols / 2; int batch = output_shape.n_batches; float *outptr = output; const float *inptr = input; asm volatile ( // Compute input pointers "inptr1 .req x0\n" "inptr2 .req x1\n" "inptr3 .req x2\n" "inptr4 .req x3\n" "inptr5 .req x4\n" "inptr6 .req x5\n" "inptr7 .req x6\n" "inptr8 .req x7\n" "mstride .req x8\n" "mul mstride, %x[tile_M], %x[tile_N]\n" "mul mstride, mstride, %x[n_channels]\n" "lsl mstride, mstride, #2\n" // * sizeof(float) "add inptr1, %[inptr], mstride\n" "add inptr2, inptr1, mstride\n" "add inptr3, inptr2, mstride\n" "add inptr4, inptr3, mstride\n" "add inptr5, inptr4, mstride\n" "add inptr6, inptr5, mstride\n" "add inptr7, inptr6, mstride\n" "add inptr8, inptr7, mstride\n" ".unreq mstride\n" // Compute initial output pointers "outptr01 .req x8\n" "outptr10 .req x9\n" "outptr11 .req x10\n" "add outptr01, %x[outptr], %x[n_channels], LSL #2\n" "add outptr10, %x[outptr], %x[row_stride], LSL #2\n" "add outptr11, outptr10, %x[n_channels], LSL #2\n" "tile_i .req x11\n" "tile_j .req x12\n" "channel .req x13\n" "1:" // Loop over batches "mov tile_i, %x[tile_M]\n" "2:" // Loop over rows of output tiles "mov tile_j, %x[tile_N]\n" "3:" // Loop over columns of output tiles "ldr q0, [%x[inptr]], #0x10\n" "ldr q2, [inptr2], #0x10\n" "subs channel, %x[n_channels], #0x4\n" "ldr q1, [inptr1], #0x10\n" "ldr q3, [inptr3], #0x10\n" "beq 6f\n" "4:" "ldr q4, [inptr4], #0x10\n" "ldr q5, [inptr5], #0x10\n" "fadd v16.4s, v0.4s, v2.4s\n" "ldr q6, [inptr6], #0x10\n" "ldr q7, [inptr7], #0x10\n" "fadd v17.4s, v1.4s, v3.4s\n" "ldr q8, [%x[inptr]], #0x10\n" "ldr q10, [inptr2], #0x10\n" "fadd v16.4s, v16.4s, v4.4s\n" "ldr q9, [inptr1], #0x10\n" "ldr q11, [inptr3], #0x10\n" "fadd v17.4s, v17.4s, v5.4s\n" "str q16, [%x[outptr]], #0x10\n" "prfm pldl1strm, [%x[inptr], #196]\n" "fsub v18.4s, v2.4s, v4.4s\n" "str q17, [outptr01], #0x10\n" "prfm pldl1strm, [inptr2, #196]\n" "fsub v19.4s, v3.4s, v5.4s\n" "prfm pldl1strm, [inptr1, #196]\n" "prfm pldl1strm, [inptr3, #196]\n" "fsub v18.4s, v18.4s, v6.4s\n" "prfm pldl1strm, [inptr4, #196]\n" "prfm pldl1strm, [inptr5, #196]\n" "fsub v19.4s, v19.4s, v7.4s\n" "str q18, [outptr10], #0x10\n" "prfm pldl1strm, [inptr6, #196]\n" "prfm pldl1strm, [inptr7, #196]\n" "subs channel, channel, #0x4\n" "str q19, [outptr11], #0x10\n" "beq 6f\n" // Branch to tail "ldr q12, [inptr4], #0x10\n" "ldr q13, [inptr5], #0x10\n" "fadd v16.4s, v8.4s, v10.4s\n" "ldr q14, [inptr6], #0x10\n" "ldr q15, [inptr7], #0x10\n" "fadd v17.4s, v9.4s, v11.4s\n" "ldr q0, [%x[inptr]], #0x10\n" "ldr q2, [inptr2], #0x10\n" "fadd v16.4s, v16.4s, v12.4s\n" "ldr q1, [inptr1], #0x10\n" "ldr q3, [inptr3], #0x10\n" "fadd v17.4s, v17.4s, v13.4s\n" "str q16, [%x[outptr]], #0x10\n" "prfm pldl1strm, [%x[inptr], #196]\n" "fsub v18.4s, v10.4s, v12.4s\n" "str q17, [outptr01], #0x10\n" "prfm pldl1strm, [inptr2, #196]\n" "fsub v19.4s, v11.4s, v13.4s\n" "prfm pldl1strm, [inptr1, #196]\n" "prfm pldl1strm, [inptr3, #196]\n" "fsub v18.4s, v18.4s, v14.4s\n" "prfm pldl1strm, [inptr4, #196]\n" "prfm pldl1strm, [inptr5, #196]\n" "fsub v19.4s, v19.4s, v15.4s\n" "str q18, [outptr10], #0x10\n" "prfm pldl1strm, [inptr6, #196]\n" "prfm pldl1strm, [inptr7, #196]\n" "subs channel, channel, #0x4\n" "str q19, [outptr11], #0x10\n" "bne 4b\n" // Continue loop "5:" // Tail "ldr q12, [inptr4], #0x10\n" "ldr q13, [inptr5], #0x10\n" "fadd v16.4s, v8.4s, v10.4s\n" "ldr q14, [inptr6], #0x10\n" "ldr q15, [inptr7], #0x10\n" "fadd v17.4s, v9.4s, v11.4s\n" "fadd v16.4s, v16.4s, v12.4s\n" "fadd v17.4s, v17.4s, v13.4s\n" "str q16, [%x[outptr]], #0x10\n" "fsub v18.4s, v10.4s, v12.4s\n" "fsub v19.4s, v11.4s, v13.4s\n" "str q17, [outptr01], #0x10\n" "fsub v18.4s, v18.4s, v14.4s\n" "fsub v19.4s, v19.4s, v15.4s\n" "str q18, [outptr10], #0x10\n" "str q19, [outptr11], #0x10\n" "b 7f\n" "6:" // Tail "ldr q4, [inptr4], #0x10\n" "ldr q5, [inptr5], #0x10\n" "fadd v16.4s, v0.4s, v2.4s\n" "ldr q6, [inptr6], #0x10\n" "ldr q7, [inptr7], #0x10\n" "fadd v17.4s, v1.4s, v3.4s\n" "fadd v16.4s, v16.4s, v4.4s\n" "fadd v17.4s, v17.4s, v5.4s\n" "str q16, [%x[outptr]], #0x10\n" "fsub v18.4s, v2.4s, v4.4s\n" "fsub v19.4s, v3.4s, v5.4s\n" "str q17, [outptr01], #0x10\n" "fsub v18.4s, v18.4s, v6.4s\n" "fsub v19.4s, v19.4s, v7.4s\n" "str q18, [outptr10], #0x10\n" "str q19, [outptr11], #0x10\n" "7:" "add %x[outptr], %x[outptr], %x[n_channels], LSL #2\n" "add outptr01, outptr01, %x[n_channels], LSL #2\n" "add outptr10, outptr10, %x[n_channels], LSL #2\n" "add outptr11, outptr11, %x[n_channels], LSL #2\n" "subs tile_j, tile_j, #1\n" "bne 3b\n" // Progress the output pointers to the new row "add %x[outptr], %x[outptr], %x[row_stride], LSL #2\n" "add outptr01, outptr01, %x[row_stride], LSL #2\n" "add outptr10, outptr10, %x[row_stride], LSL #2\n" "add outptr11, outptr11, %x[row_stride], LSL #2\n" "subs tile_i, tile_i, #1\n" "bne 2b\n" "subs %[batch], %[batch], #1\n" "bne 1b\n" "5:" ".unreq inptr1\n" ".unreq inptr2\n" ".unreq inptr3\n" ".unreq inptr4\n" ".unreq inptr5\n" ".unreq inptr6\n" ".unreq inptr7\n" ".unreq inptr8\n" ".unreq outptr01\n" ".unreq outptr10\n" ".unreq outptr11\n" : [batch] "+r" (batch), [outptr] "+r" (outptr), [inptr] "+r" (inptr) : [tile_M] "r" (tile_M), [tile_N] "r" (tile_N), [n_channels] "r" (output_shape.n_channels), [row_stride] "r" (output_shape.n_cols * output_shape.n_channels) : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "cc", "memory" ); } /*****************************************************************************/ /*****************************************************************************/ template <> inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::execute( const Tensor4DShape &output_shape, float* const matrices[16], float* const output ) { // profiler prof; // Allocate memory for the intermediate matrices const int tile_M = iceildiv(output_shape.n_rows, 2); const int tile_N = iceildiv(output_shape.n_cols, 2); const int n_rows = output_shape.n_batches * tile_M * tile_N; const int n_channels = output_shape.n_channels; float* matrices_zf = reinterpret_cast( calloc(8 * n_rows * n_channels, sizeof(float)) ); // Perform the first stage transform, computing ZF. const auto f_compute_zf = [&] () { switch (n_channels % 4) { case 0: compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); break; case 1: compute_zf<1>(n_rows, n_channels, matrices_zf, matrices); break; case 2: compute_zf<2>(n_rows, n_channels, matrices_zf, matrices); break; case 3: compute_zf<3>(n_rows, n_channels, matrices_zf, matrices); }; }; // prof("Compute ZF", f_compute_zf, 16 * n_rows * n_channels * sizeof(float), 0, 8 * n_rows * n_channels * sizeof(float)); f_compute_zf(); // Perform the second stage transform, finishing Z F Z^T - variable dispatch // based on size of the output and the channel tail. const auto f_compute_zfzT = [&] () { if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { constexpr bool tail_M = true, tail_N = true; switch (n_channels % 4) { case 0: compute_zfzT(output_shape, output, matrices_zf); break; case 1: compute_zfzT(output_shape, output, matrices_zf); break; case 2: compute_zfzT(output_shape, output, matrices_zf); break; case 3: compute_zfzT(output_shape, output, matrices_zf); } } else if (output_shape.n_rows % 2) { constexpr bool tail_M = true, tail_N = false; switch (n_channels % 4) { case 0: compute_zfzT(output_shape, output, matrices_zf); break; case 1: compute_zfzT(output_shape, output, matrices_zf); break; case 2: compute_zfzT(output_shape, output, matrices_zf); break; case 3: compute_zfzT(output_shape, output, matrices_zf); } } else if (output_shape.n_cols % 2) { constexpr bool tail_M = false, tail_N = true; switch (n_channels % 4) { case 0: compute_zfzT(output_shape, output, matrices_zf); break; case 1: compute_zfzT(output_shape, output, matrices_zf); break; case 2: compute_zfzT(output_shape, output, matrices_zf); break; case 3: compute_zfzT(output_shape, output, matrices_zf); } } else { constexpr bool tail_M = false, tail_N = false; switch (n_channels % 4) { case 0: compute_zfzT(output_shape, output, matrices_zf); break; case 1: compute_zfzT(output_shape, output, matrices_zf); break; case 2: compute_zfzT(output_shape, output, matrices_zf); break; case 3: compute_zfzT(output_shape, output, matrices_zf); } } }; // prof("Compute ZFZT", f_compute_zfzT, 8 * n_rows * n_channels * sizeof(float), 0, 4 * n_rows * n_channels * sizeof(float)); f_compute_zfzT(); free(reinterpret_cast(matrices_zf)); } /*****************************************************************************/ #endif // __aarch64__