From 3d4968ac573cc206ac1c6adcfd6f1d4689a715d1 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Mon, 4 Dec 2017 15:03:35 +0000 Subject: COMPMID-687: Winograd refactoring Moved the headers into src/ Added pimpl pattern Change-Id: I227f8b47468d8e14875d710aac8de5eb09463e2a Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/111765 Reviewed-by: Anthony Barbier Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com --- .../core/NEON/kernels/NEWinogradLayerKernel.h | 38 +- arm_compute/core/NEON/kernels/winograd/gemm.hpp | 127 -- .../core/NEON/kernels/winograd/gemm/a64_sgemm.hpp | 355 ----- .../NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp | 1445 ------------------- arm_compute/core/NEON/kernels/winograd/perf.h | 32 - .../core/NEON/kernels/winograd/profiler.hpp | 244 ---- arm_compute/core/NEON/kernels/winograd/shims.hpp | 319 ----- .../core/NEON/kernels/winograd/transforms.hpp | 29 - .../kernels/winograd/transforms/input_2x2_3x3.hpp | 638 --------- .../transforms/input_2x2_3x3/a64_float.hpp | 1498 -------------------- .../input_2x2_3x3/a64_float_channelwise.hpp | 961 ------------- .../kernels/winograd/transforms/kernel_2x2_3x3.hpp | 195 --- .../transforms/kernel_2x2_3x3/a64_float.hpp | 822 ----------- .../kernels/winograd/transforms/output_2x2_3x3.hpp | 356 ----- .../transforms/output_2x2_3x3/a64_float.hpp | 650 --------- .../output_2x2_3x3/a64_float_two_stage.hpp | 655 --------- arm_compute/core/NEON/kernels/winograd/utils.hpp | 55 - .../core/NEON/kernels/winograd/winograd_gemm.hpp | 346 ----- .../NEON/kernels/winograd/winograd_shim_nchw.hpp | 192 --- .../runtime/NEON/functions/NEWinogradLayer.h | 2 - scripts/clang_tidy_rules.py | 1 + src/core/NEON/kernels/NEWinogradLayerKernel.cpp | 79 +- src/core/NEON/kernels/winograd/gemm.hpp | 127 ++ src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp | 355 +++++ .../NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp | 1445 +++++++++++++++++++ src/core/NEON/kernels/winograd/perf.h | 32 + src/core/NEON/kernels/winograd/profiler.hpp | 244 ++++ src/core/NEON/kernels/winograd/shims.hpp | 319 +++++ src/core/NEON/kernels/winograd/transforms.hpp | 29 + .../kernels/winograd/transforms/input_2x2_3x3.hpp | 639 +++++++++ .../transforms/input_2x2_3x3/a64_float.hpp | 1498 ++++++++++++++++++++ .../input_2x2_3x3/a64_float_channelwise.hpp | 961 +++++++++++++ .../kernels/winograd/transforms/kernel_2x2_3x3.hpp | 195 +++ .../transforms/kernel_2x2_3x3/a64_float.hpp | 822 +++++++++++ .../kernels/winograd/transforms/output_2x2_3x3.hpp | 356 +++++ .../transforms/output_2x2_3x3/a64_float.hpp | 650 +++++++++ .../output_2x2_3x3/a64_float_two_stage.hpp | 655 +++++++++ src/core/NEON/kernels/winograd/utils.hpp | 55 + src/core/NEON/kernels/winograd/winograd_gemm.hpp | 345 +++++ .../NEON/kernels/winograd/winograd_shim_nchw.hpp | 191 +++ src/runtime/NEON/functions/NEWinogradLayer.cpp | 8 +- 41 files changed, 9035 insertions(+), 8930 deletions(-) delete mode 100644 arm_compute/core/NEON/kernels/winograd/gemm.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/perf.h delete mode 100644 arm_compute/core/NEON/kernels/winograd/profiler.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/shims.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/transforms.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/utils.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp delete mode 100644 arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp create mode 100644 src/core/NEON/kernels/winograd/gemm.hpp create mode 100644 src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp create mode 100644 src/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp create mode 100644 src/core/NEON/kernels/winograd/perf.h create mode 100644 src/core/NEON/kernels/winograd/profiler.hpp create mode 100644 src/core/NEON/kernels/winograd/shims.hpp create mode 100644 src/core/NEON/kernels/winograd/transforms.hpp create mode 100644 src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp create mode 100644 src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp create mode 100644 src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp create mode 100644 src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp create mode 100644 src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp create mode 100644 src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp create mode 100644 src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp create mode 100644 src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp create mode 100644 src/core/NEON/kernels/winograd/utils.hpp create mode 100644 src/core/NEON/kernels/winograd/winograd_gemm.hpp create mode 100644 src/core/NEON/kernels/winograd/winograd_shim_nchw.hpp diff --git a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h index 1e7ca64b8c..3ab3aa792b 100644 --- a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h +++ b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h @@ -25,17 +25,34 @@ #define __ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__ #include "arm_compute/core/NEON/INEKernel.h" - -#include "arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp" +#include "arm_compute/core/NEON/kernels/winograd/tensor.hpp" namespace arm_compute { class ITensor; +class NEWinogradLayerKernel; +class Winograd3x3F32 +{ +public: + friend class NEWinogradLayerKernel; + Winograd3x3F32(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage); + ~Winograd3x3F32(); + std::pair get_nhwc_ptrs(const Tensor4DShape &input_shape, const PaddingType padding_type, void *working_space); + void transform_weights(const void *const kernel, void *transform_working_space); + void reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const void *const input, void *working_space); + void reshape_output(const Tensor4DShape &input_shape, const PaddingType padding_type, void *const output); + void nchw2nhwc(const Tensor4DShape &input_shape, const PaddingType padding_type, void *working_space, const void *const input); + void nhwc2nchw(const Tensor4DShape &input_shape, const PaddingType padding_type, void *working_space, void *const output); + +private: + class Private; + std::unique_ptr _pimpl; +}; class NEWinogradLayerKernel : public INEKernel { public: - using Winograd3x3F32 = winograd_shim_nchw::Winograd2x2_3x3GEMM; + // using Winograd3x3F32 = winograd_shim_nchw::Winograd2x2_3x3GEMM; /** Constructor */ NEWinogradLayerKernel(); @@ -61,9 +78,22 @@ public: // Inherited methods overridden: void run(const Window &window, const ThreadInfo &info) override; + /* Get the memory required to instantiate a new Winograd operator. + */ + static size_t get_kernel_storage_size(const KernelShape &shape); + + /* Get the memory required to apply a Winograd operator to some input. + */ + static size_t get_working_space_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, const PaddingType padding); + + /* Get the memory required to transform the kernel. + */ + static size_t get_kernel_transform_working_size(const KernelShape &shape); + protected: Winograd3x3F32 *_convolver; - ITensor *_output; + // std::unique_ptr _conv; + ITensor *_output; }; } // namespace arm_compute diff --git a/arm_compute/core/NEON/kernels/winograd/gemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm.hpp deleted file mode 100644 index 564016a646..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/gemm.hpp +++ /dev/null @@ -1,127 +0,0 @@ - -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include "utils.hpp" - -template -void Gemm(const TIn* const a, const TIn* const b, TOut *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride, - const bool a_transposed=false, - const bool b_transposed=false) { - // Array access methods - const auto A = [a, a_transposed, M, K, a_row_stride] (const int i, const int j) -> TIn { - return a[(!a_transposed) ? i*a_row_stride + j : i + j*M]; - }; - - const auto B = [b, b_transposed, K, N, b_row_stride] (const int i, const int j) -> TIn { - return b[(!b_transposed) ? i*b_row_stride + j : i + j*N]; - }; - - const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& { - return c[i*c_row_stride + j]; - }; - - // Perform the matrix multiplication - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - for (int k = 0; k < K; k++) { - C(i, j) += A(i, k) * B(k, j); - } - } - } -} - -template -void BlockedGemm( - const TIn* const a, const TIn* const b, TOut *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - // Array access methods - const auto A = [a, M, K, a_row_stride] (const int i, const int j) -> TIn { - return a[i*a_row_stride + j]; - }; - - const auto B = [b, K, N, b_row_stride] (const int i, const int j) -> TIn { - return b[i*b_row_stride + j]; - }; - - const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& { - return c[i*c_row_stride + j]; - }; - - const int M_BLOCKS = iceildiv(M, M_BLOCK); - const int N_BLOCKS = iceildiv(N, N_BLOCK); - - // For each block of output rows - for (int mblock = 0; mblock < M_BLOCKS; mblock++) { - // For each block of output columns - for (int nblock = 0; nblock < N_BLOCKS; nblock++) { - // Create an appropriately sized block of accumulators - TOut accum[M_BLOCK][N_BLOCK]; - for (int i = 0; i < M_BLOCK; i++) { - for (int j = 0; j < N_BLOCK; j++) { - accum[i][j] = static_cast(0); - } - } - - // Perform this portion of the matrix multiply - for (int k = 0; k < K; k++) { - // Load elements of A - TIn elems_a[M_BLOCK]; - for (int i = 0; i < M_BLOCK; i++) { - elems_a[i] = A(mblock*M_BLOCK + i, k); - } - - // Load elements of B - TIn elems_b[N_BLOCK]; - for (int j = 0; j < N_BLOCK; j++) { - elems_b[j] = B(k, nblock*N_BLOCK + j); - } - - // Perform the partial matrix multiply - for (int i = 0; i < M_BLOCK; i++) { - for (int j = 0; j < N_BLOCK; j++) { - accum[i][j] += elems_a[i] * elems_b[j]; - } - } - } - - // Store the partial product - for (int i = 0; i < M_BLOCK; i++) { - for (int j = 0; j < N_BLOCK; j++) { - C(mblock*M_BLOCK + i, nblock*N_BLOCK + j) = accum[i][j]; - } - } - } - } -} - -#include "gemm/a64_sgemm.hpp" diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp deleted file mode 100644 index e1b7488c31..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp +++ /dev/null @@ -1,355 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include -#include "../utils.hpp" - -#ifdef __aarch64__ - -template <> -inline void BlockedGemm<8, 12, float, float>( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - const int M_BLOCK = 8; - const int N_BLOCK = 12; - - const int m_blocks = iceildiv(M, M_BLOCK); - const int n_blocks = iceildiv(N, N_BLOCK); - - // For each block of output rows - for (int mblock = 0; mblock < m_blocks; mblock++) { - // For each block of output columns - for (int nblock = 0; nblock < n_blocks; nblock++) { - const float *aptr = a + mblock*M_BLOCK*a_row_stride; - const float *bptr = b + nblock*N_BLOCK; - float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; - int k = K; - - asm volatile ( - // Create an 8x12 block of accumulators - " A_1 .req v27\n" - "sA_1 .req s27\n" - " A_2 .req v28\n" - "sA_2 .req s28\n" - " A_3 .req v29\n" - "sA_3 .req s29\n" - " A_4 .req v30\n" - "sA_4 .req s30\n" - - " B_1 .req v24\n" " B_2 .req v25\n" " B_3 .req v26\n" - "qB_1 .req q24\n" "qB_2 .req q25\n" "qB_3 .req q26\n" - - " C_11 .req v0\n" " C_12 .req v1\n" " C_13 .req v2\n" - " C_21 .req v3\n" " C_22 .req v4\n" " C_23 .req v5\n" - " C_31 .req v6\n" " C_32 .req v7\n" " C_33 .req v8\n" - " C_41 .req v9\n" " C_42 .req v10\n" " C_43 .req v11\n" - " C_51 .req v12\n" " C_52 .req v13\n" " C_53 .req v14\n" - " C_61 .req v15\n" " C_62 .req v16\n" " C_63 .req v17\n" - " C_71 .req v18\n" " C_72 .req v19\n" " C_73 .req v20\n" - " C_81 .req v21\n" " C_82 .req v22\n" " C_83 .req v23\n" - - "qC_11 .req q0\n" "qC_12 .req q1\n" "qC_13 .req q2\n" - "qC_21 .req q3\n" "qC_22 .req q4\n" "qC_23 .req q5\n" - "qC_31 .req q6\n" "qC_32 .req q7\n" "qC_33 .req q8\n" - "qC_41 .req q9\n" "qC_42 .req q10\n" "qC_43 .req q11\n" - "qC_51 .req q12\n" "qC_52 .req q13\n" "qC_53 .req q14\n" - "qC_61 .req q15\n" "qC_62 .req q16\n" "qC_63 .req q17\n" - "qC_71 .req q18\n" "qC_72 .req q19\n" "qC_73 .req q20\n" - "qC_81 .req q21\n" "qC_82 .req q22\n" "qC_83 .req q23\n" - - "aptr1 .req x17\n" - "aptr2 .req x18\n" - "aptr3 .req x19\n" - "aptr4 .req x20\n" - "aptr5 .req x21\n" - "aptr6 .req x22\n" - "aptr7 .req x23\n" - - // Initialise accumulators with 0 - // Initialise pointers - "movi C_11.4s, #0\n" - "add aptr1, %x[aptr], %x[a_row_stride]\n" - "movi C_12.4s, #0\n" - "add aptr2, aptr1, %x[a_row_stride]\n" - "movi C_13.4s, #0\n" - "add aptr3, aptr2, %x[a_row_stride]\n" - "movi C_21.4s, #0\n" - "add aptr4, aptr3, %x[a_row_stride]\n" - "movi C_22.4s, #0\n" - "add aptr5, aptr4, %x[a_row_stride]\n" - "movi C_23.4s, #0\n" - "add aptr6, aptr5, %x[a_row_stride]\n" - "movi C_31.4s, #0\n" - "add aptr7, aptr6, %x[a_row_stride]\n" - "movi C_32.4s, #0\n" - "ldr qB_1, [%x[bptr]]\n" - "movi C_33.4s, #0\n" - "ldr qB_2, [%x[bptr], #0x10]\n" - "movi C_41.4s, #0\n" - "prfm pldl1keep, [%x[bptr], #0x00]\n" - "movi C_42.4s, #0\n" - "prfm pldl1keep, [%x[bptr], #0x10]\n" - "movi C_43.4s, #0\n" - "prfm pldl1keep, [%x[bptr], #0x20]\n" - "movi C_51.4s, #0\n" - "prfm pldl1keep, [%x[aptr], #0x00]\n" - "movi C_52.4s, #0\n" - "prfm pldl1keep, [ aptr1, #0x00]\n" - "movi C_53.4s, #0\n" - "prfm pldl1keep, [ aptr2, #0x00]\n" - "movi C_61.4s, #0\n" - "prfm pldl1keep, [ aptr3, #0x00]\n" - "movi C_62.4s, #0\n" - "prfm pldl1keep, [ aptr4, #0x00]\n" - "movi C_63.4s, #0\n" - "prfm pldl1keep, [ aptr5, #0x00]\n" - "movi C_71.4s, #0\n" - "prfm pldl1keep, [ aptr6, #0x00]\n" - "movi C_72.4s, #0\n" - "prfm pldl1keep, [ aptr7, #0x00]\n" - "movi C_73.4s, #0\n" - "ldr sA_1, [%x[aptr]], #0x4\n" - "movi C_81.4s, #0\n" - "ldr sA_2, [ aptr1], #0x4\n" - "movi C_82.4s, #0\n" - "ldr sA_3, [ aptr2], #0x4\n" - "movi C_83.4s, #0\n" - "subs %x[k], %x[k], #1\n" - "beq 2f\n" - - "1:" - "fmla C_11.4s, B_1.4s, A_1.s[0]\n" - "ldr qB_3, [%x[bptr], #0x20]\n" - "fmla C_12.4s, B_2.4s, A_1.s[0]\n" - "ldr sA_4, [ aptr3], #0x4\n" - "fmla C_13.4s, B_3.4s, A_1.s[0]\n" - "ldr sA_1, [ aptr4], #0x04\n" - - "fmla C_21.4s, B_1.4s, A_2.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride]\n" - "fmla C_22.4s, B_2.4s, A_2.s[0]\n" - "prfm pldl1keep, [ aptr3, #0x10]\n" - "fmla C_23.4s, B_3.4s, A_2.s[0]\n" - "ldr sA_2, [ aptr5], #0x04\n" - - "fmla C_31.4s, B_1.4s, A_3.s[0]\n" - "prfm pldl1keep, [%x[bptr], #0x00]\n" - "fmla C_32.4s, B_2.4s, A_3.s[0]\n" - "prfm pldl1keep, [%x[bptr], #0x10]\n" - "fmla C_33.4s, B_3.4s, A_3.s[0]\n" - "ldr sA_3, [ aptr6], #0x04\n" - - "fmla C_41.4s, B_1.4s, A_4.s[0]\n" - "prfm pldl1keep, [%x[bptr], #0x20]\n" - "fmla C_42.4s, B_2.4s, A_4.s[0]\n" - "prfm pldl1keep, [ aptr4, #0x10]\n" - "fmla C_43.4s, B_3.4s, A_4.s[0]\n" - "ldr sA_4, [ aptr7], #0x04\n" - - "fmla C_51.4s, B_1.4s, A_1.s[0]\n" - "prfm pldl1keep, [ aptr5, #0x10]\n" - "fmla C_52.4s, B_2.4s, A_1.s[0]\n" - "prfm pldl1keep, [ aptr6, #0x10]\n" - "fmla C_53.4s, B_3.4s, A_1.s[0]\n" - "ldr sA_1, [%x[aptr]], #0x04\n" - - "fmla C_61.4s, B_1.4s, A_2.s[0]\n" - "prfm pldl1keep, [ aptr7, #0x10]\n" - "fmla C_62.4s, B_2.4s, A_2.s[0]\n" - "subs %x[k], %x[k], #1\n" - "fmla C_63.4s, B_3.4s, A_2.s[0]\n" - "ldr sA_2, [ aptr1], #0x04\n" - - "fmla C_71.4s, B_1.4s, A_3.s[0]\n" - "prfm pldl1keep, [%x[aptr], #0x10]\n" - "fmla C_72.4s, B_2.4s, A_3.s[0]\n" - "prfm pldl1keep, [ aptr1, #0x10]\n" - "fmla C_73.4s, B_3.4s, A_3.s[0]\n" - "ldr sA_3, [ aptr2], #0x04\n" - - "fmla C_81.4s, B_1.4s, A_4.s[0]\n" - "prfm pldl1keep, [ aptr2, #0x10]\n" - "fmla C_82.4s, B_2.4s, A_4.s[0]\n" - "ldp qB_1, qB_2, [%x[bptr]]\n" - "fmla C_83.4s, B_3.4s, A_4.s[0]\n" - "bne 1b\n" - - "2:" - "fmla C_11.4s, B_1.4s, A_1.s[0]\n" - "ldr qB_3, [%x[bptr], #0x20]\n" - "fmla C_12.4s, B_2.4s, A_1.s[0]\n" - "stp qC_11, qC_12, [%x[cptr]]\n" - "fmla C_13.4s, B_3.4s, A_1.s[0]\n" - "str qC_13, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - "ldr sA_1, [ aptr4], #0x04\n" - - "fmla C_21.4s, B_1.4s, A_2.s[0]\n" - "ldr sA_4, [ aptr3], #0x4\n" - "fmla C_22.4s, B_2.4s, A_2.s[0]\n" - "stp qC_21, qC_22, [%x[cptr]]\n" - "fmla C_23.4s, B_3.4s, A_2.s[0]\n" - "str qC_23, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - "ldr sA_2, [ aptr5], #0x04\n" - - "fmla C_31.4s, B_1.4s, A_3.s[0]\n" - "fmla C_32.4s, B_2.4s, A_3.s[0]\n" - "stp qC_31, qC_32, [%x[cptr]]\n" - "fmla C_33.4s, B_3.4s, A_3.s[0]\n" - "str qC_33, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - "ldr sA_3, [ aptr6], #0x04\n" - - "fmla C_41.4s, B_1.4s, A_4.s[0]\n" - "fmla C_42.4s, B_2.4s, A_4.s[0]\n" - "stp qC_41, qC_42, [%x[cptr]]\n" - "fmla C_43.4s, B_3.4s, A_4.s[0]\n" - "str qC_43, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - "ldr sA_4, [ aptr7], #0x04\n" - - "fmla C_51.4s, B_1.4s, A_1.s[0]\n" - "fmla C_52.4s, B_2.4s, A_1.s[0]\n" - "stp qC_51, qC_52, [%x[cptr]]\n" - "fmla C_53.4s, B_3.4s, A_1.s[0]\n" - "str qC_53, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - - "fmla C_61.4s, B_1.4s, A_2.s[0]\n" - "fmla C_62.4s, B_2.4s, A_2.s[0]\n" - "stp qC_61, qC_62, [%x[cptr]]\n" - "fmla C_63.4s, B_3.4s, A_2.s[0]\n" - "str qC_63, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - - "fmla C_71.4s, B_1.4s, A_3.s[0]\n" - "fmla C_72.4s, B_2.4s, A_3.s[0]\n" - "stp qC_71, qC_72, [%x[cptr]]\n" - "fmla C_73.4s, B_3.4s, A_3.s[0]\n" - "str qC_73, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - - "fmla C_81.4s, B_1.4s, A_4.s[0]\n" - "fmla C_82.4s, B_2.4s, A_4.s[0]\n" - "stp qC_81, qC_82, [%x[cptr]]\n" - "fmla C_83.4s, B_3.4s, A_4.s[0]\n" - "str qC_83, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - - // Clear aliases - ".unreq aptr1\n" - ".unreq aptr2\n" - ".unreq aptr3\n" - ".unreq aptr4\n" - ".unreq aptr5\n" - ".unreq aptr6\n" - ".unreq aptr7\n" - - ".unreq A_1\n" ".unreq A_2\n" ".unreq A_3\n" ".unreq A_4\n" - ".unreq sA_1\n" ".unreq sA_2\n" ".unreq sA_3\n" ".unreq sA_4\n" - - ".unreq B_1\n" ".unreq B_2\n" ".unreq B_3\n" - ".unreq qB_1\n" ".unreq qB_2\n" ".unreq qB_3\n" - - ".unreq C_11\n" ".unreq C_12\n" ".unreq C_13\n" - ".unreq C_21\n" ".unreq C_22\n" ".unreq C_23\n" - ".unreq C_31\n" ".unreq C_32\n" ".unreq C_33\n" - ".unreq C_41\n" ".unreq C_42\n" ".unreq C_43\n" - ".unreq C_51\n" ".unreq C_52\n" ".unreq C_53\n" - ".unreq C_61\n" ".unreq C_62\n" ".unreq C_63\n" - ".unreq C_71\n" ".unreq C_72\n" ".unreq C_73\n" - ".unreq C_81\n" ".unreq C_82\n" ".unreq C_83\n" - - ".unreq qC_11\n" ".unreq qC_12\n" ".unreq qC_13\n" - ".unreq qC_21\n" ".unreq qC_22\n" ".unreq qC_23\n" - ".unreq qC_31\n" ".unreq qC_32\n" ".unreq qC_33\n" - ".unreq qC_41\n" ".unreq qC_42\n" ".unreq qC_43\n" - ".unreq qC_51\n" ".unreq qC_52\n" ".unreq qC_53\n" - ".unreq qC_61\n" ".unreq qC_62\n" ".unreq qC_63\n" - ".unreq qC_71\n" ".unreq qC_72\n" ".unreq qC_73\n" - ".unreq qC_81\n" ".unreq qC_82\n" ".unreq qC_83\n" - : [aptr] "+r" (aptr), - [bptr] "+r" (bptr), - [cptr] "+r" (cptr), - [k] "+r" (k) - : [a_row_stride] "r" (a_row_stride * sizeof(float)), - [b_row_stride] "r" (b_row_stride * sizeof(float)), - [c_row_stride] "r" (c_row_stride * sizeof(float)) - : "cc", "memory", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "x17", "x18", "x19", "x20", "x21", "x22", "x23" - ); - } - } -} - -/*****************************************************************************/ -/* 4x16 blocked GEMM with specialised tails - */ -#include "a64_sgemm_4x16.hpp" - -template <> -inline void BlockedGemm<4, 16, float, float>( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - // Despatch based on tail of K - switch (K % 4) { - case 3: - sgemm_4x16_impl<3>( - a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride - ); - break; - case 2: - sgemm_4x16_impl<2>( - a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride - ); - break; - case 1: - sgemm_4x16_impl<1>( - a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride - ); - break; - case 0: - sgemm_4x16_impl<0>( - a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride - ); - break; - default: - assert(0); - break; - } -} - -#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp deleted file mode 100644 index e74610ef27..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp +++ /dev/null @@ -1,1445 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -template -inline void sgemm_4x16_impl( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -); - -template <> -inline void sgemm_4x16_impl<0>( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - const int TAIL_SIZE = 0; - const int M_BLOCK = 4; - const int N_BLOCK = 16; - - const int m_blocks = iceildiv(M, M_BLOCK); - const int n_blocks = iceildiv(N, N_BLOCK); - - // For each block of output rows - for (int mblock = 0; mblock < m_blocks; mblock++) { - // For each block of output columns - for (int nblock = 0; nblock < n_blocks; nblock++) { - const float *aptr = a + mblock*M_BLOCK*a_row_stride; - const float *bptr = b + nblock*N_BLOCK; - float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; - int k = (K - TAIL_SIZE) / 4; - - asm volatile( - "aptr2 .req X20\n" - "aptr3 .req X21\n" - "aptr4 .req X22\n" - "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" - "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" - "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" - "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" - "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" - "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" - "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" - "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" - "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" - "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" - "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" - "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" - "vB1 .req v20\n" "qB1 .req q20\n" - "vB2 .req v21\n" "qB2 .req q21\n" - "vB3 .req v22\n" "qB3 .req q22\n" - "vB4 .req v23\n" "qB4 .req q23\n" - - // Clear accumulators, initialise pointers - "movi vC11.4s, #0\n" - "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" - "movi vC12.4s, #0\n" - "add aptr3, aptr2, %x[a_row_stride_bytes]\n" - "movi vC13.4s, #0\n" - "add aptr4, aptr3, %x[a_row_stride_bytes]\n" - "movi vC14.4s, #0\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "movi vC21.4s, #0\n" - "ldr qA2, [ aptr2], #0x10\n" - "movi vC22.4s, #0\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "movi vC23.4s, #0\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "movi vC24.4s, #0\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "movi vC31.4s, #0\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "subs %x[k], %x[k], #1\n" - "beq 2f\n" - - "1:" // Loop proper - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "subs %x[k], %x[k], #1\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr qA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - "bne 1b\n" - - "2:" // Tail - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "stp qC11, qC12, [%x[cptr], #0x00]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "stp qC13, qC14, [%x[cptr], #0x20]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "stp qC21, qC22, [%x[cptr], #0x00]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "stp qC23, qC24, [%x[cptr], #0x20]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "stp qC31, qC32, [%x[cptr], #0x00]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "stp qC33, qC34, [%x[cptr], #0x20]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "stp qC41, qC42, [%x[cptr], #0x00]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - "stp qC43, qC44, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - - ".unreq vB4\n" ".unreq qB4\n" - ".unreq vB3\n" ".unreq qB3\n" - ".unreq vB2\n" ".unreq qB2\n" - ".unreq vB1\n" ".unreq qB1\n" - ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" - ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" - ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" - ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" - ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" - ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" - ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" - ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" - ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" - ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" - ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" - ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" - ".unreq aptr2\n" - ".unreq aptr3\n" - ".unreq aptr4\n" - - : [aptr] "+r" (aptr), - [bptr] "+r" (bptr), - [cptr] "+r" (cptr), - [k] "+r" (k) - : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), - [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), - [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) - : "cc", "memory", "x20", "x21", "x22", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", - "v21", "v22", "v23" - ); - } - } -} - -template <> -inline void sgemm_4x16_impl<1>( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - const int TAIL_SIZE = 1; - const int M_BLOCK = 4; - const int N_BLOCK = 16; - - const int m_blocks = iceildiv(M, M_BLOCK); - const int n_blocks = iceildiv(N, N_BLOCK); - - // For each block of output rows - for (int mblock = 0; mblock < m_blocks; mblock++) { - // For each block of output columns - for (int nblock = 0; nblock < n_blocks; nblock++) { - const float *aptr = a + mblock*M_BLOCK*a_row_stride; - const float *bptr = b + nblock*N_BLOCK; - float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; - int k = (K - TAIL_SIZE) / 4; - - asm volatile( - "aptr2 .req X20\n" - "aptr3 .req X21\n" - "aptr4 .req X22\n" - "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" - "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" - "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" - "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" - "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" - "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" - "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" - "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" - "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" - "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" - "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" - "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" - "vB1 .req v20\n" "qB1 .req q20\n" - "vB2 .req v21\n" "qB2 .req q21\n" - "vB3 .req v22\n" "qB3 .req q22\n" - "vB4 .req v23\n" "qB4 .req q23\n" - - // Clear accumulators, initialise pointers - "movi vC11.4s, #0\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "movi vC12.4s, #0\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "movi vC13.4s, #0\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "movi vC14.4s, #0\n" - "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" - "movi vC21.4s, #0\n" - "add aptr3, aptr2, %x[a_row_stride_bytes]\n" - "movi vC22.4s, #0\n" - "add aptr4, aptr3, %x[a_row_stride_bytes]\n" - "movi vC23.4s, #0\n" - "cbnz %x[k], 3f\n" - - // Prepare for tail in K - "movi vC24.4s, #0\n" - "ldr sA1, [%x[aptr]], #0x04\n" - "movi vC31.4s, #0\n" - "ldr sA2, [ aptr2], #0x04\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "b 2f\n" // Jump to tail - - "3:" // Prepare for loop over K - "movi vC24.4s, #0\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "movi vC31.4s, #0\n" - "ldr qA2, [ aptr2], #0x10\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "subs %x[k], %x[k], #1\n" - "beq 4f\n" - - "1:" // Loop proper - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "subs %x[k], %x[k], #1\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr qA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - "bne 1b\n" - - "4:" // Tail iteration - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr sA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr sA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - - "2:" // Common tail - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "stp qC11, qC12, [%x[cptr], #0x00]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "ldr sA3, [ aptr3], #0x10\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "stp qC13, qC14, [%x[cptr], #0x20]\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "stp qC21, qC22, [%x[cptr], #0x00]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "ldr sA4, [ aptr4], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "stp qC23, qC24, [%x[cptr], #0x20]\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "stp qC31, qC32, [%x[cptr], #0x00]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "stp qC33, qC34, [%x[cptr], #0x20]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "stp qC41, qC42, [%x[cptr], #0x00]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - "stp qC43, qC44, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - - ".unreq vB4\n" ".unreq qB4\n" - ".unreq vB3\n" ".unreq qB3\n" - ".unreq vB2\n" ".unreq qB2\n" - ".unreq vB1\n" ".unreq qB1\n" - ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" - ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" - ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" - ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" - ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" - ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" - ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" - ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" - ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" - ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" - ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" - ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" - ".unreq aptr2\n" - ".unreq aptr3\n" - ".unreq aptr4\n" - - : [aptr] "+r" (aptr), - [bptr] "+r" (bptr), - [cptr] "+r" (cptr), - [k] "+r" (k) - : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), - [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), - [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) - : "cc", "memory", "x20", "x21", "x22", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", - "v21", "v22", "v23" - ); - } - } -} - -template <> -inline void sgemm_4x16_impl<2>( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - const int TAIL_SIZE = 2; - const int M_BLOCK = 4; - const int N_BLOCK = 16; - - const int m_blocks = iceildiv(M, M_BLOCK); - const int n_blocks = iceildiv(N, N_BLOCK); - - // For each block of output rows - for (int mblock = 0; mblock < m_blocks; mblock++) { - // For each block of output columns - for (int nblock = 0; nblock < n_blocks; nblock++) { - const float *aptr = a + mblock*M_BLOCK*a_row_stride; - const float *bptr = b + nblock*N_BLOCK; - float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; - int k = (K - TAIL_SIZE) / 4; - - asm volatile( - "aptr2 .req X20\n" - "aptr3 .req X21\n" - "aptr4 .req X22\n" - "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" - "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" - "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" - "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" - "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" - "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" - "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" - "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" - "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" - "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" - "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" - "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" - "vB1 .req v20\n" "qB1 .req q20\n" - "vB2 .req v21\n" "qB2 .req q21\n" - "vB3 .req v22\n" "qB3 .req q22\n" - "vB4 .req v23\n" "qB4 .req q23\n" - - // Clear accumulators, initialise pointers - "movi vC11.4s, #0\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "movi vC12.4s, #0\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "movi vC13.4s, #0\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "movi vC14.4s, #0\n" - "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" - "movi vC21.4s, #0\n" - "add aptr3, aptr2, %x[a_row_stride_bytes]\n" - "movi vC22.4s, #0\n" - "add aptr4, aptr3, %x[a_row_stride_bytes]\n" - "movi vC23.4s, #0\n" - "cbnz %x[k], 3f\n" - - // Prepare for tail in K - "movi vC24.4s, #0\n" - "ldr dA1, [%x[aptr]], #0x08\n" - "movi vC31.4s, #0\n" - "ldr dA2, [ aptr2], #0x08\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "b 2f\n" // Jump to tail - - "3:" // Prepare for loop over K - "movi vC24.4s, #0\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "movi vC31.4s, #0\n" - "ldr qA2, [ aptr2], #0x10\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "subs %x[k], %x[k], #1\n" - "beq 4f\n" - - "1:" // Loop proper - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "subs %x[k], %x[k], #1\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr qA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - "bne 1b\n" - - "4:" // Tail iteration - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr dA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr dA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - - "2:" // Common tail - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr dA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr dA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "stp qC11, qC12, [%x[cptr], #0x00]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "stp qC13, qC14, [%x[cptr], #0x20]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "stp qC21, qC22, [%x[cptr], #0x00]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "stp qC23, qC24, [%x[cptr], #0x20]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "stp qC31, qC32, [%x[cptr], #0x00]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "stp qC33, qC34, [%x[cptr], #0x20]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "stp qC41, qC42, [%x[cptr], #0x00]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - "stp qC43, qC44, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - - ".unreq vB4\n" ".unreq qB4\n" - ".unreq vB3\n" ".unreq qB3\n" - ".unreq vB2\n" ".unreq qB2\n" - ".unreq vB1\n" ".unreq qB1\n" - ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" - ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" - ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" - ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" - ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" - ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" - ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" - ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" - ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" - ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" - ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" - ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" - ".unreq aptr2\n" - ".unreq aptr3\n" - ".unreq aptr4\n" - - : [aptr] "+r" (aptr), - [bptr] "+r" (bptr), - [cptr] "+r" (cptr), - [k] "+r" (k) - : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), - [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), - [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) - : "cc", "memory", "x20", "x21", "x22", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", - "v21", "v22", "v23" - ); - } - } -} - -template <> -inline void sgemm_4x16_impl<3>( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - const int TAIL_SIZE = 3; - const int M_BLOCK = 4; - const int N_BLOCK = 16; - - const int m_blocks = iceildiv(M, M_BLOCK); - const int n_blocks = iceildiv(N, N_BLOCK); - - // For each block of output rows - for (int mblock = 0; mblock < m_blocks; mblock++) { - // For each block of output columns - for (int nblock = 0; nblock < n_blocks; nblock++) { - const float *aptr = a + mblock*M_BLOCK*a_row_stride; - const float *bptr = b + nblock*N_BLOCK; - float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; - int k = (K - TAIL_SIZE) / 4; - - asm volatile( - "aptr2 .req X20\n" - "aptr3 .req X21\n" - "aptr4 .req X22\n" - "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" - "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" - "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" - "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" - "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" - "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" - "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" - "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" - "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" - "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" - "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" - "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" - "vB1 .req v20\n" "qB1 .req q20\n" - "vB2 .req v21\n" "qB2 .req q21\n" - "vB3 .req v22\n" "qB3 .req q22\n" - "vB4 .req v23\n" "qB4 .req q23\n" - - // Clear accumulators, initialise pointers - "movi vC11.4s, #0\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "movi vC12.4s, #0\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "movi vC13.4s, #0\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "movi vC14.4s, #0\n" - "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" - "movi vC21.4s, #0\n" - "add aptr3, aptr2, %x[a_row_stride_bytes]\n" - "movi vC22.4s, #0\n" - "add aptr4, aptr3, %x[a_row_stride_bytes]\n" - "movi vC23.4s, #0\n" - "cbnz %x[k], 3f\n" - - // Prepare for tail in K - "movi vC24.4s, #0\n" - "ldr dA1, [%x[aptr]], #0x08\n" - "movi vC31.4s, #0\n" - "ldr dA2, [ aptr2], #0x08\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "b 2f\n" // Jump to tail - - "3:" // Prepare for loop over K - "movi vC24.4s, #0\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "movi vC31.4s, #0\n" - "ldr qA2, [ aptr2], #0x10\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "subs %x[k], %x[k], #1\n" - "beq 4f\n" - - "1:" // Loop proper - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "subs %x[k], %x[k], #1\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr qA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - "bne 1b\n" - - "4:" // Tail iteration - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr dA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr dA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - - "2:" // Common tail - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr dA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr dA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "ldr sA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "ldr sA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "stp qC11, qC12, [%x[cptr], #0x00]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "ldr sA3, [ aptr3], #0x10\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "stp qC13, qC14, [%x[cptr], #0x20]\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "stp qC21, qC22, [%x[cptr], #0x00]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "ldr sA4, [ aptr4], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "stp qC23, qC24, [%x[cptr], #0x20]\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "stp qC31, qC32, [%x[cptr], #0x00]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "stp qC33, qC34, [%x[cptr], #0x20]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "stp qC41, qC42, [%x[cptr], #0x00]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - "stp qC43, qC44, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - - ".unreq vB4\n" ".unreq qB4\n" - ".unreq vB3\n" ".unreq qB3\n" - ".unreq vB2\n" ".unreq qB2\n" - ".unreq vB1\n" ".unreq qB1\n" - ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" - ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" - ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" - ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" - ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" - ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" - ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" - ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" - ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" - ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" - ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" - ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" - ".unreq aptr2\n" - ".unreq aptr3\n" - ".unreq aptr4\n" - - : [aptr] "+r" (aptr), - [bptr] "+r" (bptr), - [cptr] "+r" (cptr), - [k] "+r" (k) - : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), - [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), - [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) - : "cc", "memory", "x20", "x21", "x22", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", - "v21", "v22", "v23" - ); - } - } -} diff --git a/arm_compute/core/NEON/kernels/winograd/perf.h b/arm_compute/core/NEON/kernels/winograd/perf.h deleted file mode 100644 index 11fb0c452f..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/perf.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -/* Prototypes from perf.c */ - -void start_counter(int fd); -long long get_counter(int fd); -long long stop_counter(int fd); -int open_instruction_counter(void); -int open_cycle_counter(void); diff --git a/arm_compute/core/NEON/kernels/winograd/profiler.hpp b/arm_compute/core/NEON/kernels/winograd/profiler.hpp deleted file mode 100644 index 143192b589..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/profiler.hpp +++ /dev/null @@ -1,244 +0,0 @@ - -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "perf.h" -#include - -class profiler { -private: -#ifdef CYCLE_PROFILING - struct ProfileEntry { - int event_id; - long int bytes_read, ops, bytes_written; - long int duration; - }; - - static const int maxevents = 10000; - ProfileEntry events[maxevents]; - int currentevent; - int countfd; - - std::map event_ids; - - int get_event_id(const char *id) { - if (!event_ids.count(id)) { - event_ids.emplace(id, event_ids.size()); - } - return event_ids[id]; - } -#endif // CYCLE_PROFILING - -public: -#ifdef CYCLE_PROFILING - profiler() { - currentevent = 0; - countfd = open_cycle_counter(); - } - - ~profiler() { - close(countfd); - - // Compute performance from recorded events - struct ProfileResult { - ProfileResult() : total_calls(0), - total_duration(0), - total_bytes_read(0), - total_ops(0), - total_bytes_written(0) { - } - - void operator+=(const ProfileEntry &rhs) { - total_calls++; - total_duration += rhs.duration; - total_bytes_read += rhs.bytes_read; - total_ops += rhs.ops; - total_bytes_written = rhs.bytes_written; - } - - float avg_duration(void) const { - return static_cast(total_duration) / - static_cast(total_calls); - } - - float bytes_read_per_cycle(void) const { - return static_cast(total_bytes_read) / - static_cast(total_duration); - } - - float ops_per_cycle(void) const { - return static_cast(total_ops) / - static_cast(total_duration); - } - - float bytes_written_per_cycle(void) const { - return static_cast(total_bytes_written) / - static_cast(total_duration); - } - - long int total_calls, - total_duration, - total_bytes_read, - total_ops, - total_bytes_written; - }; - - std::vector totals; - totals.resize(event_ids.size()); - for (int i = 0; i < currentevent; i++) { - const auto &event = events[i]; - totals[event.event_id] += event; - } - - // Get the longest label - int len_label = 0; - for (const auto &kv : event_ids) { - len_label = std::max(len_label, static_cast(strlen(kv.first))); - } - - // Get the longest values for every other field - const auto get_length_of_field = - [totals] (const char *title, auto f, auto len) -> size_t { - size_t l = strlen(title); - for (const auto &v : totals) { - l = std::max(l, len(f(v))); - } - return l; - }; - - // Get the strlen for an int - const auto intlen = [] (long int x) -> size_t { - size_t len = 0; - do { - x /= 10; - len++; - } while (x); - return len; - }; - - // Get the strlen for a float - const auto floatlen = [] (const int precision) { - return [precision] (float x) { - size_t len = 0; - - if (!std::isfinite(x)) { - return static_cast(3); - } - - do { - x /= 10.0f; - len++; - } while (x > 1.0f); - return len + 1 + precision; - }; - }; - - const int len_calls = get_length_of_field( - "Calls", [] (const auto &v) {return v.total_calls;}, - intlen - ); - const int len_duration = get_length_of_field( - "Duration", [] (const auto &v) {return v.total_duration;}, - intlen - ); - const int len_average_duration = get_length_of_field( - "Average", [] (const auto &v) {return v.avg_duration();}, - floatlen(2) - ); - const int len_reads_per_cycle = get_length_of_field( - "Reads / cycle", - [] (const auto &v) {return v.bytes_read_per_cycle();}, - floatlen(6) - ); - const int len_ops_per_cycle = get_length_of_field( - "Ops / cycle", - [] (const auto &v) {return v.ops_per_cycle();}, - floatlen(6) - ); - const int len_writes_per_cycle = get_length_of_field( - "Writes / cycle", - [] (const auto &v) {return v.bytes_written_per_cycle();}, - floatlen(6) - ); - - // Print header - printf( - "%*s %*s %*s %*s %*s %*s %*s\n", - len_label, "", - len_calls, "Calls", - len_duration, "Duration", - len_average_duration, "Average", - len_reads_per_cycle, "Reads / cycle", - len_ops_per_cycle, "Ops / cycle", - len_writes_per_cycle, "Writes / cycle" - ); - for (const auto &kv : event_ids) { - const auto id = kv.second; - printf( - "%*s %*ld %*ld %*.2f %*.6f %*.6f %*.6f\n", - len_label, kv.first, - len_calls, totals[id].total_calls, - len_duration, totals[id].total_duration, - len_average_duration, totals[id].avg_duration(), - len_reads_per_cycle, totals[id].bytes_read_per_cycle(), - len_ops_per_cycle, totals[id].ops_per_cycle(), - len_writes_per_cycle, totals[id].bytes_written_per_cycle() - ); - } - printf("\n"); - } -#endif // CYCLE_PROFILING - - template - void operator() (const char * event, - T func, - long int bytes_read = 0, - long int ops = 0, - long int bytes_written = 0) { -#ifdef CYCLE_PROFILING - if (currentevent==maxevents) { - func(); - } else { - start_counter(countfd); - func(); - long long cycs = stop_counter(countfd); - - // Store the profiling data - events[currentevent++] = { - get_event_id(event), bytes_read, ops, bytes_written, cycs - }; - } -#else - func(); -#endif // CYCLE_PROFILING - } -}; diff --git a/arm_compute/core/NEON/kernels/winograd/shims.hpp b/arm_compute/core/NEON/kernels/winograd/shims.hpp deleted file mode 100644 index 249e5757f0..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/shims.hpp +++ /dev/null @@ -1,319 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#pragma once - -/** Re-order a weight tensor from [Output feature map x Input feature map x - * Height x Width] format to [Height x Width x Input feature map x Output - * feature map] format. - */ -template -inline void ofm_ifm_h_w_to_h_w_ifm_ofm( - const T* const in, // Input in [Output x Input x Height x Width] form - T* const out, // Output in [Height x Width x Input x Output] form - const int n_output_feature_maps, - const int n_input_feature_maps, - const int n_rows, - const int n_cols, - int in_output_feature_map_stride=0, - int in_input_feature_map_stride=0, - int in_row_stride=0, - int out_row_stride=0, - int out_col_stride=0, - int out_input_feature_map_stride=0 -); - -/** Re-order a weight tensor from [Height x Width x Input feature map x Output - * feature map] format to [Output feature map x Input feature map x Height x - * Width] format. - */ -template -inline void h_w_ifm_ofm_to_ofm_ifm_h_w( - const T* const in, // Input in [Height x Width x Input x Output] form - T* const out, // Output in [Output x Input x Height x Width] form - const int n_rows, - const int n_cols, - const int n_input_feature_maps, - const int n_output_feature_maps, - int in_row_stride=0, - int in_col_stride=0, - int in_input_feature_map_stride=0, - int out_output_feature_map_stride=0, - int out_input_feature_map_stride=0, - int out_row_stride=0 -); - - -/* Re-order a tensor from NCHW format to NHWC. - */ -template -inline void nchw_to_nhwc( - const T* const in, - T* const out, - const int n_batches, - const int n_channels, - const int n_rows, - const int n_cols, - int in_batch_stride=0, - int in_channel_stride=0, - int in_row_stride=0, - int out_batch_stride=0, - int out_row_stride=0, - int out_col_stride=0 -) -{ - // Fill in the stride values - in_row_stride = (in_row_stride) ? in_row_stride : n_cols; - in_channel_stride = (in_channel_stride) ? in_channel_stride - : n_rows * in_row_stride; - in_batch_stride = (in_batch_stride) ? in_batch_stride - : n_channels * in_channel_stride; - - out_col_stride = (out_col_stride) ? out_col_stride : n_channels; - out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride; - out_batch_stride = (out_batch_stride) ? out_batch_stride - : n_rows * out_row_stride; - - // Perform the re-ordering - for (int n = 0; n < n_batches; n++) - { - const T* const in_batch = in + n*in_batch_stride; - T* const out_batch = out + n*out_batch_stride; - - for (int i = 0; i < n_rows; i++) - { - const T* const in_row = in_batch + i*in_row_stride; - T* const out_row = out_batch + i*out_row_stride; - - for (int j = 0; j < n_cols; j++) - { - const T* const in_col = in_row + j; - T* const out_col = out_row + j*out_col_stride; - - for (int c = 0; c < n_channels; c++) - { - const T* const in_channel = in_col + c*in_channel_stride; - out_col[c] = *(in_channel); - } - } - } - } -} - -/* Re-order a tensor from NHWC format to NCHW. - */ -template -inline void nhwc_to_nchw( - const T* const in, // Input data in NHWC form - T* const out, // Output data in NCHW form - const int n_batches, - const int n_rows, - const int n_cols, - const int n_channels, - int in_batch_stride=0, - int in_row_stride=0, - int in_col_stride=0, - int out_batch_stride=0, - int out_channel_stride=0, - int out_row_stride=0 -) -{ - // Fill in stride values - in_col_stride = (in_col_stride) ? in_col_stride : n_channels; - in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride; - in_batch_stride = (in_batch_stride) ? in_batch_stride - : n_rows * in_row_stride; - - out_row_stride = (out_row_stride) ? out_row_stride : n_cols; - out_channel_stride = (out_channel_stride) ? out_channel_stride - : n_rows * out_row_stride; - out_batch_stride = (out_batch_stride) ? out_batch_stride - : n_channels * out_channel_stride; - - // Perform the re-ordering - // For every batch - for (int n = 0; n < n_batches; n++) - { - const T* const in_batch = in + n*in_batch_stride; - T* const out_batch = out + n*out_batch_stride; - - // For every row - for (int i = 0; i < n_rows; i++) - { - const T* const in_i = in_batch + i*in_row_stride; - T* const out_i = out_batch + i*out_row_stride; - - // For every column - for (int j = 0; j < n_cols; j++) - { - const T* const in_j = in_i + j*in_col_stride; - T* const out_j = out_i + j; - - // For every channel - for (int c = 0; c < n_channels; c++) - { - const T* const in_channel = in_j + c; - T* const out_channel = out_j + c*out_channel_stride; - *(out_channel) = *(in_channel); - } - } - } - } -} - - -/*****************************************************************************/ -/* Generic weight re-order implementation. - */ -template -inline void ofm_ifm_h_w_to_h_w_ifm_ofm( - const T* const in, // Input in [Output x Input x Height x Width] form - T* const out, // Output in [Height x Width x Input x Output] form - const int n_output_feature_maps, - const int n_input_feature_maps, - const int n_rows, - const int n_cols, - int in_output_feature_map_stride, - int in_input_feature_map_stride, - int in_row_stride, - int out_row_stride, - int out_col_stride, - int out_input_feature_map_stride -) -{ - // Fill in stride values - in_row_stride = (in_row_stride) - ? in_row_stride - : n_cols; - in_input_feature_map_stride = (in_input_feature_map_stride) - ? in_input_feature_map_stride - : n_rows * in_row_stride; - in_output_feature_map_stride = (in_output_feature_map_stride) - ? in_output_feature_map_stride - : n_input_feature_maps * in_input_feature_map_stride; - - out_input_feature_map_stride = (out_input_feature_map_stride) - ? out_input_feature_map_stride - : n_output_feature_maps; - out_col_stride = (out_col_stride) - ? out_col_stride - : n_input_feature_maps * out_input_feature_map_stride; - out_row_stride = (out_row_stride) - ? out_row_stride - : n_cols * out_col_stride; - - // Perform the re-ordering - for (int i = 0; i < n_rows; i++) - { - const T* const in_row = in + i * in_row_stride; - T* out_row = out + i * out_row_stride; - - for (int j = 0; j < n_cols; j++) - { - const T* const in_col = in_row + j; - T* const out_col = out_row + j * out_col_stride; - - for (int ifm = 0; ifm < n_input_feature_maps; ifm++) - { - const T* const in_ifm = in_col + ifm * in_input_feature_map_stride; - T* const out_ifm = out_col + ifm * out_input_feature_map_stride; - - for (int ofm = 0; ofm < n_output_feature_maps; ofm++) - { - const T* const in_ofm = in_ifm + ofm * in_output_feature_map_stride; - T* const out_ofm = out_ifm + ofm; - *(out_ofm) = *(in_ofm); - } - } - } - } -} - -/*****************************************************************************/ -/* Generic weight re-order implementation. - */ -template -inline void h_w_ifm_ofm_to_ofm_ifm_h_w( - const T* const in, // Input in [Height x Width x Input x Output] form - T* const out, // Output in [Output x Input x Height x Width] form - const int n_rows, - const int n_cols, - const int n_input_feature_maps, - const int n_output_feature_maps, - int in_row_stride, - int in_col_stride, - int in_input_feature_map_stride, - int out_output_feature_map_stride, - int out_input_feature_map_stride, - int out_row_stride -) -{ - // Fill in the stride values - in_input_feature_map_stride = (in_input_feature_map_stride) - ? in_input_feature_map_stride - : n_output_feature_maps; - in_col_stride = (in_col_stride) - ? in_col_stride - : n_input_feature_maps * in_input_feature_map_stride; - in_row_stride = (in_row_stride) - ? in_row_stride - : n_cols * in_col_stride; - - out_row_stride = (out_row_stride) - ? out_row_stride - : n_cols; - out_input_feature_map_stride = (out_input_feature_map_stride) - ? out_input_feature_map_stride - : n_rows * out_row_stride; - out_output_feature_map_stride = (out_output_feature_map_stride) - ? out_output_feature_map_stride - : n_input_feature_maps * out_input_feature_map_stride; - - // Perform the re-ordering - for (int i = 0; i < n_rows; i++) - { - const T* const in_row = in + i * in_row_stride; - T* const out_row = out + i * out_row_stride; - - for (int j = 0; j < n_cols; j++) - { - const T* const in_col = in_row + j * in_col_stride; - T* const out_col = out_row + j; - - for (int ifm = 0; ifm < n_input_feature_maps; ifm++) - { - const T* const in_ifm = in_col + ifm * in_input_feature_map_stride; - T* const out_ifm = out_col + ifm * out_input_feature_map_stride; - - for (int ofm = 0; ofm < n_output_feature_maps; ofm++) - { - const T* const in_ofm = in_ifm + ofm; - T* const out_ofm = out_ifm + ofm * out_output_feature_map_stride; - *(out_ofm) = *(in_ofm); - } - } - } - } -} - diff --git a/arm_compute/core/NEON/kernels/winograd/transforms.hpp b/arm_compute/core/NEON/kernels/winograd/transforms.hpp deleted file mode 100644 index 8546ee9e2e..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/transforms.hpp +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#pragma once - -#include "transforms/input_2x2_3x3.hpp" -#include "transforms/kernel_2x2_3x3.hpp" -#include "transforms/output_2x2_3x3.hpp" diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp deleted file mode 100644 index 7013c66ac0..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp +++ /dev/null @@ -1,638 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include "../tensor.hpp" - -namespace winograd { - /* Transform an input tensor into the Winograd domain. - */ - template - struct Winograd2x2_3x3GemmInput { - static void execute( - const T *inptr, - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const int tile_M, - const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride - ); - - static size_t bytes_read(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - return input_shape.n_batches * tile_rows * (16 + 8*(tile_cols - 1)) * input_shape.n_channels * sizeof(T); - } - - static int flops_performed(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - return input_shape.n_batches * tile_rows * (32 + 24*(tile_cols - 1)) * input_shape.n_channels; - } - - static size_t bytes_written(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - const int M = input_shape.n_batches * tile_rows * tile_cols; - return 16 * M * input_shape.n_channels * sizeof(T); - } - - protected: - template - static void process_tile_tensor( - const int tile_M, // Number of rows of tiles - const int tile_N, // Number of columns of tiles - int n_channels, // Number of input channels - const T* const input, // Base input pointer (appropriate to batch and channel) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - T* const matrix, // 1st output matrix (appropriate to batch and channel) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix - ); - - template - static void process_tile_row( - const int tile_N, // Number of tiles in the row - const T* const input, // Base input pointer (appropriate to batch, channel and row) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - T* const matrix, // 1st output matrix (appropriate to batch, channel and row) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix - ); - }; - - template - struct Winograd2x2_3x3GemmInputChannelwise { - static void execute( - const T *inptr, - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const int tile_M, - const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride - ); - - static size_t bytes_read(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - // We read as many bytes as we write - return bytes_written(input_shape, output_shape); - } - - static int flops_performed(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - return input_shape.n_batches * tile_rows * 32 * tile_cols * input_shape.n_channels; - } - - static size_t bytes_written(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - return winograd::Winograd2x2_3x3GemmInput::bytes_written(input_shape, output_shape); - } - - protected: - typedef void (*tilefunc)(int, const T*, int, int, T*, int); - template - static void process_tile( - int n_channels, // Number of channels in the tile - const T* const input_base, - const int input_row_stride, - const int input_col_stride, - T* const matrix_base, - const int matrix_stride - ); - - private: - template - static void _process_tile( - int &n_channels, const T* &inptr, - const int input_row_stride, const int input_col_stride, - T* &outptr, const int matrix_stride - ); - }; -} - -/*****************************************************************************/ -// Include specialised implementations here -#include "input_2x2_3x3/a64_float.hpp" -#include "input_2x2_3x3/a64_float_channelwise.hpp" -/*****************************************************************************/ - -/*****************************************************************************/ -template -void winograd::Winograd2x2_3x3GemmInput::execute( - const T *inptr_base, - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const int tile_M, - const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride -) { - // Select an appropriate matrix processing method for the shape and padding - // of the input tensor. - typedef void (*tensorfunc)(int, int, int, const T*, int, int, T*, int, int); - const auto process_tensor = [&padding_type, &input_shape] () -> tensorfunc { - if (padding_type == PADDING_VALID) { - const int pad_bottom = input_shape.n_rows % 2; - const int pad_right = input_shape.n_cols % 2; - - if (pad_bottom == 0 && pad_right == 0) { - return process_tile_tensor; - } else if (pad_bottom == 0 && pad_right == 1) { - return process_tile_tensor; - } else if (pad_bottom == 1 && pad_right == 0) { - return process_tile_tensor; - } else if (pad_bottom == 1 && pad_right == 1) { - return process_tile_tensor; - } - } else { // PADDING_SAME - const int pad_bottom = 1 + input_shape.n_rows % 2; - const int pad_right = 1 + input_shape.n_cols % 2; - - if (pad_bottom == 1 && pad_right == 1) { - return process_tile_tensor; - } else if (pad_bottom == 1 && pad_right == 2) { - return process_tile_tensor; - } else if (pad_bottom == 2 && pad_right == 1) { - return process_tile_tensor; - } else if (pad_bottom == 2 && pad_right == 2) { - return process_tile_tensor; - } - } - - printf("%s::%u Uncovered case.\n", __FILE__, __LINE__); - exit(-1); - return NULL; // No function found - } (); - - // Compute strides - const int input_row_stride = input_shape.n_cols * input_shape.n_channels; - const int input_col_stride = input_shape.n_channels; - - // Process each batch of the tensor in turn. - for (int batch = 0; batch < input_shape.n_batches; batch++) { - // Work out pointers - const T *inptr = inptr_base + (batch * input_shape.n_rows * - input_shape.n_cols * input_shape.n_channels); - T *outptr = outptr_base + batch * matrix_batch_stride; - - // Delegate doing the actual work - process_tensor( - tile_M, tile_N, input_shape.n_channels, - inptr, input_row_stride, input_col_stride, - outptr, matrix_stride, matrix_row_stride - ); - } -} - -/*****************************************************************************/ -template -template -void winograd::Winograd2x2_3x3GemmInput::process_tile_tensor( - const int tile_M, // Number of rows of tiles - const int tile_N, // Number of columns of tiles - int n_channels, // Number of input channels - const T* const input, // Base input pointer (appropriate to batch and channel) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - T* const matrix, // 1st output matrix (appropriate to batch and channel) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix -) { - // Base row processing functions - typedef void (*rowfunc)(int, const T*, int, int, T*, int, int); - const rowfunc process_top_row[3] = { - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 1> - : process_tile_row<1, 1, 0, pad_right, 1>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 2> - : process_tile_row<1, 1, 0, pad_right, 2>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 4> - : process_tile_row<1, 1, 0, pad_right, 4>, - }; - const rowfunc process_middle_row[3] = { - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 1> - : process_tile_row<0, 1, 0, pad_right, 1>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 2> - : process_tile_row<0, 1, 0, pad_right, 2>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 4> - : process_tile_row<0, 1, 0, pad_right, 4>, - }; - const rowfunc process_bottom_row[3] = { - (padding == PADDING_VALID) - ? process_tile_row<0, 0, pad_bottom, pad_right, 1> - : process_tile_row<0, 1, pad_bottom, pad_right, 1>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, pad_bottom, pad_right, 2> - : process_tile_row<0, 1, pad_bottom, pad_right, 2>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, pad_bottom, pad_right, 4> - : process_tile_row<0, 1, pad_bottom, pad_right, 4>, - }; - - // Method to get an input pointer for the given tile row - const auto get_inptr = [&input, &input_row_stride] (const int tile_i) { - if (padding == PADDING_VALID) { - return input + 2 * tile_i * input_row_stride; - } else { - return input + (2 * tile_i - (tile_i ? 1 : 0)) * input_row_stride; - } - }; - - // Wrapper to process a row of tiles, covering all channels. - const auto process_row = - [tile_N, input_row_stride, input_col_stride, matrix_stride, matrix_row_stride, n_channels] - (const rowfunc f[3], const T *inptr, T *outptr) { - int rem_channels = n_channels; - - // While there remain channels to process continue to process the - // row. - for (; rem_channels >= 4; rem_channels -= 4, inptr += 4, outptr += 4) { - f[2](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); - } - for (; rem_channels >= 2; rem_channels -= 2, inptr += 2, outptr += 2) { - f[1](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); - } - if (rem_channels) { - f[0](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); - } - }; - - // Process all rows of tiles in the tensor - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - T* const m_row = matrix + tile_i * tile_N * matrix_row_stride; - const T *row_inptr = get_inptr(tile_i); - - if (tile_i == 0) { - // Top row of the input - process_row(process_top_row, row_inptr, m_row); - } else if (tile_i == tile_M - 1) { - // Bottom row of the input - process_row(process_bottom_row, row_inptr, m_row); - } else { - // Any other row of the input - process_row(process_middle_row, row_inptr, m_row); - } - } -} - -/*****************************************************************************/ -template -template -void winograd::Winograd2x2_3x3GemmInput::process_tile_row( - const int tile_N, // Number of tiles in the row - const T* const input, // Base input pointer (appropriate to batch, channel and row) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - T* const matrix, // 1st output matrix (appropriate to batch, channel and row) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix -) { - // Construct copies of the pointers - const T *inptr = input; - T *outptr = matrix; - - // Storage for the tensors x, X.T x, and X.T x X. - T x[4][4][proc_channels], XTx[4][4][proc_channels], XTxX[4][4][proc_channels]; - - // For every tile in the row - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - // Determine the padding for the tile - const int tile_pad_left = (tile_j == 0) ? pad_left : 0; - const int tile_pad_right = (tile_j == tile_N - 1) ? pad_right : 0; - - // Load tile values. If this is the first tile in the row then we must load - // all values, otherwise we can just load the final two columns of the input. - for (int i = 0; i < 4; i++) { - for (int j = ((tile_j == 0) ? 0 : 2); j < 4; j++) { - // Fill with padding if required - if (i < pad_top || 4 - pad_bottom <= i || - j < tile_pad_left || 4 - tile_pad_right <= j) { - for (int c = 0; c < proc_channels; c++) { - x[i][j][c] = static_cast(0); // Padding - } - } else { - // Load values, note that the initial padding offsets the pointer we - // were provided. - for (int c = 0; c < proc_channels; c++) { - const int row_offset = (i - pad_top) * input_row_stride; - const int col_offset = (j - tile_pad_left) * input_col_stride; - x[i][j][c] = inptr[row_offset + col_offset + c]; - } - } - } - } - - // Compute the matrix X.T x. Note, can elide operations depending on the - // padding. Furthermore, if this isn't the left-most tile we can skip half - // of the operations by copying results from the previous version of X.T x. - // This latter optimisation can be simplified by unrolling the outermost - // loop by two and by renaming the registers containing XTx. - if (tile_j == 0) { - for (int j = 0; j < 4; j++) { - for (int c = 0; c < proc_channels; c++) { - XTx[0][j][c] = x[0][j][c] - x[2][j][c]; - XTx[1][j][c] = x[1][j][c] + x[2][j][c]; - XTx[2][j][c] = -x[1][j][c] + x[2][j][c]; - XTx[3][j][c] = x[1][j][c] - x[3][j][c]; - } - } - } else { - for (int j = 0; j < 2; j++) { - for (int c = 0; c < proc_channels; c++) { - XTx[0][j][c] = XTx[0][j + 2][c]; - XTx[1][j][c] = XTx[1][j + 2][c]; - XTx[2][j][c] = XTx[2][j + 2][c]; - XTx[3][j][c] = XTx[3][j + 2][c]; - } - } - for (int j = 2; j < 4; j++) { - for (int c = 0; c < proc_channels; c++) { - XTx[0][j][c] = x[0][j][c] - x[2][j][c]; - XTx[1][j][c] = x[1][j][c] + x[2][j][c]; - XTx[2][j][c] = -x[1][j][c] + x[2][j][c]; - XTx[3][j][c] = x[1][j][c] - x[3][j][c]; - } - } - } - - // Compute the matrix X.T x X. Note, can elide operations based on the - // padding. - for (int i = 0; i < 4; i++) { - for (int c = 0; c < proc_channels; c++) { - XTxX[i][0][c] = XTx[i][0][c] - XTx[i][2][c]; - XTxX[i][1][c] = XTx[i][1][c] + XTx[i][2][c]; - XTxX[i][2][c] = -XTx[i][1][c] + XTx[i][2][c]; - XTxX[i][3][c] = XTx[i][1][c] - XTx[i][3][c]; - } - } - - // Store the output matrix (X.T x X) - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - // Get a pointer to the relevant output matrix - T *mptr = outptr + (i*4 + j)*matrix_stride; - - // Write out the channels - for (int c = 0; c < proc_channels; c++) { - mptr[c] = XTxX[i][j][c]; - } - } - } - - // Update the pointers - inptr += input_col_stride * ((tile_j == 0 && pad_left) ? 1 : 2); - outptr += matrix_row_stride; - } -} - -/*****************************************************************************/ -template -void winograd::Winograd2x2_3x3GemmInputChannelwise::execute( - const T *inptr, - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const int tile_M, - const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride -) { - const int n_channels = input_shape.n_channels; - const int input_col_stride = n_channels; - const int input_row_stride = input_shape.n_cols * input_col_stride; - - // Determine the padding and hence select appropriate methods for each tile. - tilefunc fs[3][3]; - - if (padding_type == PADDING_VALID) { - constexpr int pad_top = 0; - constexpr int pad_left = 0; - const int pad_right = input_shape.n_cols % 2 == 0; - - fs[0][0] = process_tile; - fs[0][1] = process_tile; - fs[0][2] = (pad_right) ? process_tile : process_tile; - - fs[1][0] = process_tile<0, pad_left, 0, 0>; - fs[1][1] = process_tile<0, 0, 0, 0>; - fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 0> : process_tile<0, 0, 0, 1>; - - if (input_shape.n_rows % 2 == 0) { - constexpr int pad_bottom = 0; - fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; - fs[2][1] = process_tile<0, 0, pad_bottom, 0>; - fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>; - } else { - constexpr int pad_bottom = 1; - fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; - fs[2][1] = process_tile<0, 0, pad_bottom, 0>; - fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>; - } - } else { - constexpr int pad_top = 1; - constexpr int pad_left = 1; - const int pad_right = input_shape.n_cols % 2 == 0; - - fs[0][0] = process_tile; - fs[0][1] = process_tile; - fs[0][2] = (pad_right) ? process_tile : process_tile; - - fs[1][0] = process_tile<0, pad_left, 0, 0>; - fs[1][1] = process_tile<0, 0, 0, 0>; - fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 1> : process_tile<0, 0, 0, 2>; - - if (input_shape.n_rows % 2 == 0) { - constexpr int pad_bottom = 1; - fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; - fs[2][1] = process_tile<0, 0, pad_bottom, 0>; - fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>; - } else { - constexpr int pad_bottom = 2; - fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; - fs[2][1] = process_tile<0, 0, pad_bottom, 0>; - fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>; - } - } - - // Process each tile in turn - for (int batch = 0; batch < input_shape.n_batches; batch++) { - const T* const input_base_batch = inptr + batch*input_shape.n_rows*input_shape.n_cols*n_channels; - - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - const int row_offset = (tile_i == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1); - const T* const input_base_row = input_base_batch + (2*tile_i - row_offset)*input_shape.n_cols*n_channels; - - // Select the set of functions for the row - const int fs_i = (tile_i == 0) ? 0 : ((tile_i < tile_M - 1) ? 1 : 2); - - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - // Select the function for the column - const int fs_j = (tile_j == 0) ? 0 : ((tile_j < tile_N - 1) ? 1 : 2); - const auto f = fs[fs_i][fs_j]; - - // Get pointers into the input and outputs - const int col_offset = (tile_j == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1); - const T* const input_base_col = input_base_row + (2*tile_j - col_offset)*n_channels; - T* const matrix_base = outptr_base + batch*matrix_batch_stride + (tile_i*tile_N + tile_j)*matrix_row_stride; - f(n_channels, input_base_col, input_row_stride, input_col_stride, - matrix_base, matrix_stride); - } - } - } -} - -template -template -void winograd::Winograd2x2_3x3GemmInputChannelwise::process_tile( - int n_channels, // Number of channels in the tile - const T* const input_base, - const int input_row_stride, - const int input_col_stride, - T* const matrix_base, - const int matrix_stride -) { - // Copy pointers - const T *inptr = input_base; - T *outptr = matrix_base; - - // Process channels (modifies inptr, outptr and n_channels) - _process_tile( - n_channels, inptr, input_row_stride, input_col_stride, - outptr, matrix_stride - ); - _process_tile( - n_channels, inptr, input_row_stride, input_col_stride, - outptr, matrix_stride - ); - _process_tile( - n_channels, inptr, input_row_stride, input_col_stride, - outptr, matrix_stride - ); -} - -template -template -void winograd::Winograd2x2_3x3GemmInputChannelwise::_process_tile( - int &n_channels, - const T* &inptr, const int input_row_stride, const int input_col_stride, - T* &outptr, const int matrix_stride -) { - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - T* outptrs[4] = { - outptr, - outptr + matrix_stride * 4, - outptr + matrix_stride * 8, - outptr + matrix_stride * 12 - }; - - // The matrix X; zeroed to account for padding. - T x[4][4]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - x[i][j] = 0; - } - } - - // The matrices X.T x and U - T XTx[4][4], U[4][4]; - - // Now progress through each channel - for (; n_channels >= proc_channels; n_channels -= proc_channels) { - for (int n = 0; n < proc_channels; n++) { - // Load the matrix X - for (int cell_i = pad_top, i = 0; cell_i < 4 - pad_bottom; cell_i++, i++) { - for (int cell_j = pad_left, j = 0; cell_j < 4 - pad_right; cell_j++, j++) { - x[cell_i][cell_j] = inptr[i*input_row_stride + j*input_col_stride]; - } - } - inptr++; - - // Compute the matrix X.T - for (int j = 0; j < 4; j++) { - XTx[0][j] = x[0][j] - x[2][j]; - XTx[1][j] = x[1][j] + x[2][j]; - XTx[2][j] = x[2][j] - x[1][j]; - XTx[3][j] = x[1][j] - x[3][j]; - } - - // Hence compute the matrix U - for (int i = 0; i < 4; i++) { - U[i][0] = XTx[i][0] - XTx[i][2]; - U[i][1] = XTx[i][1] + XTx[i][2]; - U[i][2] = XTx[i][2] - XTx[i][1]; - U[i][3] = XTx[i][1] - XTx[i][3]; - } - - // Store the matrix U - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - outptrs[i][j * matrix_stride] = U[i][j]; - } - outptrs[i]++; - } - } - } - - // Update the output pointer for future calls - outptr = outptrs[0]; -} diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp deleted file mode 100644 index a99cbe325b..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp +++ /dev/null @@ -1,1498 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include "../input_2x2_3x3.hpp" - -#ifdef __aarch64__ -namespace winograd { - -// Pad left by one column, pad right by one column, no upper or lower padding, 4 channels -template <> -template <> -inline void Winograd2x2_3x3GemmInput::process_tile_row<0, 1, 0, 1, 4>( - const int tile_N, // Number of tiles in the row - const float* const input, // Base input pointer (appropriate to batch, channel and row) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - float* const matrix, // 1st output matrix (appropriate to batch, channel and row) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix -) { - /* SIMD register allocation - * ======================== - * - * In the following code we read 4x4 tiles of a matrix `x`, with which we - * compute another matrix `X.T x` where: - * - * / 1 0 0 0 \ - * X = | 0 1 -1 1 | - * | -1 1 1 0 | - * \ 0 0 0 -1 / - * - * Hence, `X.T` is a program which operates upon rows of the matrix `X`. - * We subsequently compute and store the matrix `U = (X.T x) X`. - * - * Importantly, each iteration of the loop below loads a new matrix `x'` - * where the final two columns of `x'` are the first two columns of the - * previous `x`. That is: - * - * x11 x12 x13 x14 - * x21 x22 x23 x24 - * x31 x32 x33 x34 - * x41 x42 x43 x44 - * - * x'11 x'12 x'13 x'14 - * x'21 x'22 x'23 x'24 - * x'31 x'32 x'33 x'34 - * x'41 x'42 x'43 x'44 - * - * Consequently, while the first iteration of the below loop must load 16 - * values for `x`, the second need load only 8. *Furthermore*, since we noted - * above that the operation `X.T x` was a program which operated upon *rows* - * of the matrix `x` it follows that that the relation that `x'[i][1] = - * x[i][3]` and `x'[i][2] = x[i][4]` applies also the matrices `X.T x'` and - * `X.T x`. That is: - * - * (X.T x)11 (X.T x)12 (X.T x)13 (X.T x)14 - * (X.T x)21 (X.T x)22 (X.T x)23 (X.T x)24 - * (X.T x)31 (X.T x)32 (X.T x)33 (X.T x)34 - * (X.T x)41 (X.T x)42 (X.T x)43 (X.T x)44 - * - * (X.T x')11 (X.T x')12 (X.T x')13 (X.T x')14 - * (X.T x')12 (X.T x')12 (X.T x')12 (X.T x')12 - * (X.T x')13 (X.T x')13 (X.T x')13 (X.T x')13 - * (X.T x')14 (X.T x')14 (X.T x')14 (X.T x')14 - * - * Hence, as well as not needing to load new values for x'[i][1..2] it is - * also unnecessary to recompute values for (X.T x')[i][1..2]. - * - * Following this we break the registers into blocks `A` and `B` used by the - * two stages of the unrolled loop. These registers named such that the - * latter columns of `A` become the earlier columns of `B` and vice-versa: - * - * AXTx11 AXTx12 > AXTx13 AXTx14 | - * AXTx21 AXTx22 > AXTx23 AXTx24 | - * AXTx31 AXTx32 > AXTx33 AXTx34 | - * AXTx41 AXTx42 > AXTx43 AXTx44 | - * - * BXTx13 BXTx14 | BXTx11 BXTx12 > - * BXTx23 BXTx24 | BXTx21 BXTx22 > - * BXTx33 BXTx34 | BXTx31 BXTx32 > - * BXTx43 BXTx44 | BXTx41 BXTx42 > - * - * These 32 named registers require only 16 architectural registers. 1 - * additional architectural register is used as scratch space and 8 - * architectural registers are used to load in the values x[1..4][3,4]. - * - * Input and output addressing - * =========================== - * TODO Description - */ - const float *inptr0 = input; - const float *inptr1 = input + input_row_stride; - const float *inptr2 = input + input_row_stride * 2; - const float *inptr3 = input + input_row_stride * 3; - - float *outptr0 = matrix; - float *outptr4 = matrix + matrix_stride * 4; - float *outptr8 = matrix + matrix_stride * 8; - float *outptr12 = matrix + matrix_stride * 12; - - int tile_j = tile_N; // Tiles to process - - asm volatile ( - // Named SIMD registers according to the policy given above - // Registers into which to load the latter two columns of `x` - "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n" - "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" - "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" - "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n" - - // Registers for storing X.T x (both A and B halves) - "AXTx11 .req v8\n" "BXTx13 .req v8\n" - "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" - "AXTx21 .req v10\n" "BXTx23 .req v10\n" - "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" - "AXTx31 .req v12\n" "BXTx33 .req v12\n" - "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" - "AXTx41 .req v14\n" "BXTx43 .req v14\n" - "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" - "AXTx13 .req v16\n" "BXTx11 .req v16\n" - "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" - "AXTx23 .req v18\n" "BXTx21 .req v18\n" - "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" - "AXTx33 .req v20\n" "BXTx31 .req v20\n" - "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" - "AXTx43 .req v22\n" "BXTx41 .req v22\n" - "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" - - // Result register (TODO Does using more registers yield better - // performance) - "U .req v24\n qU .req q24\n" - - // ---------------------------------------------------------------------- - // Head of loop - // Loads a complete 4x4 tile of x, computes X.T x, computes and stores - // `U = X.T x X`. Prepares for the 'A' half of the loop. - // NOTE: Since the first tile has the leftmost column padded we can - // skip 4 loads and 4 calculations for the matrix X.T x X. - - // Temporarily alias registers for computing the first (non-padded) - // column of x. - "x_12 .req v0\n qx_12 .req q0\n" - "x_22 .req v1\n qx_22 .req q1\n" - "x_32 .req v2\n qx_32 .req q2\n" - "x_42 .req v3\n qx_42 .req q3\n" - - "ldr qx_12, [%x[inptr0]]\n" - "ldr qx_22, [%x[inptr1]]\n" - "ldr qx_32, [%x[inptr2]]\n" - "ldr qx_42, [%x[inptr3]]\n" - - "fsub BXTx12.4s, x_12.4s, x_32.4s\n" - "fadd BXTx22.4s, x_22.4s, x_32.4s\n" - "fsub BXTx32.4s, x_32.4s, x_22.4s\n" - "fsub BXTx42.4s, x_22.4s, x_42.4s\n" - - ".unreq x_12\n .unreq qx_12\n" - ".unreq x_22\n .unreq qx_22\n" - ".unreq x_32\n .unreq qx_32\n" - ".unreq x_42\n .unreq qx_42\n" - - // Load and compute latter two columns of the first tile. Progress the - // input pointers (by three columns so that the each points are the - // second column of the next tile, that is, each points at the first - // column which must be read for the next tile. - "ldr qx_13, [%x[inptr0], %x[colstride1]]\n" - "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" - "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" - "ldr qx_43, [%x[inptr3], %x[colstride1]]\n" - - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride2]]\n" - - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" - - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" - - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride2]]\n" - - "fsub BXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride3]\n" - - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride3]\n" - - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride3]\n" - - "fsub BXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride3]\n" - - // Compute and store U for the first tile - // First row - "fneg U.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fneg U.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fneg U.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row, simultaneously load the first column of inputs for the - // next tile. - "fneg U.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - // Update the loop counter, subtract two to account for both the head and - // the tail. - "subs %x[tile_j], %x[tile_j], #2\n" - "beq 2f\n" // Jump to "A" tail if out of tiles - - // ---------------------------------------------------------------------- - "1:" - // Start part A - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fsub AXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "fsub AXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" - "fsub AXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride2]\n" - "fadd AXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub AXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "fsub AXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride2]\n" - - // Compute and store U. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, AXTx12.4s, AXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, AXTx22.4s, AXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, AXTx32.4s, AXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, AXTx42.4s, AXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - "subs %x[tile_j], %x[tile_j], #1\n" - "beq 3f\n" // Jump to 'B' tail - - // Start part B - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" - "fsub BXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride2]\n" - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "fsub BXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride2]\n" - - // Compute and store U. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - "subs %x[tile_j], %x[tile_j], #1\n" - "bne 1b\n" // Continue loop, otherwise flow into 'A' tail - - // ---------------------------------------------------------------------- - "2:" - // 'A' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fsub AXTx13.4s, x_13.4s, x_33.4s\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "fsub AXTx43.4s, x_23.4s, x_43.4s\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" - - "b 4f\n" // Jump to end of function - - // ---------------------------------------------------------------------- - "3:" - // 'B' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" - - // ---------------------------------------------------------------------- - "4:" - // End of function - - // Clear names - ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n" - ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" - ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" - ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n" - ".unreq AXTx11\n" ".unreq BXTx13\n" - ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" - ".unreq AXTx21\n" ".unreq BXTx23\n" - ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" - ".unreq AXTx31\n" ".unreq BXTx33\n" - ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" - ".unreq AXTx41\n" ".unreq BXTx43\n" - ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" - ".unreq AXTx13\n" ".unreq BXTx11\n" - ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" - ".unreq AXTx23\n" ".unreq BXTx21\n" - ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" - ".unreq AXTx33\n" ".unreq BXTx31\n" - ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" - ".unreq AXTx43\n" ".unreq BXTx41\n" - ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" - ".unreq U\n" ".unreq qU\n" - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [inptr3] "+r" (inptr3), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [tile_j] "+r" (tile_j) // Tile counter - : [colstride1] "r" (1 * input_col_stride * sizeof(float)), - [colstride2] "r" (2 * input_col_stride * sizeof(float)), - [colstride3] "r" (3 * input_col_stride * sizeof(float)), - [mstride1] "r" (1 * matrix_stride * sizeof(float)), - [mstride2] "r" (2 * matrix_stride * sizeof(float)), - [mstride3] "r" (3 * matrix_stride * sizeof(float)), - [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24" - ); -} - -// Pad top, left and right by 1. -template <> -template <> -inline void Winograd2x2_3x3GemmInput::process_tile_row<1, 1, 0, 1, 4>( - const int tile_N, - const float* const input, - const int input_row_stride, - const int input_col_stride, - float* const matrix, - const int matrix_stride, - const int matrix_row_stride -) { - const float *inptr0 = input; - const float *inptr1 = input + input_row_stride; - const float *inptr2 = input + input_row_stride * 2; - - float *outptr0 = matrix; - float *outptr4 = matrix + matrix_stride * 4; - float *outptr8 = matrix + matrix_stride * 8; - float *outptr12 = matrix + matrix_stride * 12; - - int tile_j = tile_N; // Tiles to process - - asm volatile ( - // Named SIMD registers according to the policy given above - // Registers into which to load the latter two columns of `x` - // NOTE: We need only load the latter three rows since we know that the - // first row is padded. - "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" - "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" - "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n" - - // Registers for storing X.T x (both A and B halves) - "AXTx11 .req v8\n" "BXTx13 .req v8\n" - "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" - "AXTx21 .req v10\n" "BXTx23 .req v10\n" - "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" - "AXTx31 .req v12\n" "BXTx33 .req v12\n" - "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" - "AXTx41 .req v14\n" "BXTx43 .req v14\n" - "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" - "AXTx13 .req v16\n" "BXTx11 .req v16\n" - "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" - "AXTx23 .req v18\n" "BXTx21 .req v18\n" - "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" - "AXTx33 .req v20\n" "BXTx31 .req v20\n" - "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" - "AXTx43 .req v22\n" "BXTx41 .req v22\n" - "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" - - // Result register (TODO Does using more registers yield better - // performance) - "U .req v24\n qU .req q24\n" - - // ---------------------------------------------------------------------- - // Head of loop - // Loads a complete 4x4 tile of x, computes X.T x, computes and stores - // `U = X.T x X`. Prepares for the 'A' half of the loop. - // NOTE: Since the first tile has the leftmost column padded we can - // skip 4 loads and 4 calculations for the matrix X.T x X. - - // Temporarily alias registers for computing the first (non-padded) - // column of x. - "x_22 .req v1\n qx_22 .req q1\n" - "x_32 .req v2\n qx_32 .req q2\n" - "x_42 .req v3\n qx_42 .req q3\n" - - "ldr qx_22, [%x[inptr1]]\n" - "ldr qx_32, [%x[inptr2]]\n" - "ldr qx_42, [%x[inptr3]]\n" - - "fneg BXTx12.4s, x_32.4s\n" - "fadd BXTx22.4s, x_22.4s, x_32.4s\n" - "fsub BXTx32.4s, x_32.4s, x_22.4s\n" - "fsub BXTx42.4s, x_22.4s, x_42.4s\n" - - ".unreq x_22\n .unreq qx_22\n" - ".unreq x_32\n .unreq qx_32\n" - ".unreq x_42\n .unreq qx_42\n" - - // Load and compute latter two columns of the first tile. Progress the - // input pointers (by three columns so that the each points are the - // second column of the next tile, that is, each points at the first - // column which must be read for the next tile. - "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" - "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" - "ldr qx_43, [%x[inptr3], %x[colstride1]]\n" - - "fneg BXTx13.4s, x_33.4s\n" - - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" - - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" - - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride2]]\n" - - "fneg BXTx14.4s, x_34.4s\n" - - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride3]\n" - - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride3]\n" - - "fsub BXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride3]\n" - - // Compute and store U for the first tile - // First row - "fneg U.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fneg U.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fneg U.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row, simultaneously load the first column of inputs for the - // next tile. - "fneg U.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - // Update the loop counter, subtract two to account for both the head and - // the tail. - "subs %x[tile_j], %x[tile_j], #2\n" - "beq 2f\n" // Jump to "A" tail if out of tiles - - // ---------------------------------------------------------------------- - "1:" - // Start part A - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fneg AXTx13.4s, x_33.4s\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "fsub AXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" - "fneg AXTx14.4s, x_34.4s\n" - "fadd AXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub AXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "fsub AXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride2]\n" - - // Compute and store U. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, AXTx12.4s, AXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, AXTx22.4s, AXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, AXTx32.4s, AXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, AXTx42.4s, AXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - "subs %x[tile_j], %x[tile_j], #1\n" - "beq 3f\n" // Jump to 'B' tail - - // Start part B - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fneg BXTx13.4s, x_33.4s\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" - "fneg BXTx14.4s, x_34.4s\n" - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "fsub BXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride2]\n" - - // Compute and store U. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - "subs %x[tile_j], %x[tile_j], #1\n" - "bne 1b\n" // Continue loop, otherwise flow into 'A' tail - - // ---------------------------------------------------------------------- - "2:" - // 'A' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fneg AXTx13.4s, x_33.4s\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "fsub AXTx43.4s, x_23.4s, x_43.4s\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" - - "b 4f\n" // Jump to end of function - - // ---------------------------------------------------------------------- - "3:" - // 'B' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fneg BXTx13.4s, x_33.4s\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" - - // ---------------------------------------------------------------------- - "4:" - // End of function - - // Clear names - ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" - ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" - ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n" - ".unreq AXTx11\n" ".unreq BXTx13\n" - ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" - ".unreq AXTx21\n" ".unreq BXTx23\n" - ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" - ".unreq AXTx31\n" ".unreq BXTx33\n" - ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" - ".unreq AXTx41\n" ".unreq BXTx43\n" - ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" - ".unreq AXTx13\n" ".unreq BXTx11\n" - ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" - ".unreq AXTx23\n" ".unreq BXTx21\n" - ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" - ".unreq AXTx33\n" ".unreq BXTx31\n" - ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" - ".unreq AXTx43\n" ".unreq BXTx41\n" - ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" - ".unreq U\n" ".unreq qU\n" - : [inptr1] "+r" (inptr0), // Offset to account for padded row - [inptr2] "+r" (inptr1), // Offset to account for padded row - [inptr3] "+r" (inptr2), // Offset to account for padded row - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [tile_j] "+r" (tile_j) // Tile counter - : [colstride1] "r" (1 * input_col_stride * sizeof(float)), - [colstride2] "r" (2 * input_col_stride * sizeof(float)), - [colstride3] "r" (3 * input_col_stride * sizeof(float)), - [mstride1] "r" (1 * matrix_stride * sizeof(float)), - [mstride2] "r" (2 * matrix_stride * sizeof(float)), - [mstride3] "r" (3 * matrix_stride * sizeof(float)), - [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24" - ); -} - -// Pad left, right and bottom by 1. -template <> -template <> -inline void Winograd2x2_3x3GemmInput::process_tile_row<0, 1, 1, 1, 4>( - const int tile_N, - const float* const input, - const int input_row_stride, - const int input_col_stride, - float* const matrix, - const int matrix_stride, - const int matrix_row_stride -) { - const float *inptr0 = input; - const float *inptr1 = input + input_row_stride; - const float *inptr2 = input + input_row_stride * 2; - - float *outptr0 = matrix; - float *outptr4 = matrix + matrix_stride * 4; - float *outptr8 = matrix + matrix_stride * 8; - float *outptr12 = matrix + matrix_stride * 12; - - int tile_j = tile_N; // Tiles to process - - asm volatile ( - // Named SIMD registers according to the policy given above - // Registers into which to load the latter two columns of `x` - // NOTE: Bottom row is not required since since it is padded. - "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n" - "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" - "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" - - // Registers for storing X.T x (both A and B halves) - "AXTx11 .req v8\n" "BXTx13 .req v8\n" - "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" - "AXTx21 .req v10\n" "BXTx23 .req v10\n" - "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" - "AXTx31 .req v12\n" "BXTx33 .req v12\n" - "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" - "AXTx41 .req v14\n" "BXTx43 .req v14\n" - "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" - "AXTx13 .req v16\n" "BXTx11 .req v16\n" - "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" - "AXTx23 .req v18\n" "BXTx21 .req v18\n" - "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" - "AXTx33 .req v20\n" "BXTx31 .req v20\n" - "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" - "AXTx43 .req v22\n" "BXTx41 .req v22\n" - "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" - - // Result register (TODO Does using more registers yield better - // performance) - "U .req v24\n qU .req q24\n" - - // ---------------------------------------------------------------------- - // Head of loop - // Loads a complete 4x4 tile of x, computes X.T x, computes and stores - // `U = X.T x X`. Prepares for the 'A' half of the loop. - // NOTE: Since the first tile has the leftmost column padded we can - // skip 4 loads and 4 calculations for the matrix X.T x X. - - // Temporarily alias registers for computing the first (non-padded) - // column of x. - "x_12 .req v0\n qx_12 .req q0\n" - "x_22 .req v1\n qx_22 .req q1\n" - "x_32 .req v2\n qx_32 .req q2\n" - - "ldr qx_12, [%x[inptr0]]\n" - "ldr qx_22, [%x[inptr1]]\n" - "ldr qx_32, [%x[inptr2]]\n" - - "fsub BXTx12.4s, x_12.4s, x_32.4s\n" - "fadd BXTx22.4s, x_22.4s, x_32.4s\n" - "fsub BXTx32.4s, x_32.4s, x_22.4s\n" - "mov BXTx42.16b, x_22.16b\n" // Probably should do better - - ".unreq x_12\n .unreq qx_12\n" - ".unreq x_22\n .unreq qx_22\n" - ".unreq x_32\n .unreq qx_32\n" - - // Load and compute latter two columns of the first tile. Progress the - // input pointers (by three columns so that the each points are the - // second column of the next tile, that is, each points at the first - // column which must be read for the next tile. - "ldr qx_13, [%x[inptr0], %x[colstride1]]\n" - "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" - "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" - - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride2]]\n" - - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" - - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" - - "mov BXTx43.16b, x_23.16b\n" - "fsub BXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride3]\n" - - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride3]\n" - - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride3]\n" - - "mov BXTx44.16b, x_24.16b\n" - - // Compute and store U for the first tile - // First row - "fneg U.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fneg U.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fneg U.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row, simultaneously load the first column of inputs for the - // next tile. - "fneg U.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - // Update the loop counter, subtract two to account for both the head and - // the tail. - "subs %x[tile_j], %x[tile_j], #2\n" - "beq 2f\n" // Jump to "A" tail if out of tiles - - // ---------------------------------------------------------------------- - "1:" - // Start part A - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fsub AXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "mov AXTx43.16b, x_23.16b\n" - - "fsub AXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride2]\n" - "fadd AXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub AXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "mov AXTx44.16b, x_24.16b\n" - - // Compute and store U. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, AXTx12.4s, AXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, AXTx22.4s, AXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, AXTx32.4s, AXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, AXTx42.4s, AXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - "subs %x[tile_j], %x[tile_j], #1\n" - "beq 3f\n" // Jump to 'B' tail - - // Start part B - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "mov BXTx43.16b, x_23.16b\n" - - "fsub BXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride2]\n" - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "mov BXTx44.16b, x_24.16b\n" - - // Compute and store U. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - "subs %x[tile_j], %x[tile_j], #1\n" - "bne 1b\n" // Continue loop, otherwise flow into 'A' tail - - // ---------------------------------------------------------------------- - "2:" - // 'A' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fsub AXTx13.4s, x_13.4s, x_33.4s\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "mov AXTx43.16b, x_23.16b\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" - - "b 4f\n" // Jump to end of function - - // ---------------------------------------------------------------------- - "3:" - // 'B' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "mov BXTx43.16b, x_23.16b\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" - - // ---------------------------------------------------------------------- - "4:" - // End of function - - // Clear names - ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n" - ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" - ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" - ".unreq AXTx11\n" ".unreq BXTx13\n" - ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" - ".unreq AXTx21\n" ".unreq BXTx23\n" - ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" - ".unreq AXTx31\n" ".unreq BXTx33\n" - ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" - ".unreq AXTx41\n" ".unreq BXTx43\n" - ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" - ".unreq AXTx13\n" ".unreq BXTx11\n" - ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" - ".unreq AXTx23\n" ".unreq BXTx21\n" - ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" - ".unreq AXTx33\n" ".unreq BXTx31\n" - ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" - ".unreq AXTx43\n" ".unreq BXTx41\n" - ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" - ".unreq U\n" ".unreq qU\n" - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [tile_j] "+r" (tile_j) // Tile counter - : [colstride1] "r" (1 * input_col_stride * sizeof(float)), - [colstride2] "r" (2 * input_col_stride * sizeof(float)), - [colstride3] "r" (3 * input_col_stride * sizeof(float)), - [mstride1] "r" (1 * matrix_stride * sizeof(float)), - [mstride2] "r" (2 * matrix_stride * sizeof(float)), - [mstride3] "r" (3 * matrix_stride * sizeof(float)), - [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24" - ); -} -} -#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp deleted file mode 100644 index ad1ad55291..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp +++ /dev/null @@ -1,961 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include "../input_2x2_3x3.hpp" - -#ifdef __aarch64__ - -namespace winograd { - -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 0, 0, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 1*input_row_stride; - auto inptr2 = inptr0 + 2*input_row_stride; - auto inptr3 = inptr0 + 3*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_11 .req v0\n" "qX_11 .req q0\n" - "X_12 .req v1\n" "qX_12 .req q1\n" - "X_13 .req v2\n" "qX_13 .req q2\n" - "X_14 .req v3\n" "qX_14 .req q3\n" - "X_21 .req v4\n" "qX_21 .req q4\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_24 .req v7\n" "qX_24 .req q7\n" - "X_31 .req v8\n" "qX_31 .req q8\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_34 .req v11\n" "qX_34 .req q11\n" - "X_41 .req v12\n" "qX_41 .req q12\n" - "X_42 .req v13\n" "qX_42 .req q13\n" - "X_43 .req v14\n" "qX_43 .req q14\n" - "X_44 .req v15\n" "qX_44 .req q15\n" - "xX_11 .req v16\n" - "xX_12 .req v17\n" - "xX_13 .req v18\n" - "xX_14 .req v19\n" - "xX_21 .req v20\n" - "xX_22 .req v21\n" - "xX_23 .req v22\n" - "xX_24 .req v23\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req v27\n" - "xX_41 .req v28\n" - "xX_42 .req v29\n" - "xX_43 .req v30\n" - "xX_44 .req v31\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_11, [%x[inptr0]]\n" - "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" - "ldr qX_14, [%x[inptr0], %x[colstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qX_21, [%x[inptr1]]\n" - "fsub xX_11.4s, x_11.4s, x_13.4s\n" - "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" - "fadd xX_12.4s, x_12.4s, x_13.4s\n" - "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" - "fsub xX_13.4s, x_13.4s, x_12.4s\n" - "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" - "fsub xX_14.4s, x_12.4s, x_14.4s\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qX_31, [%x[inptr2]]\n" - "fsub xX_21.4s, x_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" - "fsub xX_24.4s, x_22.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "ldr qX_41, [%x[inptr3]]\n" - "fsub xX_31.4s, x_31.4s, x_33.4s\n" - "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "ldr qX_44, [%x[inptr3], %x[colstride3]]\n" - "fsub xX_34.4s, x_32.4s, x_34.4s\n" - "add %x[inptr3], %x[inptr3], #0x10\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fsub xX_41.4s, x_41.4s, x_43.4s\n" - - "fsub U.4s, xX_11.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fsub U.4s, xX_12.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, xX_13.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, xX_14.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd xX_42.4s, x_42.4s, x_43.4s\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub xX_43.4s, x_43.4s, x_42.4s\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fsub xX_44.4s, x_42.4s, x_44.4s\n" - - "fsub U.4s, xX_21.4s, xX_41.4s\n" - "str qU, [%x[outptr12]]\n" - "fsub U.4s, xX_22.4s, xX_42.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, xX_23.4s, xX_43.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "fsub U.4s, xX_24.4s, xX_44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq qU\n" - ".unreq U\n" - ".unreq X_11\n" ".unreq qX_11\n" - ".unreq X_12\n" ".unreq qX_12\n" - ".unreq X_13\n" ".unreq qX_13\n" - ".unreq X_14\n" ".unreq qX_14\n" - ".unreq X_21\n" ".unreq qX_21\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_24\n" ".unreq qX_24\n" - ".unreq X_31\n" ".unreq qX_31\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_34\n" ".unreq qX_34\n" - ".unreq X_41\n" ".unreq qX_41\n" - ".unreq X_42\n" ".unreq qX_42\n" - ".unreq X_43\n" ".unreq qX_43\n" - ".unreq X_44\n" ".unreq qX_44\n" - ".unreq xX_11\n" - ".unreq xX_12\n" - ".unreq xX_13\n" - ".unreq xX_14\n" - ".unreq xX_21\n" - ".unreq xX_22\n" - ".unreq xX_23\n" - ".unreq xX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - ".unreq xX_41\n" - ".unreq xX_42\n" - ".unreq xX_43\n" - ".unreq xX_44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [inptr3] "+r" (inptr3), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [colstride3] "r" (input_col_stride * sizeof(float) * 3), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} - -// Pad top by 1 -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<1, 0, 0, 0, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 0*input_row_stride; - auto inptr2 = inptr0 + 1*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_21 .req v4\n" "qX_21 .req q4\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_24 .req v7\n" "qX_24 .req q7\n" - "X_31 .req v8\n" "qX_31 .req q8\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_34 .req v11\n" "qX_34 .req q11\n" - "X_41 .req v12\n" "qX_41 .req q12\n" - "X_42 .req v13\n" "qX_42 .req q13\n" - "X_43 .req v14\n" "qX_43 .req q14\n" - "X_44 .req v15\n" "qX_44 .req q15\n" - "xX_21 .req v20\n" - "xX_22 .req v21\n" - "xX_23 .req v22\n" - "xX_24 .req v23\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req v27\n" - "xX_41 .req v28\n" - "xX_42 .req v29\n" - "xX_43 .req v30\n" - "xX_44 .req v31\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_21, [%x[inptr1]]\n" - "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" - "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" - "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qX_31, [%x[inptr2]]\n" - "fsub xX_21.4s, x_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" - "fsub xX_24.4s, x_22.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "ldr qX_41, [%x[inptr3]]\n" - "fsub xX_31.4s, x_31.4s, x_33.4s\n" - "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "ldr qX_44, [%x[inptr3], %x[colstride3]]\n" - "fsub xX_34.4s, x_32.4s, x_34.4s\n" - "add %x[inptr3], %x[inptr3], #0x10\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fsub xX_41.4s, x_41.4s, x_43.4s\n" - - "fneg U.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fneg U.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fneg U.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fneg U.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd xX_42.4s, x_42.4s, x_43.4s\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub xX_43.4s, x_43.4s, x_42.4s\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fsub xX_44.4s, x_42.4s, x_44.4s\n" - - "fsub U.4s, xX_21.4s, xX_41.4s\n" - "str qU, [%x[outptr12]]\n" - "fsub U.4s, xX_22.4s, xX_42.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, xX_23.4s, xX_43.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "fsub U.4s, xX_24.4s, xX_44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq qU\n" - ".unreq U\n" - ".unreq X_21\n" ".unreq qX_21\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_24\n" ".unreq qX_24\n" - ".unreq X_31\n" ".unreq qX_31\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_34\n" ".unreq qX_34\n" - ".unreq X_41\n" ".unreq qX_41\n" - ".unreq X_42\n" ".unreq qX_42\n" - ".unreq X_43\n" ".unreq qX_43\n" - ".unreq X_44\n" ".unreq qX_44\n" - ".unreq xX_21\n" - ".unreq xX_22\n" - ".unreq xX_23\n" - ".unreq xX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - ".unreq xX_41\n" - ".unreq xX_42\n" - ".unreq xX_43\n" - ".unreq xX_44\n" - - : [inptr1] "+r" (inptr0), // Offset for missing row - [inptr2] "+r" (inptr1), // Offset for missing row - [inptr3] "+r" (inptr2), // Offset for missing row - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [colstride3] "r" (input_col_stride * sizeof(float) * 3), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} - -// Pad left by 1 -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 1, 0, 0, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 1*input_row_stride; - auto inptr2 = inptr0 + 2*input_row_stride; - auto inptr3 = inptr0 + 3*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_12 .req v1\n" "qX_12 .req q1\n" - "X_13 .req v2\n" "qX_13 .req q2\n" - "X_14 .req v3\n" "qX_14 .req q3\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_24 .req v7\n" "qX_24 .req q7\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_34 .req v11\n" "qX_34 .req q11\n" - "X_42 .req v13\n" "qX_42 .req q13\n" - "X_43 .req v14\n" "qX_43 .req q14\n" - "X_44 .req v15\n" "qX_44 .req q15\n" - "xX_11 .req v16\n" - "xX_12 .req v17\n" - "xX_13 .req v18\n" - "xX_14 .req v19\n" - "xX_21 .req v20\n" - "xX_22 .req v21\n" - "xX_23 .req v22\n" - "xX_24 .req v23\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req v27\n" - "xX_41 .req v28\n" - "xX_42 .req v29\n" - "xX_43 .req v30\n" - "xX_44 .req v31\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_12, [%x[inptr0]]\n" - "ldr qX_13, [%x[inptr0], %x[colstride1]]\n" - "ldr qX_14, [%x[inptr0], %x[colstride2]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "fneg xX_11.4s, x_13.4s\n" - "ldr qX_22, [%x[inptr1]]\n" - "fadd xX_12.4s, x_12.4s, x_13.4s\n" - "ldr qX_23, [%x[inptr1], %x[colstride1]]\n" - "fsub xX_13.4s, x_13.4s, x_12.4s\n" - "ldr qX_24, [%x[inptr1], %x[colstride2]]\n" - "fsub xX_14.4s, x_12.4s, x_14.4s\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "fneg xX_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride1]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "ldr qX_34, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_24.4s, x_22.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "fneg xX_31.4s, x_33.4s\n" - "ldr qX_42, [%x[inptr3]]\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "ldr qX_43, [%x[inptr3], %x[colstride1]]\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "ldr qX_44, [%x[inptr3], %x[colstride2]]\n" - "fsub xX_34.4s, x_32.4s, x_34.4s\n" - "add %x[inptr3], %x[inptr3], #0x10\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fneg xX_41.4s, x_43.4s\n" - - "fsub U.4s, xX_11.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fsub U.4s, xX_12.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, xX_13.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, xX_14.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd xX_42.4s, x_42.4s, x_43.4s\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub xX_43.4s, x_43.4s, x_42.4s\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fsub xX_44.4s, x_42.4s, x_44.4s\n" - - "fsub U.4s, xX_21.4s, xX_41.4s\n" - "str qU, [%x[outptr12]]\n" - "fsub U.4s, xX_22.4s, xX_42.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, xX_23.4s, xX_43.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "fsub U.4s, xX_24.4s, xX_44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq X_12\n" ".unreq qX_12\n" - ".unreq X_13\n" ".unreq qX_13\n" - ".unreq X_14\n" ".unreq qX_14\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_24\n" ".unreq qX_24\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_34\n" ".unreq qX_34\n" - ".unreq X_42\n" ".unreq qX_42\n" - ".unreq X_43\n" ".unreq qX_43\n" - ".unreq X_44\n" ".unreq qX_44\n" - ".unreq xX_11\n" - ".unreq xX_12\n" - ".unreq xX_13\n" - ".unreq xX_14\n" - ".unreq xX_21\n" - ".unreq xX_22\n" - ".unreq xX_23\n" - ".unreq xX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - ".unreq xX_41\n" - ".unreq xX_42\n" - ".unreq xX_43\n" - ".unreq xX_44\n" - ".unreq U\n" - ".unreq qU\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [inptr3] "+r" (inptr3), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} - -// Pad bottom by 1 -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 1, 0, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 1*input_row_stride; - auto inptr2 = inptr0 + 2*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_11 .req v0\n" "qX_11 .req q0\n" - "X_12 .req v1\n" "qX_12 .req q1\n" - "X_13 .req v2\n" "qX_13 .req q2\n" - "X_14 .req v3\n" "qX_14 .req q3\n" - "X_21 .req v4\n" "qX_21 .req q4\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_24 .req v7\n" "qX_24 .req q7\n" - "X_31 .req v8\n" "qX_31 .req q8\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_34 .req v11\n" "qX_34 .req q11\n" - "xX_11 .req v16\n" - "xX_12 .req v17\n" - "xX_13 .req v18\n" - "xX_14 .req v19\n" - "xX_21 .req v20\n" "qxX_21 .req q20\n" - "xX_22 .req v21\n" "qxX_22 .req q21\n" - "xX_23 .req v22\n" "qxX_23 .req q22\n" - "xX_24 .req v23\n" "qxX_24 .req q23\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req v27\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_11, [%x[inptr0]]\n" - "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" - "ldr qX_14, [%x[inptr0], %x[colstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qX_21, [%x[inptr1]]\n" - "fsub xX_11.4s, x_11.4s, x_13.4s\n" - "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" - "fadd xX_12.4s, x_12.4s, x_13.4s\n" - "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" - "fsub xX_13.4s, x_13.4s, x_12.4s\n" - "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" - "fsub xX_14.4s, x_12.4s, x_14.4s\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qX_31, [%x[inptr2]]\n" - "fsub xX_21.4s, x_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" - "fsub xX_24.4s, x_22.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "fsub xX_31.4s, x_31.4s, x_33.4s\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "fsub xX_34.4s, x_32.4s, x_34.4s\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fsub U.4s, xX_11.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fsub U.4s, xX_12.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, xX_13.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, xX_14.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "str qxX_21, [%x[outptr12]]\n" - "str qxX_22, [%x[outptr12], %x[mstride1]]\n" - "str qxX_23, [%x[outptr12], %x[mstride2]]\n" - "str qxX_24, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq qU\n" - ".unreq U\n" - ".unreq X_11\n" ".unreq qX_11\n" - ".unreq X_12\n" ".unreq qX_12\n" - ".unreq X_13\n" ".unreq qX_13\n" - ".unreq X_14\n" ".unreq qX_14\n" - ".unreq X_21\n" ".unreq qX_21\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_24\n" ".unreq qX_24\n" - ".unreq X_31\n" ".unreq qX_31\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_34\n" ".unreq qX_34\n" - ".unreq xX_11\n" - ".unreq xX_12\n" - ".unreq xX_13\n" - ".unreq xX_14\n" - ".unreq xX_21\n" ".unreq qxX_21\n" - ".unreq xX_22\n" ".unreq qxX_22\n" - ".unreq xX_23\n" ".unreq qxX_23\n" - ".unreq xX_24\n" ".unreq qxX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [colstride3] "r" (input_col_stride * sizeof(float) * 3), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} - -// Pad right by 1 -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 0, 1, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 1*input_row_stride; - auto inptr2 = inptr0 + 2*input_row_stride; - auto inptr3 = inptr0 + 3*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_11 .req v0\n" "qX_11 .req q0\n" - "X_12 .req v1\n" "qX_12 .req q1\n" - "X_13 .req v2\n" "qX_13 .req q2\n" - "X_21 .req v4\n" "qX_21 .req q4\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_31 .req v8\n" "qX_31 .req q8\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_41 .req v12\n" "qX_41 .req q12\n" - "X_42 .req v13\n" "qX_42 .req q13\n" - "X_43 .req v14\n" "qX_43 .req q14\n" - "xX_11 .req v16\n" - "xX_12 .req v17\n" - "xX_13 .req v18\n" - "xX_14 .req x_12\n" - "xX_21 .req v20\n" - "xX_22 .req v21\n" - "xX_23 .req v22\n" - "xX_24 .req x_22\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req x_32\n" - "xX_41 .req v28\n" - "xX_42 .req v29\n" - "xX_43 .req v30\n" - "xX_44 .req x_42\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_11, [%x[inptr0]]\n" - "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qX_21, [%x[inptr1]]\n" - "fsub xX_11.4s, x_11.4s, x_13.4s\n" - "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" - "fadd xX_12.4s, x_12.4s, x_13.4s\n" - "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" - "fsub xX_13.4s, x_13.4s, x_12.4s\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qX_31, [%x[inptr2]]\n" - "fsub xX_21.4s, x_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "ldr qX_41, [%x[inptr3]]\n" - "fsub xX_31.4s, x_31.4s, x_33.4s\n" - "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "add %x[inptr3], %x[inptr3], #0x10\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fsub xX_41.4s, x_41.4s, x_43.4s\n" - - "fsub U.4s, xX_11.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fsub U.4s, xX_12.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, xX_13.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, xX_14.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd xX_42.4s, x_42.4s, x_43.4s\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub xX_43.4s, x_43.4s, x_42.4s\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fsub U.4s, xX_21.4s, xX_41.4s\n" - "str qU, [%x[outptr12]]\n" - "fsub U.4s, xX_22.4s, xX_42.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, xX_23.4s, xX_43.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "fsub U.4s, xX_24.4s, xX_44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq qU\n" - ".unreq U\n" - ".unreq X_11\n" ".unreq qX_11\n" - ".unreq X_12\n" ".unreq qX_12\n" - ".unreq X_13\n" ".unreq qX_13\n" - ".unreq X_21\n" ".unreq qX_21\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_31\n" ".unreq qX_31\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_41\n" ".unreq qX_41\n" - ".unreq X_42\n" ".unreq qX_42\n" - ".unreq X_43\n" ".unreq qX_43\n" - ".unreq xX_11\n" - ".unreq xX_12\n" - ".unreq xX_13\n" - ".unreq xX_14\n" - ".unreq xX_21\n" - ".unreq xX_22\n" - ".unreq xX_23\n" - ".unreq xX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - ".unreq xX_41\n" - ".unreq xX_42\n" - ".unreq xX_43\n" - ".unreq xX_44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [inptr3] "+r" (inptr3), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} -} -#endif diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp deleted file mode 100644 index 033442aa14..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -namespace winograd { - /* Transform a kernel into the Winograd domain. - * - * NOTE: It is assumed that the kernel is in the form [height x width x - * input_channels x output_channel]. - */ - template - struct winograd2x2_3x3_gemm_kernel_transform_impl{ - static void execute( - const KernelShape &shape, - const T* const kernel, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride - ); - - protected: - template - static void transform_kernel( - const T* const kernel, - const int n_input_channels, - const int n_output_channels, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride - ); - }; -} - -/*****************************************************************************/ -/* Transform a fp32 kernel into the Winograd domain. - */ -#include "kernel_2x2_3x3/a64_float.hpp" // AArch64 specialisations - -namespace winograd -{ -template <> -inline void winograd2x2_3x3_gemm_kernel_transform_impl::execute( - const KernelShape &shape, - const float* const kernel, - float* const matrix_base, - const int matrix_stride, - const int matrix_row_stride -) { - // Delegate based on tail size - const int n_input_channels = shape.n_input_channels; - const int n_output_channels = shape.n_output_channels; - - switch (n_output_channels % 4) { - case 0: - transform_kernel<0>( - kernel, n_input_channels, n_output_channels, - matrix_base, matrix_stride, matrix_row_stride - ); - break; - case 1: - transform_kernel<1>( - kernel, n_input_channels, n_output_channels, - matrix_base, matrix_stride, matrix_row_stride - ); - break; - case 2: - transform_kernel<2>( - kernel, n_input_channels, n_output_channels, - matrix_base, matrix_stride, matrix_row_stride - ); - break; - case 3: - transform_kernel<3>( - kernel, n_input_channels, n_output_channels, - matrix_base, matrix_stride, matrix_row_stride - ); - break; - default: - ARM_COMPUTE_ERROR("Cannot happen"); - break; - } -} - -template <> -template -inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel( - const float* const kernel, - const int n_input_channels, - const int n_output_channels, - float* const matrix_base, - const int mstride, - const int matrix_row_stride -) { - // Use one input pointer for each row of the kernel, use two additional - // offsets to extract columns. - const int kernel_col_stride = n_input_channels * n_output_channels; - const int kernel_row_stride = 3 * kernel_col_stride; - const float *inptr0 = kernel; - const float *inptr1 = kernel + kernel_row_stride; - const float *inptr2 = kernel + kernel_row_stride*2; - - // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three - // offsets to extract further matrices. - float *outptr0 = matrix_base; - float *outptr4 = matrix_base + mstride * 4; - float *outptr8 = matrix_base + mstride * 8; - float *outptr12 = matrix_base + mstride * 12; - - // For every input channel - for (int in_c = 0; in_c < n_input_channels; in_c++) { - // For every output channel - for (int c = 0; c < n_output_channels; c++) { - // Read in the kernel - float w11 = inptr0[0], w12 = inptr0[kernel_col_stride], w13 = inptr0[kernel_col_stride*2]; - float w21 = inptr1[0], w22 = inptr1[kernel_col_stride], w23 = inptr1[kernel_col_stride*2]; - float w31 = inptr2[0], w32 = inptr2[kernel_col_stride], w33 = inptr2[kernel_col_stride*2]; - - // Progress input pointers - inptr0++; - inptr1++; - inptr2++; - - // Compute the kernel W w, note we need only compute the middle two rows - // (2 and 3) because the first and last rows are merely copies of values - // from the matrix w. - float Ww11 = w11, Ww12 = w12, Ww13 = w13; - float Ww21 = 0.5*(w11 + w21 + w31), Ww22 = 0.5*(w12 + w22 + w32), Ww23 = 0.5*(w13 + w23 + w33); - float Ww31 = 0.5*(w11 - w21 + w31), Ww32 = 0.5*(w12 - w22 + w32), Ww33 = 0.5*(w13 - w23 + w33); - float Ww41 = w31, Ww42 = w32, Ww43 = w33; - - // Hence compute W w W.T; again note we need compute only the middle two - // columns since the first and last columns are copies of the first and - // last columns of the previous matrix. - float WwWT11 = Ww11, WwWT12 = 0.5*(Ww11 + Ww12 + Ww13), WwWT13 = 0.5*(Ww11 - Ww12 + Ww13), WwWT14 = Ww13; - float WwWT21 = Ww21, WwWT22 = 0.5*(Ww21 + Ww22 + Ww23), WwWT23 = 0.5*(Ww21 - Ww22 + Ww23), WwWT24 = Ww23; - float WwWT31 = Ww31, WwWT32 = 0.5*(Ww31 + Ww32 + Ww33), WwWT33 = 0.5*(Ww31 - Ww32 + Ww33), WwWT34 = Ww33; - float WwWT41 = Ww41, WwWT42 = 0.5*(Ww41 + Ww42 + Ww43), WwWT43 = 0.5*(Ww41 - Ww42 + Ww43), WwWT44 = Ww43; - - // Store the computed weights - outptr0[0 * mstride] = WwWT11; - outptr0[1 * mstride] = WwWT12; - outptr0[2 * mstride] = WwWT13; - outptr0[3 * mstride] = WwWT14; - - outptr4[0 * mstride] = WwWT21; - outptr4[1 * mstride] = WwWT22; - outptr4[2 * mstride] = WwWT23; - outptr4[3 * mstride] = WwWT24; - - outptr8[0 * mstride] = WwWT31; - outptr8[1 * mstride] = WwWT32; - outptr8[2 * mstride] = WwWT33; - outptr8[3 * mstride] = WwWT34; - - outptr12[0 * mstride] = WwWT41; - outptr12[1 * mstride] = WwWT42; - outptr12[2 * mstride] = WwWT43; - outptr12[3 * mstride] = WwWT44; - - // Progress output pointers - outptr0++; - outptr4++; - outptr8++; - outptr12++; - } - - // Progression to complete stride - outptr0 += matrix_row_stride - n_output_channels; - outptr4 += matrix_row_stride - n_output_channels; - outptr8 += matrix_row_stride - n_output_channels; - outptr12 += matrix_row_stride - n_output_channels; - } -} -} diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp deleted file mode 100644 index 3dd62d1ac1..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp +++ /dev/null @@ -1,822 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __aarch64__ -namespace winograd { -template <> -template <> -inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<0>( - const float* const kernel, - const int n_input_channels, - const int n_output_channels, - float* const matrix_base, - const int mstride, - const int matrix_row_stride -) { - // Use one input pointer for each row of the kernel, use two additional - // offsets to extract columns. - const int kernel_col_stride = n_input_channels * n_output_channels; - const int kernel_row_stride = 3 * kernel_col_stride; - const float *inptr0 = kernel; - const float *inptr1 = kernel + kernel_row_stride; - const float *inptr2 = kernel + kernel_row_stride*2; - - // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three - // offsets to extract further matrices. - float *outptr0 = matrix_base; - float *outptr4 = matrix_base + mstride * 4; - float *outptr8 = matrix_base + mstride * 8; - float *outptr12 = matrix_base + mstride * 12; - - // For every input channel - for (int in_c = 0; in_c < n_input_channels; in_c++) { - int n_remaining_channels = n_output_channels; - - asm volatile ( - // Registers into which to read the kernel - "w_11 .req v0\n" "qw_11 .req q0\n" - "w_12 .req v1\n" "qw_12 .req q1\n" - "w_13 .req v2\n" "qw_13 .req q2\n" - "w_21 .req v3\n" "qw_21 .req q3\n" - "w_22 .req v4\n" "qw_22 .req q4\n" - "w_23 .req v5\n" "qw_23 .req q5\n" - "w_31 .req v6\n" "qw_31 .req q6\n" - "w_32 .req v7\n" "qw_32 .req q7\n" - "w_33 .req v8\n" "qw_33 .req q8\n" - - // Transformed matrix Ww - "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" - "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" - "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" - "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" - - // Output matrix U = WwWT - "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" - "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" - "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" - "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" - - // Storage view of output matrices - "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" - "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" - "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" - "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" - - "half .req v23\n" // {0.5, ..., 0.5} - "dup half.4s, %w[one_half]\n" - "scratch .req v24\n" - - "1:" - // Load tile of the kernel - "ldr qw_11, [%x[inptr0]]\n" - "str qU11, [%x[outptr0]]\n" - "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" - "str qU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qw_21, [%x[inptr1]]\n" - "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qw_31, [%x[inptr2]]\n" - "str qU41, [%x[outptr12]]\n" - "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" - "str qU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.4s, w_11.4s, w_31.4s\n" - "fmul Ww21.4s, scratch.4s, half.4s\n" - "fmla Ww21.4s, w_21.4s, half.4s\n" - "str qU21, [%x[outptr4]]\n" - "fmul Ww31.4s, scratch.4s, half.4s\n" - "fmls Ww31.4s, w_21.4s, half.4s\n" - "str qU31, [%x[outptr8]]\n" - - "fadd scratch.4s, w_12.4s, w_32.4s\n" - "fmul Ww22.4s, scratch.4s, half.4s\n" - "fmla Ww22.4s, w_22.4s, half.4s\n" - "fmul Ww32.4s, scratch.4s, half.4s\n" - "fmls Ww32.4s, w_22.4s, half.4s\n" - - "fadd scratch.4s, w_13.4s, w_33.4s\n" - "fmul Ww23.4s, scratch.4s, half.4s\n" - "fmla Ww23.4s, w_23.4s, half.4s\n" - "str qU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.4s, scratch.4s, half.4s\n" - "fmls Ww33.4s, w_23.4s, half.4s\n" - "str qU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns - // of U and update output pointers - "fadd scratch.4s, Ww11.4s, Ww13.4s\n" - "fmul U12.4s, scratch.4s, half.4s\n" - "fmla U12.4s, Ww12.4s, half.4s\n" - "str qU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.4s, scratch.4s, half.4s\n" - "fmls U13.4s, Ww12.4s, half.4s\n" - "str qU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd scratch.4s, Ww21.4s, Ww23.4s\n" - "fmul U22.4s, scratch.4s, half.4s\n" - "fmla U22.4s, Ww22.4s, half.4s\n" - "str qU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.4s, scratch.4s, half.4s\n" - "fmls U23.4s, Ww22.4s, half.4s\n" - "str qU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fadd scratch.4s, Ww31.4s, Ww33.4s\n" - "fmul U32.4s, scratch.4s, half.4s\n" - "fmla U32.4s, Ww32.4s, half.4s\n" - "str qU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.4s, scratch.4s, half.4s\n" - "fmls U33.4s, Ww32.4s, half.4s\n" - "str qU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fadd scratch.4s, Ww41.4s, Ww43.4s\n" - "fmul U42.4s, scratch.4s, half.4s\n" - "fmla U42.4s, Ww42.4s, half.4s\n" - "str qU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.4s, scratch.4s, half.4s\n" - "fmls U43.4s, Ww42.4s, half.4s\n" - "str qU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" - "bne 1b\n" - - // Clear aliases - ".unreq half\n" - ".unreq scratch\n" - ".unreq w_11\n" ".unreq qw_11\n" - ".unreq w_12\n" ".unreq qw_12\n" - ".unreq w_13\n" ".unreq qw_13\n" - ".unreq w_21\n" ".unreq qw_21\n" - ".unreq w_22\n" ".unreq qw_22\n" - ".unreq w_23\n" ".unreq qw_23\n" - ".unreq w_31\n" ".unreq qw_31\n" - ".unreq w_32\n" ".unreq qw_32\n" - ".unreq w_33\n" ".unreq qw_33\n" - ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" - ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" - ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" - ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" - ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" - ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" - ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" - ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" - ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" - ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" - ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" - ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [n_remaining_channels] "+r" (n_remaining_channels) - : [mstride1] "r" (sizeof(float) * mstride), - [mstride2] "r" (sizeof(float) * mstride * 2), - [mstride3] "r" (sizeof(float) * mstride * 3), - [colstride1] "r" (sizeof(float) * kernel_col_stride), - [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), - [one_half] "r" (0.5f) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24" - ); - - // Progression to complete stride - outptr0 += matrix_row_stride - n_output_channels; - outptr4 += matrix_row_stride - n_output_channels; - outptr8 += matrix_row_stride - n_output_channels; - outptr12 += matrix_row_stride - n_output_channels; - } -} - -template <> -template <> -inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<2>( - const float* const kernel, - const int n_input_channels, - const int n_output_channels, - float* const matrix_base, - const int mstride, - const int matrix_row_stride -) { - // Use one input pointer for each row of the kernel, use two additional - // offsets to extract columns. - const int kernel_col_stride = n_input_channels * n_output_channels; - const int kernel_row_stride = 3 * kernel_col_stride; - const float *inptr0 = kernel; - const float *inptr1 = kernel + kernel_row_stride; - const float *inptr2 = kernel + kernel_row_stride*2; - - // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three - // offsets to extract further matrices. - float *outptr0 = matrix_base; - float *outptr4 = matrix_base + mstride * 4; - float *outptr8 = matrix_base + mstride * 8; - float *outptr12 = matrix_base + mstride * 12; - - // For every input channel - for (int in_c = 0; in_c < n_input_channels; in_c++) { - int n_remaining_channels = n_output_channels; - - asm volatile ( - // Registers into which to read the kernel - "w_11 .req v0\n" "qw_11 .req q0\n" "dw_11 .req d0\n" - "w_12 .req v1\n" "qw_12 .req q1\n" "dw_12 .req d1\n" - "w_13 .req v2\n" "qw_13 .req q2\n" "dw_13 .req d2\n" - "w_21 .req v3\n" "qw_21 .req q3\n" "dw_21 .req d3\n" - "w_22 .req v4\n" "qw_22 .req q4\n" "dw_22 .req d4\n" - "w_23 .req v5\n" "qw_23 .req q5\n" "dw_23 .req d5\n" - "w_31 .req v6\n" "qw_31 .req q6\n" "dw_31 .req d6\n" - "w_32 .req v7\n" "qw_32 .req q7\n" "dw_32 .req d7\n" - "w_33 .req v8\n" "qw_33 .req q8\n" "dw_33 .req d8\n" - - // Transformed matrix Ww - "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" - "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" - "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" - "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" - - // Output matrix U = WwWT - "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" - "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" - "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" - "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" - - // Storage view of output matrices - "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" - "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" - "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" - "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" - - "dU11 .req d0\n" "dU12 .req d15\n" "dU13 .req d16\n" "dU14 .req d2\n" - "dU21 .req d9\n" "dU22 .req d17\n" "dU23 .req d18\n" "dU24 .req d11\n" - "dU31 .req d12\n" "dU32 .req d19\n" "dU33 .req d20\n" "dU34 .req d14\n" - "dU41 .req d6\n" "dU42 .req d21\n" "dU43 .req d22\n" "dU44 .req d8\n" - - "half .req v23\n" // {0.5, ..., 0.5} - "dup half.4s, %w[one_half]\n" - "scratch .req v24\n" - - // Subtract the tail from the number of remaining channels and jump to - // the tail if necessary. - "subs %x[n_remaining_channels], %x[n_remaining_channels], #2\n" - "beq 2f\n" - - "1:" - // Load tile of the kernel - "ldr qw_11, [%x[inptr0]]\n" - "str qU11, [%x[outptr0]]\n" - "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" - "str qU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qw_21, [%x[inptr1]]\n" - "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qw_31, [%x[inptr2]]\n" - "str qU41, [%x[outptr12]]\n" - "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" - "str qU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.4s, w_11.4s, w_31.4s\n" - "fmul Ww21.4s, scratch.4s, half.4s\n" - "fmla Ww21.4s, w_21.4s, half.4s\n" - "str qU21, [%x[outptr4]]\n" - "fmul Ww31.4s, scratch.4s, half.4s\n" - "fmls Ww31.4s, w_21.4s, half.4s\n" - "str qU31, [%x[outptr8]]\n" - - "fadd scratch.4s, w_12.4s, w_32.4s\n" - "fmul Ww22.4s, scratch.4s, half.4s\n" - "fmla Ww22.4s, w_22.4s, half.4s\n" - "fmul Ww32.4s, scratch.4s, half.4s\n" - "fmls Ww32.4s, w_22.4s, half.4s\n" - - "fadd scratch.4s, w_13.4s, w_33.4s\n" - "fmul Ww23.4s, scratch.4s, half.4s\n" - "fmla Ww23.4s, w_23.4s, half.4s\n" - "str qU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.4s, scratch.4s, half.4s\n" - "fmls Ww33.4s, w_23.4s, half.4s\n" - "str qU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns - // of U and update output pointers - "fadd scratch.4s, Ww11.4s, Ww13.4s\n" - "fmul U12.4s, scratch.4s, half.4s\n" - "fmla U12.4s, Ww12.4s, half.4s\n" - "str qU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.4s, scratch.4s, half.4s\n" - "fmls U13.4s, Ww12.4s, half.4s\n" - "str qU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd scratch.4s, Ww21.4s, Ww23.4s\n" - "fmul U22.4s, scratch.4s, half.4s\n" - "fmla U22.4s, Ww22.4s, half.4s\n" - "str qU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.4s, scratch.4s, half.4s\n" - "fmls U23.4s, Ww22.4s, half.4s\n" - "str qU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fadd scratch.4s, Ww31.4s, Ww33.4s\n" - "fmul U32.4s, scratch.4s, half.4s\n" - "fmla U32.4s, Ww32.4s, half.4s\n" - "str qU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.4s, scratch.4s, half.4s\n" - "fmls U33.4s, Ww32.4s, half.4s\n" - "str qU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fadd scratch.4s, Ww41.4s, Ww43.4s\n" - "fmul U42.4s, scratch.4s, half.4s\n" - "fmla U42.4s, Ww42.4s, half.4s\n" - "str qU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.4s, scratch.4s, half.4s\n" - "fmls U43.4s, Ww42.4s, half.4s\n" - "str qU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" - "bne 1b\n" - - // Tail size 2 - "2:" - // Load tile of the kernel - "ldr dw_11, [%x[inptr0]]\n" - "str dU11, [%x[outptr0]]\n" - "ldr dw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr dw_13, [%x[inptr0], %x[colstride2]]\n" - "str dU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x08\n" - - "ldr dw_21, [%x[inptr1]]\n" - "ldr dw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr dw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x08\n" - - "ldr dw_31, [%x[inptr2]]\n" - "str dU41, [%x[outptr12]]\n" - "ldr dw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr dw_33, [%x[inptr2], %x[colstride2]]\n" - "str dU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x08\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.2s, w_11.2s, w_31.2s\n" - "fmul Ww21.2s, scratch.2s, half.2s\n" - "fmla Ww21.2s, w_21.2s, half.2s\n" - "str dU21, [%x[outptr4]]\n" - "fmul Ww31.2s, scratch.2s, half.2s\n" - "fmls Ww31.2s, w_21.2s, half.2s\n" - "str dU31, [%x[outptr8]]\n" - - "fadd scratch.2s, w_12.2s, w_32.2s\n" - "fmul Ww22.2s, scratch.2s, half.2s\n" - "fmla Ww22.2s, w_22.2s, half.2s\n" - "fmul Ww32.2s, scratch.2s, half.2s\n" - "fmls Ww32.2s, w_22.2s, half.2s\n" - - "fadd scratch.2s, w_13.2s, w_33.2s\n" - "fmul Ww23.2s, scratch.2s, half.2s\n" - "fmla Ww23.2s, w_23.2s, half.2s\n" - "str dU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.2s, scratch.2s, half.2s\n" - "fmls Ww33.2s, w_23.2s, half.2s\n" - "str dU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns of - // U and update output pointers - "fadd scratch.2s, Ww11.2s, Ww13.2s\n" - "fmul U12.2s, scratch.2s, half.2s\n" - "fmla U12.2s, Ww12.2s, half.2s\n" - "str dU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.2s, scratch.2s, half.2s\n" - "fmls U13.2s, Ww12.2s, half.2s\n" - "str dU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x08\n" - - "fadd scratch.2s, Ww21.2s, Ww23.2s\n" - "fmul U22.2s, scratch.2s, half.2s\n" - "fmla U22.2s, Ww22.2s, half.2s\n" - "str dU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.2s, scratch.2s, half.2s\n" - "fmls U23.2s, Ww22.2s, half.2s\n" - "str dU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x08\n" - - "fadd scratch.2s, Ww31.2s, Ww33.2s\n" - "fmul U32.2s, scratch.2s, half.2s\n" - "fmla U32.2s, Ww32.2s, half.2s\n" - "str dU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.2s, scratch.2s, half.2s\n" - "fmls U33.2s, Ww32.2s, half.2s\n" - "str dU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x08\n" - - "fadd scratch.2s, Ww41.2s, Ww43.2s\n" - "fmul U42.2s, scratch.2s, half.2s\n" - "fmla U42.2s, Ww42.2s, half.2s\n" - "str dU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.2s, scratch.2s, half.2s\n" - "fmls U43.2s, Ww42.2s, half.2s\n" - "str dU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x08\n" - - // Clear aliases - ".unreq half\n" - ".unreq scratch\n" - ".unreq w_11\n" ".unreq qw_11\n" ".unreq dw_11\n" - ".unreq w_12\n" ".unreq qw_12\n" ".unreq dw_12\n" - ".unreq w_13\n" ".unreq qw_13\n" ".unreq dw_13\n" - ".unreq w_21\n" ".unreq qw_21\n" ".unreq dw_21\n" - ".unreq w_22\n" ".unreq qw_22\n" ".unreq dw_22\n" - ".unreq w_23\n" ".unreq qw_23\n" ".unreq dw_23\n" - ".unreq w_31\n" ".unreq qw_31\n" ".unreq dw_31\n" - ".unreq w_32\n" ".unreq qw_32\n" ".unreq dw_32\n" - ".unreq w_33\n" ".unreq qw_33\n" ".unreq dw_33\n" - ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" - ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" - ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" - ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" - ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" - ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" - ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" - ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" - ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" - ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" - ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" - ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" - ".unreq dU11\n" ".unreq dU12\n" ".unreq dU13\n" ".unreq dU14\n" - ".unreq dU21\n" ".unreq dU22\n" ".unreq dU23\n" ".unreq dU24\n" - ".unreq dU31\n" ".unreq dU32\n" ".unreq dU33\n" ".unreq dU34\n" - ".unreq dU41\n" ".unreq dU42\n" ".unreq dU43\n" ".unreq dU44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [n_remaining_channels] "+r" (n_remaining_channels) - : [mstride1] "r" (sizeof(float) * mstride), - [mstride2] "r" (sizeof(float) * mstride * 2), - [mstride3] "r" (sizeof(float) * mstride * 3), - [colstride1] "r" (sizeof(float) * kernel_col_stride), - [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), - [one_half] "r" (0.5f) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24" - ); - - // Progression to complete stride - outptr0 += matrix_row_stride - n_output_channels; - outptr4 += matrix_row_stride - n_output_channels; - outptr8 += matrix_row_stride - n_output_channels; - outptr12 += matrix_row_stride - n_output_channels; - } -} - -template <> -template <> -inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<1>( - const float* const kernel, - const int n_input_channels, - const int n_output_channels, - float* const matrix_base, - const int mstride, - const int matrix_row_stride -) { - // Use one input pointer for each row of the kernel, use two additional - // offsets to extract columns. - const int kernel_col_stride = n_input_channels * n_output_channels; - const int kernel_row_stride = 3 * kernel_col_stride; - const float *inptr0 = kernel; - const float *inptr1 = kernel + kernel_row_stride; - const float *inptr2 = kernel + kernel_row_stride*2; - - // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three - // offsets to extract further matrices. - float *outptr0 = matrix_base; - float *outptr4 = matrix_base + mstride * 4; - float *outptr8 = matrix_base + mstride * 8; - float *outptr12 = matrix_base + mstride * 12; - - // For every input channel - for (int in_c = 0; in_c < n_input_channels; in_c++) { - int n_remaining_channels = n_output_channels; - - asm volatile ( - // Registers into which to read the kernel - "w_11 .req v0\n" "qw_11 .req q0\n" "sw_11 .req s0\n" - "w_12 .req v1\n" "qw_12 .req q1\n" "sw_12 .req s1\n" - "w_13 .req v2\n" "qw_13 .req q2\n" "sw_13 .req s2\n" - "w_21 .req v3\n" "qw_21 .req q3\n" "sw_21 .req s3\n" - "w_22 .req v4\n" "qw_22 .req q4\n" "sw_22 .req s4\n" - "w_23 .req v5\n" "qw_23 .req q5\n" "sw_23 .req s5\n" - "w_31 .req v6\n" "qw_31 .req q6\n" "sw_31 .req s6\n" - "w_32 .req v7\n" "qw_32 .req q7\n" "sw_32 .req s7\n" - "w_33 .req v8\n" "qw_33 .req q8\n" "sw_33 .req s8\n" - - // Transformed matrix Ww - "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" - "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" - "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" - "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" - - // Output matrix U = WwWT - "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" - "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" - "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" - "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" - - // Storage view of output matrices - "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" - "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" - "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" - "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" - - "sU11 .req s0\n" "sU12 .req s15\n" "sU13 .req s16\n" "sU14 .req s2\n" - "sU21 .req s9\n" "sU22 .req s17\n" "sU23 .req s18\n" "sU24 .req s11\n" - "sU31 .req s12\n" "sU32 .req s19\n" "sU33 .req s20\n" "sU34 .req s14\n" - "sU41 .req s6\n" "sU42 .req s21\n" "sU43 .req s22\n" "sU44 .req s8\n" - - "half .req v23\n" // {0.5, ..., 0.5} - "dup half.4s, %w[one_half]\n" - "scratch .req v24\n" - - // Subtract the tail from the number of remaining channels and jump to - // the tail if necessary. - "subs %x[n_remaining_channels], %x[n_remaining_channels], #1\n" - "beq 2f\n" - - "1:" - // Load tile of the kernel - "ldr qw_11, [%x[inptr0]]\n" - "str qU11, [%x[outptr0]]\n" - "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" - "str qU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qw_21, [%x[inptr1]]\n" - "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qw_31, [%x[inptr2]]\n" - "str qU41, [%x[outptr12]]\n" - "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" - "str qU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.4s, w_11.4s, w_31.4s\n" - "fmul Ww21.4s, scratch.4s, half.4s\n" - "fmla Ww21.4s, w_21.4s, half.4s\n" - "str qU21, [%x[outptr4]]\n" - "fmul Ww31.4s, scratch.4s, half.4s\n" - "fmls Ww31.4s, w_21.4s, half.4s\n" - "str qU31, [%x[outptr8]]\n" - - "fadd scratch.4s, w_12.4s, w_32.4s\n" - "fmul Ww22.4s, scratch.4s, half.4s\n" - "fmla Ww22.4s, w_22.4s, half.4s\n" - "fmul Ww32.4s, scratch.4s, half.4s\n" - "fmls Ww32.4s, w_22.4s, half.4s\n" - - "fadd scratch.4s, w_13.4s, w_33.4s\n" - "fmul Ww23.4s, scratch.4s, half.4s\n" - "fmla Ww23.4s, w_23.4s, half.4s\n" - "str qU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.4s, scratch.4s, half.4s\n" - "fmls Ww33.4s, w_23.4s, half.4s\n" - "str qU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns - // of U and update output pointers - "fadd scratch.4s, Ww11.4s, Ww13.4s\n" - "fmul U12.4s, scratch.4s, half.4s\n" - "fmla U12.4s, Ww12.4s, half.4s\n" - "str qU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.4s, scratch.4s, half.4s\n" - "fmls U13.4s, Ww12.4s, half.4s\n" - "str qU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd scratch.4s, Ww21.4s, Ww23.4s\n" - "fmul U22.4s, scratch.4s, half.4s\n" - "fmla U22.4s, Ww22.4s, half.4s\n" - "str qU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.4s, scratch.4s, half.4s\n" - "fmls U23.4s, Ww22.4s, half.4s\n" - "str qU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fadd scratch.4s, Ww31.4s, Ww33.4s\n" - "fmul U32.4s, scratch.4s, half.4s\n" - "fmla U32.4s, Ww32.4s, half.4s\n" - "str qU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.4s, scratch.4s, half.4s\n" - "fmls U33.4s, Ww32.4s, half.4s\n" - "str qU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fadd scratch.4s, Ww41.4s, Ww43.4s\n" - "fmul U42.4s, scratch.4s, half.4s\n" - "fmla U42.4s, Ww42.4s, half.4s\n" - "str qU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.4s, scratch.4s, half.4s\n" - "fmls U43.4s, Ww42.4s, half.4s\n" - "str qU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" - "bne 1b\n" - - // Tail size 1 - "2:" - // Load tile of the kernel - "ldr sw_11, [%x[inptr0]]\n" - "str sU11, [%x[outptr0]]\n" - "ldr sw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr sw_13, [%x[inptr0], %x[colstride2]]\n" - "str sU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x04\n" - - "ldr sw_21, [%x[inptr1]]\n" - "ldr sw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr sw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x04\n" - - "ldr sw_31, [%x[inptr2]]\n" - "str sU41, [%x[outptr12]]\n" - "ldr sw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr sw_33, [%x[inptr2], %x[colstride2]]\n" - "str sU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x04\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.2s, w_11.2s, w_31.2s\n" - "fmul Ww21.2s, scratch.2s, half.2s\n" - "fmla Ww21.2s, w_21.2s, half.2s\n" - "str sU21, [%x[outptr4]]\n" - "fmul Ww31.2s, scratch.2s, half.2s\n" - "fmls Ww31.2s, w_21.2s, half.2s\n" - "str sU31, [%x[outptr8]]\n" - - "fadd scratch.2s, w_12.2s, w_32.2s\n" - "fmul Ww22.2s, scratch.2s, half.2s\n" - "fmla Ww22.2s, w_22.2s, half.2s\n" - "fmul Ww32.2s, scratch.2s, half.2s\n" - "fmls Ww32.2s, w_22.2s, half.2s\n" - - "fadd scratch.2s, w_13.2s, w_33.2s\n" - "fmul Ww23.2s, scratch.2s, half.2s\n" - "fmla Ww23.2s, w_23.2s, half.2s\n" - "str sU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.2s, scratch.2s, half.2s\n" - "fmls Ww33.2s, w_23.2s, half.2s\n" - "str sU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns of - // U and update output pointers - "fadd scratch.2s, Ww11.2s, Ww13.2s\n" - "fmul U12.2s, scratch.2s, half.2s\n" - "fmla U12.2s, Ww12.2s, half.2s\n" - "str sU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.2s, scratch.2s, half.2s\n" - "fmls U13.2s, Ww12.2s, half.2s\n" - "str sU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x04\n" - - "fadd scratch.2s, Ww21.2s, Ww23.2s\n" - "fmul U22.2s, scratch.2s, half.2s\n" - "fmla U22.2s, Ww22.2s, half.2s\n" - "str sU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.2s, scratch.2s, half.2s\n" - "fmls U23.2s, Ww22.2s, half.2s\n" - "str sU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x04\n" - - "fadd scratch.2s, Ww31.2s, Ww33.2s\n" - "fmul U32.2s, scratch.2s, half.2s\n" - "fmla U32.2s, Ww32.2s, half.2s\n" - "str sU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.2s, scratch.2s, half.2s\n" - "fmls U33.2s, Ww32.2s, half.2s\n" - "str sU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x04\n" - - "fadd scratch.2s, Ww41.2s, Ww43.2s\n" - "fmul U42.2s, scratch.2s, half.2s\n" - "fmla U42.2s, Ww42.2s, half.2s\n" - "str sU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.2s, scratch.2s, half.2s\n" - "fmls U43.2s, Ww42.2s, half.2s\n" - "str sU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x04\n" - - // Clear aliases - ".unreq half\n" - ".unreq scratch\n" - ".unreq w_11\n" ".unreq qw_11\n" ".unreq sw_11\n" - ".unreq w_12\n" ".unreq qw_12\n" ".unreq sw_12\n" - ".unreq w_13\n" ".unreq qw_13\n" ".unreq sw_13\n" - ".unreq w_21\n" ".unreq qw_21\n" ".unreq sw_21\n" - ".unreq w_22\n" ".unreq qw_22\n" ".unreq sw_22\n" - ".unreq w_23\n" ".unreq qw_23\n" ".unreq sw_23\n" - ".unreq w_31\n" ".unreq qw_31\n" ".unreq sw_31\n" - ".unreq w_32\n" ".unreq qw_32\n" ".unreq sw_32\n" - ".unreq w_33\n" ".unreq qw_33\n" ".unreq sw_33\n" - ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" - ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" - ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" - ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" - ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" - ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" - ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" - ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" - ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" - ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" - ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" - ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" - ".unreq sU11\n" ".unreq sU12\n" ".unreq sU13\n" ".unreq sU14\n" - ".unreq sU21\n" ".unreq sU22\n" ".unreq sU23\n" ".unreq sU24\n" - ".unreq sU31\n" ".unreq sU32\n" ".unreq sU33\n" ".unreq sU34\n" - ".unreq sU41\n" ".unreq sU42\n" ".unreq sU43\n" ".unreq sU44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [n_remaining_channels] "+r" (n_remaining_channels) - : [mstride1] "r" (sizeof(float) * mstride), - [mstride2] "r" (sizeof(float) * mstride * 2), - [mstride3] "r" (sizeof(float) * mstride * 3), - [colstride1] "r" (sizeof(float) * kernel_col_stride), - [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), - [one_half] "r" (0.5f) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24" - ); - - // Progression to complete stride - outptr0 += matrix_row_stride - n_output_channels; - outptr4 += matrix_row_stride - n_output_channels; - outptr8 += matrix_row_stride - n_output_channels; - outptr12 += matrix_row_stride - n_output_channels; - } -} -} -#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp deleted file mode 100644 index 0992c0bb44..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp +++ /dev/null @@ -1,356 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -namespace winograd { - /* Transform from the Winograd domain back to the spatial domain. - */ - template - struct Winograd2x2_3x3GemmOutput { - static void execute( - const Tensor4DShape &output_shape, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - T* const output - ); - - protected: - /* Specialised implementation method. */ - template - static void _execute( - const Tensor4DShape &output_shape, - T *output, - const T *input, - const int matrix_stride, - const int matrix_row_stride - ); - }; - - /* Two-stage implementation of the transformation from the Winograd domain. - * - * First computes Z.F and then computes (Z.F).Z^T. - */ - template - struct Winograd2x2_3x3GemmOutput_TwoStage { - static void execute( - const Tensor4DShape &output_shape, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - T* const output - ); - - protected: - template - static void compute_zf( - const int n_rows, const int n_channels, - T* const zf, const T* const input[16] - ); - - template - static void compute_zfzT( - const Tensor4DShape &output_shape, - T* const output, const T* const zf - ); - }; -} - -#include "output_2x2_3x3/a64_float.hpp" -// #include "output_2x2_3x3/a64_float_two_stage.hpp" - -/*****************************************************************************/ -/* -template -void winograd::Winograd2x2_3x3GemmOutput::execute( - const Tensor4DShape &output_shape, - const int tile_M, - const int tile_N, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - T* const output -) { - T* const antipadding = reinterpret_cast(malloc(sizeof(T) * output_shape.n_channels)); - - // Get input pointers - const T* inptrs[16]; - for (int i = 0; i < 16; i++) { - inptrs[i] = matrices[i]; - } - - for (int batch = 0; batch < output_shape.n_batches; batch++) { - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - // Get pointers for each of the 4 output cells required for this computation - T* outptrs[4]; - for (int cell_i = 0, c = 0; cell_i < 2; cell_i++) { - for (int cell_j = 0; cell_j < 2; cell_j++, c++) { - const int i = tile_i*2 + cell_i; - const int j = tile_j*2 + cell_j; - - if (i < output_shape.n_rows && j < output_shape.n_cols) { - outptrs[c] = output + ( - (batch*output_shape.n_rows + i) * output_shape.n_cols + - j) * output_shape.n_channels; - } else { - outptrs[c] = antipadding; - } - } // cell_j - } // cell_i - - for (int n = 0; n < output_shape.n_channels; n++) { - // Read 16 values and progress pointers - T v[16]; - for (int i = 0; i < 16; i++) { - v[i] = *(inptrs[i]++); - } - - // Compute output for 4 pixels - *(outptrs[0]++) = v[ 0] + v[ 1] + v[ 2] + - v[ 4] + v[ 5] + v[ 6] + - v[ 8] + v[ 9] + v[10]; - *(outptrs[1]++) = v[ 1] - v[ 2] - v[ 3] + - v[ 5] - v[ 6] - v[ 7] + - v[ 9] - v[10] - v[11]; - *(outptrs[2]++) = v[ 4] + v[ 5] + v[ 6] - - v[ 8] - v[ 9] - v[10] - - v[12] - v[13] - v[14]; - *(outptrs[3]++) = v[ 5] - v[ 6] - v[ 7] - - v[ 9] + v[10] + v[11] - - v[13] + v[14] + v[15]; - } // output_channel - } // tile_j - } // tile_i - } // batch - - free(antipadding); -} -*/ - -/*****************************************************************************/ -/* -template -void winograd::Winograd2x2_3x3GemmOutput_TwoStage::execute( - const Tensor4DShape &output_shape, - T* const matrices[16], T* const output -) { - // Allocate memory for the intermediate matrices - const int tile_M = iceildiv(output_shape.n_rows, 2); - const int tile_N = iceildiv(output_shape.n_cols, 2); - const int n_rows = output_shape.n_batches * tile_M * tile_N; - const int n_channels = output_shape.n_channels; - T* matrices_zf = reinterpret_cast( - calloc(8 * n_rows * n_channels, sizeof(T)) - ); - - // Perform the first stage transform, computing ZF. - // Specializations should dispatch to different methods based on tail size. - compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); - - // Perform the second stage transform, finishing Z F Z^T - variable dispatch - // based on size of the output. Specialisations can also dispatch based on - // the tail-size of the channel. - if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { - compute_zfzT(output_shape, output, matrices_zf); - } else if (output_shape.n_rows % 2) { - compute_zfzT(output_shape, output, matrices_zf); - } else if (output_shape.n_cols % 2) { - compute_zfzT(output_shape, output, matrices_zf); - } else { - compute_zfzT(output_shape, output, matrices_zf); - } - - free(reinterpret_cast(matrices_zf)); -} - -template -template -void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zf( - const int n_rows, const int n_channels, - T* output, const T* const input[16] -) { - // Extract 8 output pointers - T* outptr[8]; - for (int i = 0; i < 8; i++) { - outptr[i] = output + i*n_rows*n_channels; - } - - // Copy the 16 input pointers - const T* inptr[16]; - for (int i = 0; i < 16; i++) { - inptr[i] = input[i]; - } - - // For every row of the matrices - for (int i = 0; i < n_rows; i++) { - // For every channel - for (int j = 0; j < n_channels; j++) { - // Extract values from the input matrices - T val[16]; - for (int n = 0; n < 16; n++) { - val[n] = *(inptr[n]++); - } - - // Compute output values - *(outptr[0]++) = val[0] + val[1] + val[2]; - *(outptr[1]++) = val[1] - val[2] - val[3]; - *(outptr[2]++) = val[4] + val[5] + val[6]; - *(outptr[3]++) = val[5] - val[6] - val[7]; - *(outptr[4]++) = val[8] + val[9] + val[10]; - *(outptr[5]++) = val[9] - val[10] - val[11]; - *(outptr[6]++) = val[12] + val[13] + val[14]; - *(outptr[7]++) = val[13] - val[14] - val[15]; - } - } -} - -template -template -void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zfzT( - const Tensor4DShape &output_shape, - T* const output, const T* const input -) { - // Sizing information - const int tile_M = output_shape.n_rows / 2; - const int tile_N = output_shape.n_cols / 2; - - const int n_rows = (output_shape.n_batches * - (tile_M + (tail_M ? 1 : 0)) * - (tile_N + (tail_N ? 1 : 0))); - const int n_channels = output_shape.n_channels; - - // Extract 8 input pointers - const T* inptr[8]; - for (int i = 0; i < 8; i++) { - inptr[i] = input + i*n_rows*n_channels; - } - - // Extract 4 output pointers - T* outptr00 = output; - T* outptr01 = outptr00 + n_channels; - T* outptr10 = outptr00 + output_shape.n_cols * n_channels; - T* outptr11 = outptr10 + n_channels; - - // Progress over the output tiles, generating output values. - for (int batch = 0; batch < output_shape.n_batches; batch++) { - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - T v[8]; - for (int i = 0; i < 8; i++) { - v[i] = *(inptr[i]++); - } - - // Compute the output values and progress the output pointers. - *(outptr00++) = v[0] + v[2] + v[4]; - *(outptr01++) = v[1] + v[3] + v[5]; - *(outptr10++) = v[2] - v[4] - v[6]; - *(outptr11++) = v[3] - v[5] - v[7]; - } - - // Progress the output pointers to the next column - outptr00 += n_channels; - outptr01 += n_channels; - outptr10 += n_channels; - outptr11 += n_channels; - } - - if (tail_N) { - // Only evaluate the left-most columns of the output - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - T v[8]; - for (int i = 0; i < 4; i++) { - v[i * 2] = *inptr[i * 2]; - } - for (int i = 0; i < 8; i++) { - inptr[i]++; - } - - // Compute the output values and progress the output pointers. - *(outptr00++) = v[0] + v[2] + v[4]; - *(outptr10++) = v[2] - v[4] - v[6]; - } - - // Progress the output pointers to the next column - outptr01 += n_channels; // Account for being skipped above - outptr11 += n_channels; // Account for being skipped above - } - - // Progress the output pointers to the next row - outptr00 += output_shape.n_cols * n_channels; - outptr01 += output_shape.n_cols * n_channels; - outptr10 += output_shape.n_cols * n_channels; - outptr11 += output_shape.n_cols * n_channels; - } - - if (tail_M) { - // Only work on the upper row of the output - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - T v[8]; - for (int i = 0; i < 8; i++) { - v[i] = *(inptr[i]++); - } - - // Compute the output values and progress the output pointers. - *(outptr00++) = v[0] + v[2] + v[4]; - *(outptr01++) = v[1] + v[3] + v[5]; - } - - // Progress the output pointers to the next column - outptr00 += n_channels; - outptr01 += n_channels; - outptr10 += 2 * n_channels; // Account for being skipped above - outptr11 += 2 * n_channels; // Account for being skipped above - } - - if (tail_N) { - // Only evaluate the upper-left cell of the output - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - T v[8]; - for (int i = 0; i < 3; i++) { - v[i * 2] = *inptr[i * 2]; - } - for (int i = 0; i < 8; i++) { - inptr[i]++; - } - - // Compute the output values and progress the output pointers. - *(outptr00++) = v[0] + v[2] + v[4]; - } - - // Progress the output pointers to the next column - outptr01 += n_channels; // Account for being skipped above - outptr10 += n_channels; // Account for being skipped above - outptr11 += n_channels; // Account for being skipped above - } - } - } -} -*/ diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp deleted file mode 100644 index 5925f9d569..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp +++ /dev/null @@ -1,650 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -/* Float implementation for AArch64. - */ -#ifdef __aarch64__ -namespace winograd { - - -template <> -template <> -inline void Winograd2x2_3x3GemmOutput::_execute( - const Tensor4DShape &output_shape, - float *output, - const float *input, - const int mstride, - const int matrix_row_stride -) { - const int tile_M = output_shape.n_rows / 2; - const int tile_N = output_shape.n_cols / 2; - int batch = output_shape.n_batches; - float *outptr = output; - - const float *inptr0 = input; - const float *inptr4 = input + 4 * mstride; - const float *inptr8 = input + 8 * mstride; - const float *inptr12 = input + 12 * mstride; - - const size_t col_stride = sizeof(float) * output_shape.n_channels; - const size_t row_stride = col_stride * tile_N * 2; - - asm volatile ( - // Aliases for elements of the input matrix `F` - // V-register Q-register - "F11 .req v0\n" "qF11 .req q0\n" - "F12 .req v1\n" "qF12 .req q1\n" - "F13 .req v2\n" "qF13 .req q2\n" - "F14 .req v3\n" "qF14 .req q3\n" - "F21 .req v4\n" "qF21 .req q4\n" - "F22 .req v5\n" "qF22 .req q5\n" - "F23 .req v6\n" "qF23 .req q6\n" - "F24 .req v7\n" "qF24 .req q7\n" - "F31 .req v8\n" "qF31 .req q8\n" - "F32 .req v9\n" "qF32 .req q9\n" - "F33 .req v10\n" "qF33 .req q10\n" - "F34 .req v11\n" "qF34 .req q11\n" - "F41 .req v12\n" "qF41 .req q12\n" - "F42 .req v13\n" "qF42 .req q13\n" - "F43 .req v14\n" "qF43 .req q14\n" - "F44 .req v15\n" "qF44 .req q15\n" - - // Aliases for elements of the intermediate matrix `FZ` - "FZ11 .req v16\n" - "FZ12 .req v17\n" - "FZ21 .req v18\n" - "FZ22 .req v19\n" - "FZ31 .req v20\n" - "FZ32 .req v21\n" - "FZ41 .req v22\n" - "FZ42 .req v23\n" - - // Aliases for elements of the output matrix `f` (called `g` due to case - // insensitivity of aliases). - " g11 .req v24\n" - "qg11 .req q24\n" - " g12 .req v25\n" - "qg12 .req q25\n" - " g21 .req v26\n" - "qg21 .req q26\n" - " g22 .req v27\n" - "qg22 .req q27\n" - - // Prepare the various strides - "col_stride .req %x[col_stride]\n" - "row_stride .req %x[row_stride]\n" - "row_plus_col_stride .req %x[row_plus_col_stride]\n" - - "mstride1 .req %x[mstride1]\n" - "mstride2 .req %x[mstride2]\n" - "mstride3 .req %x[mstride3]\n" - - "tile_i .req x19\n" // Tile row counter - "tile_j .req x20\n" // Tile column counter - "channel .req x21\n" // Channel counter - - "1:" // Loop over batches - "mov tile_i, %x[tile_M]\n" // Reset tile row counter - - "2:" // Loop over rows of tiles - "mov tile_j, %x[tile_N]\n" // Reset tile column counter - - "3:" // Loop over columns of tiles - // Perform initial loads of the matrix `F` - "ldr qF11, [%x[inptr0]]\n" - "ldr qF12, [%x[inptr0], mstride1]\n" - "ldr qF13, [%x[inptr0], mstride2]\n" - "ldr qF14, [%x[inptr0], mstride3]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - "ldr qF21, [%x[inptr4]]\n" - "ldr qF22, [%x[inptr4], mstride1]\n" - "subs channel, %x[n_channels], #4\n" // Reset channel counter - - "ldr qF23, [%x[inptr4], mstride2]\n" - "ldr qF24, [%x[inptr4], mstride3]\n" - "add %x[inptr4], %x[inptr4], #0x10\n" - "beq 5f\n" // Jump straight to tail if necessary - - "4:" // Loop over channels - "ldr qF31, [%x[inptr8]]\n" - "fadd FZ11.4s, F11.4s, F12.4s\n" - - "ldr qF32, [%x[inptr8], mstride1]\n" - "fsub FZ12.4s, F12.4s, F13.4s\n" - - "ldr qF33, [%x[inptr8], mstride2]\n" - "fadd FZ11.4s, FZ11.4s, F13.4s\n" - - "ldr qF34, [%x[inptr8], mstride3]\n" - "fsub FZ12.4s, FZ12.4s, F14.4s\n" - - "ldr qF41, [%x[inptr12]]\n" - "fadd FZ21.4s, F21.4s, F22.4s\n" - - "ldr qF42, [%x[inptr12], mstride1]\n" - "fsub FZ22.4s, F22.4s, F23.4s\n" - - "ldr qF43, [%x[inptr12], mstride2]\n" - "fadd FZ21.4s, FZ21.4s, F23.4s\n" - - "ldr qF44, [%x[inptr12], mstride3]\n" - "fsub FZ22.4s, FZ22.4s, F24.4s\n" - - "fadd FZ31.4s, F31.4s, F32.4s\n" - "add %x[inptr8], %x[inptr8], #0x10\n" - - "fsub FZ32.4s, F32.4s, F33.4s\n" - "add %x[inptr12], %x[inptr12], #0x10\n" - - "fadd FZ31.4s, FZ31.4s, F33.4s\n" - - "fsub FZ32.4s, FZ32.4s, F34.4s\n" - - "fadd g11.4s, FZ11.4s, FZ21.4s\n" - - "fadd g12.4s, FZ12.4s, FZ22.4s\n" - - "fadd g11.4s, g11.4s, FZ31.4s\n" - - "fadd g12.4s, g12.4s, FZ32.4s\n" - - "ldr qF11, [%x[inptr0]]\n" - "fadd FZ41.4s, F41.4s, F42.4s\n" - - "ldr qF12, [%x[inptr0], mstride1]\n" - "fsub g21.4s, FZ21.4s, FZ31.4s\n" - - "ldr qF13, [%x[inptr0], mstride2]\n" - "fsub FZ42.4s, F42.4s, F43.4s\n" - - "ldr qF14, [%x[inptr0], mstride3]\n" - "str qg11, [%x[outptr]]\n" - - "ldr qF21, [%x[inptr4]]\n" - "fadd FZ41.4s, FZ41.4s, F43.4s\n" - - "ldr qF22, [%x[inptr4], mstride1]\n" - "str qg12, [%x[outptr], col_stride]\n" - - "ldr qF23, [%x[inptr4], mstride2]\n" - "fsub FZ42.4s, FZ42.4s, F44.4s\n" - - "ldr qF24, [%x[inptr4], mstride3]\n" - "fsub g22.4s, FZ22.4s, FZ32.4s\n" - - "fsub g21.4s, g21.4s, FZ41.4s\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "fsub g22.4s, g22.4s, FZ42.4s\n" - "add %x[inptr4], %x[inptr4], #0x10\n" - - "subs channel, channel, #4\n" - - "str qg21, [%x[outptr], row_stride]\n" - - "str qg22, [%x[outptr], row_plus_col_stride]\n" - - "add %x[outptr], %x[outptr], #0x10\n" - - "bne 4b\n" - - "5:" // Channel tail - "ldr qF31, [%x[inptr8]]\n" - "fadd FZ11.4s, F11.4s, F12.4s\n" - - "ldr qF32, [%x[inptr8], mstride1]\n" - "fsub FZ12.4s, F12.4s, F13.4s\n" - - "ldr qF33, [%x[inptr8], mstride2]\n" - "fadd FZ11.4s, FZ11.4s, F13.4s\n" - - "ldr qF34, [%x[inptr8], mstride3]\n" - "fsub FZ12.4s, FZ12.4s, F14.4s\n" - - "ldr qF41, [%x[inptr12]]\n" - "fadd FZ21.4s, F21.4s, F22.4s\n" - - "ldr qF42, [%x[inptr12], mstride1]\n" - "fsub FZ22.4s, F22.4s, F23.4s\n" - - "ldr qF43, [%x[inptr12], mstride2]\n" - "fadd FZ21.4s, FZ21.4s, F23.4s\n" - - "ldr qF44, [%x[inptr12], mstride3]\n" - "fsub FZ22.4s, FZ22.4s, F24.4s\n" - - "fadd FZ31.4s, F31.4s, F32.4s\n" - "add %x[inptr8], %x[inptr8], #0x10\n" - - "fsub FZ32.4s, F32.4s, F33.4s\n" - "add %x[inptr12], %x[inptr12], #0x10\n" - - "fadd FZ31.4s, FZ31.4s, F33.4s\n" - - "fsub FZ32.4s, FZ32.4s, F34.4s\n" - - "fadd g11.4s, FZ11.4s, FZ21.4s\n" - - "fadd g12.4s, FZ12.4s, FZ22.4s\n" - - "fadd g11.4s, g11.4s, FZ31.4s\n" - - "fadd g12.4s, g12.4s, FZ32.4s\n" - - "fadd FZ41.4s, F41.4s, F42.4s\n" - - "fsub g21.4s, FZ21.4s, FZ31.4s\n" - - "fsub FZ42.4s, F42.4s, F43.4s\n" - - "str qg11, [%x[outptr]]\n" - - "fadd FZ41.4s, FZ41.4s, F43.4s\n" - - "str qg12, [%x[outptr], col_stride]\n" - - "fsub FZ42.4s, FZ42.4s, F44.4s\n" - - "fsub g22.4s, FZ22.4s, FZ32.4s\n" - - "fsub g21.4s, g21.4s, FZ41.4s\n" - - "fsub g22.4s, g22.4s, FZ42.4s\n" - - "subs channel, channel, #4\n" - - "str qg21, [%x[outptr], row_stride]\n" - - // Progress input pointers to the next row of the matrix - "add %x[inptr0], %x[inptr0], %x[mrowpad]\n" - "add %x[inptr4], %x[inptr4], %x[mrowpad]\n" - "add %x[inptr8], %x[inptr8], %x[mrowpad]\n" - "add %x[inptr12], %x[inptr12], %x[mrowpad]\n" - - "str qg22, [%x[outptr], row_plus_col_stride]\n" - - "add %x[outptr], %x[outptr], #0x10\n" - - - "add %x[outptr], %x[outptr], col_stride\n" - "subs tile_j, tile_j, #1\n" - "bne 3b\n" - - "add %x[outptr], %x[outptr], row_stride\n" - "subs tile_i, tile_i, #1\n" - "bne 2b\n" - - "subs %[batch], %[batch], #1\n" - "bne 1b\n" - - ".unreq F11\n" ".unreq qF11\n" - ".unreq F12\n" ".unreq qF12\n" - ".unreq F13\n" ".unreq qF13\n" - ".unreq F14\n" ".unreq qF14\n" - ".unreq F21\n" ".unreq qF21\n" - ".unreq F22\n" ".unreq qF22\n" - ".unreq F23\n" ".unreq qF23\n" - ".unreq F24\n" ".unreq qF24\n" - ".unreq F31\n" ".unreq qF31\n" - ".unreq F32\n" ".unreq qF32\n" - ".unreq F33\n" ".unreq qF33\n" - ".unreq F34\n" ".unreq qF34\n" - ".unreq F41\n" ".unreq qF41\n" - ".unreq F42\n" ".unreq qF42\n" - ".unreq F43\n" ".unreq qF43\n" - ".unreq F44\n" ".unreq qF44\n" - - ".unreq FZ11\n" ".unreq FZ12\n" - ".unreq FZ21\n" ".unreq FZ22\n" - ".unreq FZ31\n" ".unreq FZ32\n" - ".unreq FZ41\n" ".unreq FZ42\n" - - ".unreq g11\n" ".unreq qg11\n" - ".unreq g12\n" ".unreq qg12\n" - ".unreq g21\n" ".unreq qg21\n" - ".unreq g22\n" ".unreq qg22\n" - - ".unreq col_stride\n" - ".unreq row_stride\n" - ".unreq row_plus_col_stride\n" - - ".unreq mstride1\n" - ".unreq mstride2\n" - ".unreq mstride3\n" - - ".unreq tile_i \n" - ".unreq tile_j \n" - ".unreq channel\n" - - : [batch] "+r" (batch), - [outptr] "+r" (outptr), - [inptr0] "+r" (inptr0), - [inptr4] "+r" (inptr4), - [inptr8] "+r" (inptr8), - [inptr12] "+r" (inptr12) - : [tile_M] "r" (tile_M), - [tile_N] "r" (tile_N), - [n_channels] "r" (output_shape.n_channels), - [col_stride] "r" (col_stride), - [row_stride] "r" (row_stride), - [row_plus_col_stride] "r" (row_stride + col_stride), - [mstride1] "r" (mstride * sizeof(float)), - [mstride2] "r" (2 * mstride * sizeof(float)), - [mstride3] "r" (3 * mstride * sizeof(float)), - [mrowpad] "r" ((matrix_row_stride - output_shape.n_channels) * sizeof(float)) - : "x19", "x20", "x21", - "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", - "q12", "q13", "q14", "q15", "q16", "q17", "q18", "q19", "q20", "q21", - "q22", "q23", "q24", "q25", "q26", "q27", - "cc", "memory" - ); -} - -template <> -template -inline void Winograd2x2_3x3GemmOutput::_execute( - const Tensor4DShape &output_shape, - float *output, - const float *input, - const int mstride, - const int matrix_row_stride -) { - // Compute basic information about the shape of the matrices - const int tile_M = output_shape.n_rows / 2; - const int tile_N = output_shape.n_cols / 2; - const int n_channels = output_shape.n_channels; - - // Extract 16 input pointers - const float* inptr[16]; - for (int i = 0; i < 16; i++) { - inptr[i] = input + i*mstride; - } - - // Extract 4 output pointers - float *outptr00 = output; - float *outptr01 = outptr00 + n_channels; - float *outptr10 = outptr00 + output_shape.n_cols * n_channels; - float *outptr11 = outptr10 + n_channels; - - // Progress over the output tiles, generating output values. - for (int batch = 0; batch < output_shape.n_batches; batch++) { - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - float F[4][4]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - F[i][j] = *(inptr[i*4 + j]++); - } - } - - // Compute the matrix F.Z - float ZF[4][2]; - ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; - ZF[0][1] = F[0][1] - F[0][2] - F[0][3]; - ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; - ZF[1][1] = F[1][1] - F[1][2] - F[1][3]; - ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; - ZF[2][1] = F[2][1] - F[2][2] - F[2][3]; - ZF[3][0] = F[3][0] + F[3][1] + F[3][2]; - ZF[3][1] = F[3][1] - F[3][2] - F[3][3]; - - // Hence compute the output matrix Z^T . (F.Z) - *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; - *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1]; - *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0]; - *(outptr11++) = ZF[1][1] - ZF[2][1] - ZF[3][1]; - } - - // Progress the input pointers to the next row - for (int i = 0; i < 16; i++) { - inptr[i] += matrix_row_stride - n_channels; - } - - // Progress the output pointers to the next column - outptr00 += n_channels; - outptr01 += n_channels; - outptr10 += n_channels; - outptr11 += n_channels; - } - - if (tail_N) { - // Only evaluate the left-most columns of the output - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - float F[4][3]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 3; j++) { - F[i][j] = *(inptr[i*4 + j]++); - } - } - for (int i = 0; i < 4; i++) { - inptr[i*4 + 3]++; - } - - // Compute the matrix F.Z - float ZF[4][1]; - ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; - ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; - ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; - ZF[3][0] = F[3][0] + F[3][1] + F[3][2]; - - // Hence compute the output matrix Z^T . (F.Z) - *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; - *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0]; - } - - // Progress the input pointers to the next row - for (int i = 0; i < 16; i++) { - inptr[i] += matrix_row_stride - n_channels; - } - - // Progress the output pointers to the next column - outptr01 += n_channels; // Account for being skipped above - outptr11 += n_channels; // Account for being skipped above - } - - // Progress the output pointers to the next row - outptr00 += output_shape.n_cols * n_channels; - outptr01 += output_shape.n_cols * n_channels; - outptr10 += output_shape.n_cols * n_channels; - outptr11 += output_shape.n_cols * n_channels; - } - - if (tail_M) { - // Only work on the upper row of the output - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - float F[3][4]; - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4; j++) { - F[i][j] = *(inptr[i*4 + j]++); - } - } - for (int j = 0; j < 4; j++) { - inptr[12 + j]++; - } - - // Compute the matrix F.Z - float ZF[3][2]; - ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; - ZF[0][1] = F[0][1] - F[0][2] - F[0][3]; - ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; - ZF[1][1] = F[1][1] - F[1][2] - F[1][3]; - ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; - ZF[2][1] = F[2][1] - F[2][2] - F[2][3]; - - // Hence compute the output matrix Z^T . (F.Z) - *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; - *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1]; - } - - // Progress the input pointers to the next row - for (int i = 0; i < 16; i++) { - inptr[i] += matrix_row_stride - n_channels; - } - - // Progress the output pointers to the next column - outptr00 += n_channels; - outptr01 += n_channels; - outptr10 += 2 * n_channels; // Account for being skipped above - outptr11 += 2 * n_channels; // Account for being skipped above - } - - if (tail_N) { - // Only evaluate the upper-left cell of the output - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - float F[3][3]; - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 3; j++) { - F[i][j] = *(inptr[i*4 + j]); - } - } - for (int i = 0; i < 16; i++) { - inptr[i]++; - } - - // Compute the matrix F.Z - float ZF[3][1]; - ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; - ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; - ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; - - // Hence compute the output matrix Z^T . (F.Z) - *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; - } - - // Progress the input pointers to the next row - for (int i = 0; i < 16; i++) { - inptr[i] += matrix_row_stride - n_channels; - } - - // Progress the output pointers to the next column - outptr01 += n_channels; // Account for being skipped above - outptr10 += n_channels; // Account for being skipped above - outptr11 += n_channels; // Account for being skipped above - } - } - } -} - -/*****************************************************************************/ -template <> -inline void Winograd2x2_3x3GemmOutput::execute( - const Tensor4DShape &output_shape, - float* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - float* const output -) { - // Dispatch to an appropriate implementation based on the shape of the output - // tensor. - if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { - constexpr bool tail_M = true, tail_N = true; - switch (output_shape.n_channels % 4) { - case 0: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 1: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 2: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 3: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - default: - assert(0); - break; - } - } else if (output_shape.n_rows % 2) { - constexpr bool tail_M = true, tail_N = false; - switch (output_shape.n_channels % 4) { - case 0: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 1: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 2: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 3: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - default: - assert(0); - break; - } - } else if (output_shape.n_cols % 2) { - constexpr bool tail_M = false, tail_N = true; - switch (output_shape.n_channels % 4) { - case 0: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 1: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 2: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 3: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - default: - assert(0); - break; - - } - } else { - constexpr bool tail_M = false, tail_N = false; - switch (output_shape.n_channels % 4) { - case 0: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 1: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 2: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 3: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - default: - assert(0); - break; - - } - } -} -/*****************************************************************************/ - -} // namespace winograd -#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp deleted file mode 100644 index f551b12b52..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp +++ /dev/null @@ -1,655 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __aarch64__ - -/*****************************************************************************/ -// Compute ZF specializations - -template <> -template <> -inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zf<0>( - const int n_rows, const int n_channels, - float* output, const float* const input[16] -) { - // Make copies of some variables - int row = n_rows; - float* outptr = output; - const float* inptr = input[0]; - - // Perform the transformation - asm volatile ( - // "inptr0 .req %x[inptr]\n" - "inptr1 .req x0\n" - "inptr2 .req x1\n" - "inptr3 .req x2\n" - "inptr4 .req x3\n" - "inptr5 .req x4\n" - "inptr6 .req x5\n" - "inptr7 .req x6\n" - "inptr8 .req x7\n" - "inptr9 .req x8\n" - "inptr10 .req x9\n" - "inptr11 .req x10\n" - "inptr12 .req x11\n" - "inptr13 .req x12\n" - "inptr14 .req x13\n" - "inptr15 .req x14\n" - - // "outptr0 .req %x[outptr]\n" - "outptr1 .req x15\n" - "outptr2 .req x16\n" - "outptr3 .req x17\n" - "outptr4 .req x18\n" - "outptr5 .req x19\n" - "outptr6 .req x20\n" - "outptr7 .req x21\n" - - // Compute additional pointers into the input and output matrices. - "mstride .req x22\n" // Matrix stride - "mul mstride, %x[row], %x[n_channels]\n" - "lsl mstride, mstride, #2\n" // * sizeof(float) - - "add inptr1, %x[inptr], mstride\n" - "add inptr2, %x[inptr], mstride, LSL #1\n" - "add inptr3, inptr2, mstride\n" - "add inptr4, inptr3, mstride\n" - "add inptr5, inptr4, mstride\n" - "add inptr6, inptr5, mstride\n" - "add inptr7, inptr6, mstride\n" - "add inptr8, inptr7, mstride\n" - "add inptr9, inptr8, mstride\n" - "add inptr10, inptr9, mstride\n" - "add inptr11, inptr10, mstride\n" - "add inptr12, inptr11, mstride\n" - "add inptr13, inptr12, mstride\n" - "add inptr14, inptr13, mstride\n" - "add inptr15, inptr14, mstride\n" - - "add outptr1, %[outptr], mstride\n" - "add outptr2, outptr1, mstride\n" - "add outptr3, outptr2, mstride\n" - "add outptr4, outptr3, mstride\n" - "add outptr5, outptr4, mstride\n" - "add outptr6, outptr5, mstride\n" - "add outptr7, outptr6, mstride\n" - - ".unreq mstride\n" - - "column .req x22\n" // Column loop counter - - "1:" // Loop over rows - "ldr q0, [%x[inptr]], #0x10\n" - "ldr q1, [inptr1], #0x10\n" - "ldr q2, [inptr2], #0x10\n" - "ldr q3, [inptr3], #0x10\n" - "ldr q4, [inptr4], #0x10\n" - "ldr q5, [inptr5], #0x10\n" - "ldr q6, [inptr6], #0x10\n" - "ldr q7, [inptr7], #0x10\n" - "subs column, %x[n_channels], #0x4\n" - "beq 3f\n" - - "2:" // Loop over columns - "ldr q8, [inptr8], #0x10\n" - "prfm pldl1keep, [%x[inptr], #196]\n" - "fadd v16.4s, v0.4s, v1.4s\n" - - "ldr q9, [inptr9], #0x10\n" - "prfm pldl1keep, [inptr1, #196]\n" - "fsub v17.4s, v1.4s, v2.4s\n" - - "ldr q10, [inptr10], #0x10\n" - "prfm pldl1keep, [inptr2, #196]\n" - "fadd v16.4s, v16.4s, v2.4s\n" - - "ldr q11, [inptr11], #0x10\n" - "prfm pldl1keep, [inptr3, #196]\n" - "fsub v17.4s, v17.4s, v3.4s\n" - - "ldr q12, [inptr12], #0x10\n" - "prfm pldl1keep, [inptr4, #196]\n" - "str q16, [%x[outptr]], #0x10\n" - - "ldr q13, [inptr13], #0x10\n" - "prfm pldl1keep, [inptr5, #196]\n" - "str q17, [outptr1], #0x10\n" - - "ldr q14, [inptr14], #0x10\n" - "prfm pldl1keep, [inptr6, #196]\n" - "fadd v16.4s, v4.4s, v5.4s\n" - - "ldr q15, [inptr15], #0x10\n" - "prfm pldl1keep, [inptr7, #196]\n" - "fsub v17.4s, v5.4s, v6.4s\n" - - "ldr q0, [%x[inptr]], #0x10\n" - "prfm pldl1keep, [inptr8, #196]\n" - "fadd v16.4s, v16.4s, v6.4s\n" - - "ldr q1, [inptr1], #0x10\n" - "prfm pldl1keep, [inptr9, #196]\n" - "fsub v17.4s, v17.4s, v7.4s\n" - - "ldr q2, [inptr2], #0x10\n" - "prfm pldl1keep, [inptr10, #196]\n" - "str q16, [outptr2], #0x10\n" - - "ldr q3, [inptr3], #0x10\n" - "prfm pldl1keep, [inptr11, #196]\n" - "str q17, [outptr3], #0x10\n" - - "ldr q4, [inptr4], #0x10\n" - "prfm pldl1keep, [inptr12, #196]\n" - "fadd v16.4s, v8.4s, v9.4s\n" - - "ldr q5, [inptr5], #0x10\n" - "prfm pldl1keep, [inptr13, #196]\n" - "fsub v17.4s, v9.4s, v10.4s\n" - - "ldr q6, [inptr6], #0x10\n" - "prfm pldl1keep, [inptr14, #196]\n" - "fadd v16.4s, v16.4s, v10.4s\n" - - "ldr q7, [inptr7], #0x10\n" - "prfm pldl1keep, [inptr15, #196]\n" - "fsub v17.4s, v17.4s, v11.4s\n" - - "str q16, [outptr4], #0x10\n" - "fadd v16.4s, v12.4s, v13.4s\n" - "fsub v18.4s, v13.4s, v14.4s\n" - - "str q17, [outptr5], #0x10\n" - "fadd v16.4s, v16.4s, v14.4s\n" - "fsub v18.4s, v18.4s, v15.4s\n" - - "str q16, [outptr6], #0x10\n" - "subs column, column, #0x4\n" - - "str q18, [outptr7], #0x10\n" - "bne 2b\n" - - "3:" // Tail - "ldr q8, [inptr8], #0x10\n" - "prfm pldl1keep, [%x[inptr], #196]\n" - "fadd v16.4s, v0.4s, v1.4s\n" - - "ldr q9, [inptr9], #0x10\n" - "prfm pldl1keep, [inptr1, #196]\n" - "fsub v17.4s, v1.4s, v2.4s\n" - - "ldr q10, [inptr10], #0x10\n" - "prfm pldl1keep, [inptr2, #196]\n" - "fadd v16.4s, v16.4s, v2.4s\n" - - "ldr q11, [inptr11], #0x10\n" - "prfm pldl1keep, [inptr3, #196]\n" - "fsub v17.4s, v17.4s, v3.4s\n" - - "ldr q12, [inptr12], #0x10\n" - "prfm pldl1keep, [inptr4, #196]\n" - "str q16, [%x[outptr]], #0x10\n" - - "ldr q13, [inptr13], #0x10\n" - "prfm pldl1keep, [inptr5, #196]\n" - "str q17, [outptr1], #0x10\n" - - "ldr q14, [inptr14], #0x10\n" - "prfm pldl1keep, [inptr6, #196]\n" - "fadd v16.4s, v4.4s, v5.4s\n" - - "ldr q15, [inptr15], #0x10\n" - "prfm pldl1keep, [inptr7, #196]\n" - "fsub v17.4s, v5.4s, v6.4s\n" - - "prfm pldl1keep, [inptr8, #196]\n" - "prfm pldl1keep, [inptr9, #196]\n" - "fadd v16.4s, v16.4s, v6.4s\n" - - "prfm pldl1keep, [inptr10, #196]\n" - "prfm pldl1keep, [inptr11, #196]\n" - "fsub v17.4s, v17.4s, v7.4s\n" - - "prfm pldl1keep, [inptr12, #196]\n" - "prfm pldl1keep, [inptr13, #196]\n" - "str q16, [outptr2], #0x10\n" - - "prfm pldl1keep, [inptr14, #196]\n" - "prfm pldl1keep, [inptr15, #196]\n" - "str q17, [outptr3], #0x10\n" - - "fadd v16.4s, v8.4s, v9.4s\n" - "fsub v17.4s, v9.4s, v10.4s\n" - - "fadd v16.4s, v16.4s, v10.4s\n" - "fsub v17.4s, v17.4s, v11.4s\n" - - "str q16, [outptr4], #0x10\n" - "fadd v16.4s, v12.4s, v13.4s\n" - "fsub v18.4s, v13.4s, v14.4s\n" - - "str q17, [outptr5], #0x10\n" - "fadd v16.4s, v16.4s, v14.4s\n" - "fsub v18.4s, v18.4s, v15.4s\n" - - "str q16, [outptr6], #0x10\n" - "str q18, [outptr7], #0x10\n" - - "subs %x[row], %x[row], #0x1\n" - "bne 1b\n" - - ".unreq inptr1\n" - ".unreq inptr2\n" - ".unreq inptr3\n" - ".unreq inptr4\n" - ".unreq inptr5\n" - ".unreq inptr6\n" - ".unreq inptr7\n" - ".unreq inptr8\n" - ".unreq inptr9\n" - ".unreq inptr10\n" - ".unreq inptr11\n" - ".unreq inptr12\n" - ".unreq inptr13\n" - ".unreq inptr14\n" - ".unreq inptr15\n" - ".unreq outptr1\n" - ".unreq outptr2\n" - ".unreq outptr3\n" - ".unreq outptr4\n" - ".unreq outptr5\n" - ".unreq outptr6\n" - ".unreq outptr7\n" - - : [row] "+r" (row), - [inptr] "+r" (inptr), - [outptr] "+r" (outptr) - : [n_channels] "r" (n_channels), - [sizeof_float] "i" (sizeof(float)) - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", - "q12", "q13", "q14", "q15", "q16", "q17", "x0", "x1", "x2", "x3", "x4", - "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", - "x16", "x17", "x18", "x19", "x20", "x21", "x22", "cc", "memory" - ); -} - -/*****************************************************************************/ -// Compute ZFZ^T specializations - -template <> -template <> -inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zfzT( - const Tensor4DShape &output_shape, - float* const output, const float* const input -) { - const int tile_M = output_shape.n_rows / 2; - const int tile_N = output_shape.n_cols / 2; - int batch = output_shape.n_batches; - float *outptr = output; - const float *inptr = input; - - asm volatile ( - // Compute input pointers - "inptr1 .req x0\n" - "inptr2 .req x1\n" - "inptr3 .req x2\n" - "inptr4 .req x3\n" - "inptr5 .req x4\n" - "inptr6 .req x5\n" - "inptr7 .req x6\n" - "inptr8 .req x7\n" - - "mstride .req x8\n" - "mul mstride, %x[tile_M], %x[tile_N]\n" - "mul mstride, mstride, %x[n_channels]\n" - "lsl mstride, mstride, #2\n" // * sizeof(float) - - "add inptr1, %[inptr], mstride\n" - "add inptr2, inptr1, mstride\n" - "add inptr3, inptr2, mstride\n" - "add inptr4, inptr3, mstride\n" - "add inptr5, inptr4, mstride\n" - "add inptr6, inptr5, mstride\n" - "add inptr7, inptr6, mstride\n" - "add inptr8, inptr7, mstride\n" - - ".unreq mstride\n" - - // Compute initial output pointers - "outptr01 .req x8\n" - "outptr10 .req x9\n" - "outptr11 .req x10\n" - - "add outptr01, %x[outptr], %x[n_channels], LSL #2\n" - "add outptr10, %x[outptr], %x[row_stride], LSL #2\n" - "add outptr11, outptr10, %x[n_channels], LSL #2\n" - - "tile_i .req x11\n" - "tile_j .req x12\n" - "channel .req x13\n" - - "1:" // Loop over batches - "mov tile_i, %x[tile_M]\n" - - "2:" // Loop over rows of output tiles - "mov tile_j, %x[tile_N]\n" - - "3:" // Loop over columns of output tiles - "ldr q0, [%x[inptr]], #0x10\n" - "ldr q2, [inptr2], #0x10\n" - "subs channel, %x[n_channels], #0x4\n" - - "ldr q1, [inptr1], #0x10\n" - "ldr q3, [inptr3], #0x10\n" - "beq 6f\n" - - "4:" - "ldr q4, [inptr4], #0x10\n" - "ldr q5, [inptr5], #0x10\n" - "fadd v16.4s, v0.4s, v2.4s\n" - - "ldr q6, [inptr6], #0x10\n" - "ldr q7, [inptr7], #0x10\n" - "fadd v17.4s, v1.4s, v3.4s\n" - - "ldr q8, [%x[inptr]], #0x10\n" - "ldr q10, [inptr2], #0x10\n" - "fadd v16.4s, v16.4s, v4.4s\n" - - "ldr q9, [inptr1], #0x10\n" - "ldr q11, [inptr3], #0x10\n" - "fadd v17.4s, v17.4s, v5.4s\n" - - "str q16, [%x[outptr]], #0x10\n" - "prfm pldl1strm, [%x[inptr], #196]\n" - "fsub v18.4s, v2.4s, v4.4s\n" - - "str q17, [outptr01], #0x10\n" - "prfm pldl1strm, [inptr2, #196]\n" - "fsub v19.4s, v3.4s, v5.4s\n" - - "prfm pldl1strm, [inptr1, #196]\n" - "prfm pldl1strm, [inptr3, #196]\n" - "fsub v18.4s, v18.4s, v6.4s\n" - - "prfm pldl1strm, [inptr4, #196]\n" - "prfm pldl1strm, [inptr5, #196]\n" - "fsub v19.4s, v19.4s, v7.4s\n" - - "str q18, [outptr10], #0x10\n" - "prfm pldl1strm, [inptr6, #196]\n" - "prfm pldl1strm, [inptr7, #196]\n" - - "subs channel, channel, #0x4\n" - - "str q19, [outptr11], #0x10\n" - "beq 6f\n" // Branch to tail - - "ldr q12, [inptr4], #0x10\n" - "ldr q13, [inptr5], #0x10\n" - "fadd v16.4s, v8.4s, v10.4s\n" - - "ldr q14, [inptr6], #0x10\n" - "ldr q15, [inptr7], #0x10\n" - "fadd v17.4s, v9.4s, v11.4s\n" - - "ldr q0, [%x[inptr]], #0x10\n" - "ldr q2, [inptr2], #0x10\n" - "fadd v16.4s, v16.4s, v12.4s\n" - - "ldr q1, [inptr1], #0x10\n" - "ldr q3, [inptr3], #0x10\n" - "fadd v17.4s, v17.4s, v13.4s\n" - - "str q16, [%x[outptr]], #0x10\n" - "prfm pldl1strm, [%x[inptr], #196]\n" - "fsub v18.4s, v10.4s, v12.4s\n" - - "str q17, [outptr01], #0x10\n" - "prfm pldl1strm, [inptr2, #196]\n" - "fsub v19.4s, v11.4s, v13.4s\n" - - "prfm pldl1strm, [inptr1, #196]\n" - "prfm pldl1strm, [inptr3, #196]\n" - "fsub v18.4s, v18.4s, v14.4s\n" - - "prfm pldl1strm, [inptr4, #196]\n" - "prfm pldl1strm, [inptr5, #196]\n" - "fsub v19.4s, v19.4s, v15.4s\n" - - "str q18, [outptr10], #0x10\n" - "prfm pldl1strm, [inptr6, #196]\n" - "prfm pldl1strm, [inptr7, #196]\n" - - "subs channel, channel, #0x4\n" - - "str q19, [outptr11], #0x10\n" - "bne 4b\n" // Continue loop - - "5:" // Tail - "ldr q12, [inptr4], #0x10\n" - "ldr q13, [inptr5], #0x10\n" - "fadd v16.4s, v8.4s, v10.4s\n" - - "ldr q14, [inptr6], #0x10\n" - "ldr q15, [inptr7], #0x10\n" - "fadd v17.4s, v9.4s, v11.4s\n" - - "fadd v16.4s, v16.4s, v12.4s\n" - - "fadd v17.4s, v17.4s, v13.4s\n" - - "str q16, [%x[outptr]], #0x10\n" - "fsub v18.4s, v10.4s, v12.4s\n" - "fsub v19.4s, v11.4s, v13.4s\n" - - "str q17, [outptr01], #0x10\n" - "fsub v18.4s, v18.4s, v14.4s\n" - "fsub v19.4s, v19.4s, v15.4s\n" - - "str q18, [outptr10], #0x10\n" - "str q19, [outptr11], #0x10\n" - "b 7f\n" - - "6:" // Tail - "ldr q4, [inptr4], #0x10\n" - "ldr q5, [inptr5], #0x10\n" - "fadd v16.4s, v0.4s, v2.4s\n" - - "ldr q6, [inptr6], #0x10\n" - "ldr q7, [inptr7], #0x10\n" - "fadd v17.4s, v1.4s, v3.4s\n" - - "fadd v16.4s, v16.4s, v4.4s\n" - - "fadd v17.4s, v17.4s, v5.4s\n" - - "str q16, [%x[outptr]], #0x10\n" - "fsub v18.4s, v2.4s, v4.4s\n" - "fsub v19.4s, v3.4s, v5.4s\n" - - "str q17, [outptr01], #0x10\n" - "fsub v18.4s, v18.4s, v6.4s\n" - "fsub v19.4s, v19.4s, v7.4s\n" - - "str q18, [outptr10], #0x10\n" - "str q19, [outptr11], #0x10\n" - - "7:" - "add %x[outptr], %x[outptr], %x[n_channels], LSL #2\n" - "add outptr01, outptr01, %x[n_channels], LSL #2\n" - "add outptr10, outptr10, %x[n_channels], LSL #2\n" - "add outptr11, outptr11, %x[n_channels], LSL #2\n" - - "subs tile_j, tile_j, #1\n" - "bne 3b\n" - - // Progress the output pointers to the new row - "add %x[outptr], %x[outptr], %x[row_stride], LSL #2\n" - "add outptr01, outptr01, %x[row_stride], LSL #2\n" - "add outptr10, outptr10, %x[row_stride], LSL #2\n" - "add outptr11, outptr11, %x[row_stride], LSL #2\n" - - "subs tile_i, tile_i, #1\n" - "bne 2b\n" - - "subs %[batch], %[batch], #1\n" - "bne 1b\n" - "5:" - - ".unreq inptr1\n" - ".unreq inptr2\n" - ".unreq inptr3\n" - ".unreq inptr4\n" - ".unreq inptr5\n" - ".unreq inptr6\n" - ".unreq inptr7\n" - ".unreq inptr8\n" - ".unreq outptr01\n" - ".unreq outptr10\n" - ".unreq outptr11\n" - : [batch] "+r" (batch), - [outptr] "+r" (outptr), - [inptr] "+r" (inptr) - : [tile_M] "r" (tile_M), - [tile_N] "r" (tile_N), - [n_channels] "r" (output_shape.n_channels), - [row_stride] "r" (output_shape.n_cols * output_shape.n_channels) - : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", - "x12", "x13", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "cc", "memory" - ); -} -/*****************************************************************************/ - -/*****************************************************************************/ -template <> -inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::execute( - const Tensor4DShape &output_shape, - float* const matrices[16], float* const output -) { - // profiler prof; - - // Allocate memory for the intermediate matrices - const int tile_M = iceildiv(output_shape.n_rows, 2); - const int tile_N = iceildiv(output_shape.n_cols, 2); - const int n_rows = output_shape.n_batches * tile_M * tile_N; - const int n_channels = output_shape.n_channels; - float* matrices_zf = reinterpret_cast( - calloc(8 * n_rows * n_channels, sizeof(float)) - ); - - // Perform the first stage transform, computing ZF. - const auto f_compute_zf = [&] () { - switch (n_channels % 4) { - case 0: - compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); - break; - case 1: - compute_zf<1>(n_rows, n_channels, matrices_zf, matrices); - break; - case 2: - compute_zf<2>(n_rows, n_channels, matrices_zf, matrices); - break; - case 3: - compute_zf<3>(n_rows, n_channels, matrices_zf, matrices); - }; - }; - // prof("Compute ZF", f_compute_zf, 16 * n_rows * n_channels * sizeof(float), 0, 8 * n_rows * n_channels * sizeof(float)); - f_compute_zf(); - - // Perform the second stage transform, finishing Z F Z^T - variable dispatch - // based on size of the output and the channel tail. - const auto f_compute_zfzT = [&] () { - if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { - constexpr bool tail_M = true, tail_N = true; - switch (n_channels % 4) { - case 0: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 1: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 2: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 3: - compute_zfzT(output_shape, output, matrices_zf); - } - } else if (output_shape.n_rows % 2) { - constexpr bool tail_M = true, tail_N = false; - switch (n_channels % 4) { - case 0: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 1: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 2: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 3: - compute_zfzT(output_shape, output, matrices_zf); - } - } else if (output_shape.n_cols % 2) { - constexpr bool tail_M = false, tail_N = true; - switch (n_channels % 4) { - case 0: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 1: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 2: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 3: - compute_zfzT(output_shape, output, matrices_zf); - } - } else { - constexpr bool tail_M = false, tail_N = false; - switch (n_channels % 4) { - case 0: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 1: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 2: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 3: - compute_zfzT(output_shape, output, matrices_zf); - } - } - }; - // prof("Compute ZFZT", f_compute_zfzT, 8 * n_rows * n_channels * sizeof(float), 0, 4 * n_rows * n_channels * sizeof(float)); - f_compute_zfzT(); - - free(reinterpret_cast(matrices_zf)); -} -/*****************************************************************************/ - -#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/utils.hpp b/arm_compute/core/NEON/kernels/winograd/utils.hpp deleted file mode 100644 index 14e709f028..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/utils.hpp +++ /dev/null @@ -1,55 +0,0 @@ - -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include - -inline double TimeInUs(void) { -#ifdef CYCLE_PROFILING - timespec t; - clock_gettime(CLOCK_THREAD_CPUTIME_ID, &t); - return 1e6*t.tv_sec + 1e-3*t.tv_nsec; -#else - return 0; -#endif -} - -inline int iceildiv(const int a, const int b) { - return (a + b - 1) / b; -} - -template -inline T roundup(const T a, const T b) { - return a + b - (a % b); -} - -inline void PrintMatrix(const float* const m, const int M, const int N, const int row_stride) { - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - printf("%.3f ", m[i*row_stride + j]); - } - printf("\n"); - } - printf("\n"); -} diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp deleted file mode 100644 index c990cd0252..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp +++ /dev/null @@ -1,346 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include -#include -#include - -#include "alloc.hpp" -#include "gemm.hpp" -#include "profiler.hpp" -#include "utils.hpp" -#include "shims.hpp" - -#include "transforms.hpp" - -namespace winograd { - /***************************************************************************/ - /* Implementation of the Winograd F(2x2, 3x3, 4x4) algorithm using GEMM - * internally. - */ - template - class Winograd2x2_3x3GEMM { - public: - /* Instantiate a new Winograd operator. - */ - Winograd2x2_3x3GEMM(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage); - virtual ~Winograd2x2_3x3GEMM(); - - /** Transform the weights into the Winograd domain. - */ - template > - void transform_weights(const TIn* const kernel, void *transform_working_space); - - /* Initializes matrices pointers, to be called once before execute() - */ - template > - void reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const TIn* const input, void* working_space); - - /* Apply the Winograd operator to some input. - */ - template > - void reshape_output(const Tensor4DShape& input_shape, const PaddingType padding_type, TOut* const output); - - - /* Apply the Winograd operator to some input. - */ - void execute(size_t first, size_t last); - - /* Get the memory required to transform the kernel. - */ - static inline size_t get_kernel_transform_working_size(const KernelShape &shape); - - /* Get the output shape of a convolution. - */ - static Tensor4DShape get_output_shape(const Tensor4DShape &input_shape, const KernelShape &k_shape, - const PaddingType padding_type); - - /* Get the memory required to instantiate a new Winograd operator. - */ - static size_t get_kernel_storage_size(const KernelShape &shape); - - /* Get the memory required to apply a Winograd operator to some input. - */ - static size_t get_working_space_size(const Tensor4DShape &input_shape,const KernelShape &k_shape, - const PaddingType padding); - - - Winograd2x2_3x3GEMM(const Winograd2x2_3x3GEMM &) = delete; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - Winograd2x2_3x3GEMM &operator=(const Winograd2x2_3x3GEMM &) = delete; - /** Allow instances of this class to be moved */ - Winograd2x2_3x3GEMM(Winograd2x2_3x3GEMM &&) = default; - /** Allow instances of this class to be moved */ - Winograd2x2_3x3GEMM &operator=(Winograd2x2_3x3GEMM &&) = default; - - protected: - /* Get the memory required by a single "input" matrix. - */ - static size_t get_input_matrix_size(const Tensor4DShape &input_shape,const KernelShape &k_shape, - const PaddingType padding); - - /* Get the memory required by a single "output" matrix. - */ - static size_t get_output_matrix_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, - const PaddingType padding); - - /* Get the memory required by a single "kernel" matrix. - */ - static size_t get_kernel_matrix_size(const KernelShape &shape); - - const KernelShape kernel_shape; // Shape of applied kernel - const Tensor4DShape in_shape; - const PaddingType padding; - - const int kernel_matrix_row_stride; // Stride within kernel matrix - - const bool manage_kernel_storage; // Free kernel storage when done - void* const _kernel_storage; // Base pointer for kernel matrices - - profiler prof; // Profiler - - TIn *kernel_matrices[16]; // Prepared form of kernel - TIn *input_matrices[16]; - TOut *output_matrices[16]; - - - static const int M_BLOCK = 4; - static const int N_BLOCK = 16; - }; -} // namespace winograd - -template -size_t winograd::Winograd2x2_3x3GEMM::get_kernel_transform_working_size( - const KernelShape &shape -) -{ - // Need to re-order the kernel into HWIO form, require enough space to - // represent the tensor. - return sizeof(TIn) * shape.size(); -} - - -template -template -void winograd::Winograd2x2_3x3GEMM::transform_weights( - const TIn* const kernel, - void *transform_working_space -) -{ - const int kernel_matrix_size_bytes = get_kernel_matrix_size(kernel_shape); - int8_t* const ks_bytes = reinterpret_cast(_kernel_storage); - for (int i = 0; i < 16; i++) { - kernel_matrices[i] = reinterpret_cast( - ks_bytes + i*kernel_matrix_size_bytes); - } - - const TIn *kernel_hwio = kernel; - if( transform_working_space) - { - kernel_hwio = reinterpret_cast(transform_working_space); - ofm_ifm_h_w_to_h_w_ifm_ofm( - kernel, const_cast(kernel_hwio), - kernel_shape.n_output_channels, - kernel_shape.n_input_channels, - kernel_shape.n_rows, - kernel_shape.n_cols - ); - } - KernelTransform::execute( - kernel_shape, kernel_hwio, kernel_matrices[0], - kernel_matrix_size_bytes / sizeof(TIn), - kernel_matrix_row_stride - ); -} - -template -winograd::Winograd2x2_3x3GEMM::Winograd2x2_3x3GEMM( const KernelShape &kernel_shape, const Tensor4DShape input_shape, - const PaddingType padding_type, void *kernel_storage) - : kernel_shape(kernel_shape), in_shape(input_shape), padding(padding_type),kernel_matrix_row_stride(roundup(kernel_shape.n_output_channels, N_BLOCK)), manage_kernel_storage(false), - _kernel_storage(kernel_storage), prof() { - memset(kernel_matrices, 0x00, sizeof(TIn)*16); - memset(input_matrices, 0x00, sizeof(TIn)*16); - memset(output_matrices, 0x00, sizeof(TOut)*16); -} - -/*****************************************************************************/ -template -winograd::Winograd2x2_3x3GEMM::~Winograd2x2_3x3GEMM() {} - -/*****************************************************************************/ -template -template -void winograd::Winograd2x2_3x3GEMM::reshape_input( - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const TIn* const input, - void *working_space -) { - assert(working_space); - int8_t* const ws_bytes = reinterpret_cast(working_space); - // Split the working space into that required for 16 input matrices and - // output matrices. - const int in_matrix_stride_bytes = get_input_matrix_size(input_shape, kernel_shape, padding_type); - const int out_matrix_stride_bytes = get_output_matrix_size(input_shape, kernel_shape, padding_type); - - for (int i = 0; i < 16; i++) { - input_matrices[i] = reinterpret_cast( - ws_bytes + i*in_matrix_stride_bytes); - output_matrices[i] = reinterpret_cast( - ws_bytes + 16*in_matrix_stride_bytes + i*out_matrix_stride_bytes); - } - - // Compute shape for the GEMM - const auto output_shape = get_output_shape(input_shape,kernel_shape, padding_type); - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - const int K = kernel_shape.n_input_channels; - - const int in_matrix_row_stride = K; - const int in_matrix_batch_stride = tile_rows*tile_cols*in_matrix_row_stride; - - // Transform the input tensor into an appropriate form - auto input_prep = [&] () { - InputTransform::execute( - input, input_shape, padding_type, tile_rows, tile_cols, - input_matrices[0], in_matrix_stride_bytes / sizeof(TIn), - in_matrix_batch_stride, in_matrix_row_stride - ); - }; - prof( - "Input Prep", input_prep, - InputTransform::bytes_read(input_shape, output_shape), - InputTransform::flops_performed(input_shape, output_shape), - InputTransform::bytes_written(input_shape, output_shape) - ); - -} - -/*****************************************************************************/ -template -template -void winograd::Winograd2x2_3x3GEMM::reshape_output(const Tensor4DShape& input_shape, const PaddingType padding_type, TOut* const output) { - assert(output_matrices[0]); - const int out_matrix_stride_bytes = get_output_matrix_size(input_shape, kernel_shape, padding_type); - const auto output_shape = get_output_shape(input_shape,kernel_shape, padding_type); - const int out_matrix_row_stride = kernel_matrix_row_stride; - - // Transform the output tensor into an appropriate form - OutputTransform::execute( - output_shape, - output_matrices[0], - out_matrix_stride_bytes / sizeof(TOut), - out_matrix_row_stride, - output - ); -} - - -/*****************************************************************************/ -template -void winograd::Winograd2x2_3x3GEMM::execute( size_t first, size_t last ) { - assert(input_matrices[0] && kernel_matrices[0] && output_matrices[0]); - assert(first < 16 && last < 16 && first < last); - // Compute shape for the GEMM - const auto output_shape = get_output_shape(in_shape,kernel_shape, padding); - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - const int M = in_shape.n_batches * tile_rows * tile_cols; - const int K = kernel_shape.n_input_channels; - const int N = kernel_shape.n_output_channels; - - const int in_matrix_row_stride = K; - const int out_matrix_row_stride = kernel_matrix_row_stride; - // Perform the GEMMs - for (size_t i = first; i <= last; i++) { - BlockedGemm( - input_matrices[i], kernel_matrices[i], output_matrices[i], M, K, N, - in_matrix_row_stride, kernel_matrix_row_stride, out_matrix_row_stride - ); -// prof("GEMM", perform_gemm, 0, 2*M*K*N, 0); // TODO Memory - } - -} - -/*****************************************************************************/ -template -Tensor4DShape winograd::Winograd2x2_3x3GEMM::get_output_shape( - const Tensor4DShape &in_shape, const KernelShape &k_shape, const PaddingType padding) { - return Tensor4DShape { - in_shape.n_batches, - (padding == PADDING_SAME) ? in_shape.n_rows : in_shape.n_rows - 2, - (padding == PADDING_SAME) ? in_shape.n_cols : in_shape.n_cols - 2, - k_shape.n_output_channels - }; -} - -template -size_t winograd::Winograd2x2_3x3GEMM::get_kernel_storage_size( - const KernelShape &shape) { - return 16 * get_kernel_matrix_size(shape); -} - -template -size_t winograd::Winograd2x2_3x3GEMM::get_kernel_matrix_size( - const KernelShape &shape) { - const int K = shape.n_input_channels; - const int N = roundup(shape.n_output_channels, N_BLOCK); - return sizeof(TIn) * K * N; -} - -template -size_t winograd::Winograd2x2_3x3GEMM::get_working_space_size( - const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type -) { - return 16 * get_input_matrix_size(input_shape, k_shape, padding_type) + - 16 * get_output_matrix_size(input_shape, k_shape, padding_type); -} - -template -size_t winograd::Winograd2x2_3x3GEMM::get_input_matrix_size( - const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type -) { - // Compute shape for the GEMM - const auto output_shape = get_output_shape(input_shape, k_shape, padding_type); - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - const int M = roundup(tile_rows * tile_cols, M_BLOCK); - const int K = k_shape.n_input_channels; - - return input_shape.n_batches * M * K * sizeof(TIn); -} - -template -size_t winograd::Winograd2x2_3x3GEMM::get_output_matrix_size( - const Tensor4DShape& input_shape, const KernelShape &k_shape,const PaddingType padding_type -) { - // Compute shape for the GEMM - const auto output_shape = get_output_shape(input_shape, k_shape, padding_type); - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - const int M = roundup(tile_rows * tile_cols, M_BLOCK); - const int N = roundup(k_shape.n_output_channels, N_BLOCK); - - return input_shape.n_batches * M * N * sizeof(TOut); -} diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp deleted file mode 100644 index 4c7e291c58..0000000000 --- a/arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include -#include - -#include "alloc.hpp" -#include "gemm.hpp" -#include "profiler.hpp" -#include "utils.hpp" -#include "shims.hpp" -#include "winograd_gemm.hpp" - -#include "transforms.hpp" - -#ifndef ALLOC_ALIGN -#define ALLOC_ALIGN 64 -#endif // ALLOC_ALIGN - - -namespace winograd_shim_nchw { - /***************************************************************************/ - /* Implementation of the Winograd F(2x2, 3x3, 4x4) algorithm using GEMM - * internally. - */ - template - class Winograd2x2_3x3GEMM : public winograd::Winograd2x2_3x3GEMM { - public: - /* Instantiate a new Winograd operator. - */ - Winograd2x2_3x3GEMM(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage); - - void nchw2nhwc( const Tensor4DShape& input_shape, const PaddingType padding_type, void *working_space, const TIn* const input); - void nhwc2nchw( const Tensor4DShape& input_shape, const PaddingType padding_type, void *working_space, TOut* const output); - - - std::pair get_nhwc_ptrs(const Tensor4DShape& input_shape,const PaddingType padding_type,void *working_space); - - static size_t get_working_space_size(const Tensor4DShape &input_shape,const KernelShape &k_shape, const PaddingType padding); - protected: - /* Get the memory required to store an NHWC copy of the input tensor. */ - static size_t get_working_nhwc_input_size(const Tensor4DShape &input_shape); - - /* Get the memory required to store an NHWC copy of the input tensor. */ - static size_t get_working_nhwc_output_size(const Tensor4DShape &output_shape, const KernelShape &k_shape, const PaddingType padding) ; - }; -} // namespace winograd - -/*****************************************************************************/ -template -winograd_shim_nchw::Winograd2x2_3x3GEMM::Winograd2x2_3x3GEMM( - const KernelShape &kernel_shape, const Tensor4DShape input_shape, - const PaddingType padding_type, void *kernel_storage -) : winograd::Winograd2x2_3x3GEMM(kernel_shape,input_shape,padding_type,kernel_storage) { -} - -/*****************************************************************************/ -template -void winograd_shim_nchw::Winograd2x2_3x3GEMM::nchw2nhwc(const Tensor4DShape& input_shape, const PaddingType padding_type, void *working_space, const TIn* const input) { - assert(working_space); - int8_t* const ws_bytes = reinterpret_cast(working_space); - - // Extract the top chunk of the working space to store the input and output - // tensors in NHWC format. - const int in_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_input_matrix_size(input_shape, this->kernel_shape, padding_type); - const int out_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_output_matrix_size(input_shape, this->kernel_shape, padding_type); - - // Allocate working space for the input and output in NHWC format - TIn* const input_nhwc = reinterpret_cast( - ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes) - ); - - // Re-order the input tensor - this->prof( - "NCHW -> NHWC", - [input, input_shape, input_nhwc] () { - nchw_to_nhwc( - input, input_nhwc, - input_shape.n_batches, - input_shape.n_channels, - input_shape.n_rows, - input_shape.n_cols - ); - }, - input_shape.size(), 0, input_shape.size() - ); -} - -/*****************************************************************************/ -template -void winograd_shim_nchw::Winograd2x2_3x3GEMM::nhwc2nchw(const Tensor4DShape& input_shape, const PaddingType padding_type, - void *working_space, TOut* const output) { - - assert(working_space); - int8_t* const ws_bytes = reinterpret_cast(working_space); - - // Extract the top chunk of the working space to store the input and output - // tensors in NHWC format. - const int in_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_input_matrix_size(input_shape, this->kernel_shape, padding_type); - const int out_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_output_matrix_size(input_shape, this->kernel_shape, padding_type); - - TOut* const output_nhwc = reinterpret_cast(ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes) + get_working_nhwc_input_size(input_shape)); - - // Re-order the output tensor into NCHW - const auto output_shape = winograd::Winograd2x2_3x3GEMM::get_output_shape(input_shape, this->kernel_shape, padding_type); - this->prof( - "NHWC -> NCHW", - [output_nhwc, output_shape, output] () { - nhwc_to_nchw( - output_nhwc, output, - output_shape.n_batches, - output_shape.n_rows, - output_shape.n_cols, - output_shape.n_channels - ); - }, - output_shape.size(), 0, output_shape.size() - ); -} - - -/*****************************************************************************/ -template -std::pair winograd_shim_nchw::Winograd2x2_3x3GEMM::get_nhwc_ptrs( - const Tensor4DShape& input_shape, - const PaddingType padding_type, - void *working_space -) { - assert(working_space); - int8_t* const ws_bytes = reinterpret_cast(working_space); - - // Extract the top chunk of the working space to store the input and output - // tensors in NHWC format. - const int in_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_input_matrix_size(input_shape, this->kernel_shape, padding_type); - const int out_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_output_matrix_size(input_shape, this->kernel_shape, padding_type); - - // Allocate working space for the input and output in NHWC format - TIn* input_nhwc = reinterpret_cast(ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes)); - TOut* output_nhwc = reinterpret_cast(ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes) + get_working_nhwc_input_size(input_shape)); - return std::make_pair(output_nhwc,input_nhwc); -} - - - - -/*****************************************************************************/ -template -size_t winograd_shim_nchw::Winograd2x2_3x3GEMM::get_working_space_size( - const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type -) { - // TODO Add memory required for NHWC copies of input tensors - return winograd::Winograd2x2_3x3GEMM::get_working_space_size( - input_shape, k_shape, padding_type) - + get_working_nhwc_input_size(input_shape) - + get_working_nhwc_output_size(input_shape, k_shape, padding_type); -} - -template -size_t winograd_shim_nchw::Winograd2x2_3x3GEMM::get_working_nhwc_input_size( - const Tensor4DShape& input_shape -) { - return roundup(input_shape.size() * sizeof(TIn), static_cast(ALLOC_ALIGN)); -} - -template -size_t winograd_shim_nchw::Winograd2x2_3x3GEMM::get_working_nhwc_output_size( - const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type -) { - const auto output_shape = winograd::Winograd2x2_3x3GEMM::get_output_shape(input_shape,k_shape, padding_type); - return roundup(output_shape.size() * sizeof(TIn), static_cast(ALLOC_ALIGN)); -} diff --git a/arm_compute/runtime/NEON/functions/NEWinogradLayer.h b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h index 7dca4570e5..77707060ec 100644 --- a/arm_compute/runtime/NEON/functions/NEWinogradLayer.h +++ b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h @@ -67,8 +67,6 @@ public: NEWinogradLayer &operator=(const NEWinogradLayer &) = delete; private: - using Winograd3x3F32 = NEWinogradLayerKernel::Winograd3x3F32; - MemoryGroup _memory_group; NEWinogradLayerKernel _winograd_kernel; Tensor _weights_workspace; diff --git a/scripts/clang_tidy_rules.py b/scripts/clang_tidy_rules.py index 5b27dd5be5..7a13d045e7 100755 --- a/scripts/clang_tidy_rules.py +++ b/scripts/clang_tidy_rules.py @@ -91,6 +91,7 @@ def filter_clang_tidy_lines( lines ): ("parameter 'memory_manager' is unused" in line) or ("parameter 'memory_manager' is copied for each invocation but only used as a const reference" in line) or ("DeconvolutionLayer.cpp" in line and "casting (double + 0.5) to integer leads to incorrect rounding; consider using lround" in line) or + ("NEWinogradLayerKernel.cpp" in line and "use '= default' to define a trivial destructor" in line) or "3rdparty" in line): print_context=False continue diff --git a/src/core/NEON/kernels/NEWinogradLayerKernel.cpp b/src/core/NEON/kernels/NEWinogradLayerKernel.cpp index b9109dcff2..fe633368c0 100644 --- a/src/core/NEON/kernels/NEWinogradLayerKernel.cpp +++ b/src/core/NEON/kernels/NEWinogradLayerKernel.cpp @@ -27,9 +27,86 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" #include "arm_compute/core/TensorInfo.h" +#include "support/ToolchainSupport.h" + +#include "src/core/NEON/kernels/winograd/winograd_shim_nchw.hpp" + +using T = winograd_shim_nchw::Winograd2x2_3x3GEMM; namespace arm_compute { +class Winograd3x3F32::Private +{ +public: + Private(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage) + : convolver(kernel_shape, input_shape, padding_type, kernel_storage) + { + } + + T convolver; +}; + +Winograd3x3F32::~Winograd3x3F32() +{ +} + +void Winograd3x3F32::nchw2nhwc(const Tensor4DShape &input_shape, const PaddingType padding_type, void *working_space, const void *const input) +{ + _pimpl->convolver.nchw2nhwc(input_shape, padding_type, working_space, reinterpret_cast(input)); +} + +void Winograd3x3F32::nhwc2nchw(const Tensor4DShape &input_shape, const PaddingType padding_type, void *working_space, void *const output) +{ + _pimpl->convolver.nhwc2nchw(input_shape, padding_type, working_space, reinterpret_cast(output)); +} + +void Winograd3x3F32::transform_weights(const void *const kernel, void *transform_working_space) +{ + _pimpl->convolver.transform_weights(reinterpret_cast(kernel), transform_working_space); +} + +void Winograd3x3F32::reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const void *const input, void *working_space) +{ + _pimpl->convolver.reshape_input(input_shape, padding_type, reinterpret_cast(input), working_space); +} + +void Winograd3x3F32::reshape_output(const Tensor4DShape &input_shape, const PaddingType padding_type, void *const output) +{ +#if defined(__aarch64__) + _pimpl->convolver.reshape_output(input_shape, padding_type, reinterpret_cast(output)); +#else /* __aarch64__ */ + ARM_COMPUTE_UNUSED(input_shape); + ARM_COMPUTE_UNUSED(padding_type); + ARM_COMPUTE_UNUSED(output); + ARM_COMPUTE_ERROR("Not implemented"); +#endif /* __aarch64__ */ +} + +std::pair Winograd3x3F32::get_nhwc_ptrs(const Tensor4DShape &input_shape, const PaddingType padding_type, void *working_space) +{ + return _pimpl->convolver.get_nhwc_ptrs(input_shape, padding_type, working_space); +} + +Winograd3x3F32::Winograd3x3F32(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage) + : _pimpl(support::cpp14::make_unique(kernel_shape, input_shape, padding_type, kernel_storage)) +{ +} + +size_t NEWinogradLayerKernel::get_kernel_storage_size(const KernelShape &shape) +{ + return T::get_kernel_storage_size(shape); +} + +size_t NEWinogradLayerKernel::get_working_space_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, const PaddingType padding) +{ + return T::get_working_space_size(input_shape, k_shape, padding); +} + +size_t NEWinogradLayerKernel::get_kernel_transform_working_size(const KernelShape &shape) +{ + return T::get_kernel_transform_working_size(shape); +} + NEWinogradLayerKernel::NEWinogradLayerKernel() : _convolver(nullptr), _output(nullptr) { @@ -55,6 +132,6 @@ void NEWinogradLayerKernel::run(const Window &window, const ThreadInfo &info) const size_t num_gemms_per_thread = 16 / num_threads; const size_t first_gemm = tid * num_gemms_per_thread; const size_t last_gemm = (tid == (num_threads - 1)) ? 15 : first_gemm + num_gemms_per_thread - 1; - _convolver->execute(first_gemm, last_gemm); + _convolver->_pimpl->convolver.execute(first_gemm, last_gemm); } } // namespace arm_compute diff --git a/src/core/NEON/kernels/winograd/gemm.hpp b/src/core/NEON/kernels/winograd/gemm.hpp new file mode 100644 index 0000000000..564016a646 --- /dev/null +++ b/src/core/NEON/kernels/winograd/gemm.hpp @@ -0,0 +1,127 @@ + +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include "utils.hpp" + +template +void Gemm(const TIn* const a, const TIn* const b, TOut *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride, + const bool a_transposed=false, + const bool b_transposed=false) { + // Array access methods + const auto A = [a, a_transposed, M, K, a_row_stride] (const int i, const int j) -> TIn { + return a[(!a_transposed) ? i*a_row_stride + j : i + j*M]; + }; + + const auto B = [b, b_transposed, K, N, b_row_stride] (const int i, const int j) -> TIn { + return b[(!b_transposed) ? i*b_row_stride + j : i + j*N]; + }; + + const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& { + return c[i*c_row_stride + j]; + }; + + // Perform the matrix multiplication + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < K; k++) { + C(i, j) += A(i, k) * B(k, j); + } + } + } +} + +template +void BlockedGemm( + const TIn* const a, const TIn* const b, TOut *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + // Array access methods + const auto A = [a, M, K, a_row_stride] (const int i, const int j) -> TIn { + return a[i*a_row_stride + j]; + }; + + const auto B = [b, K, N, b_row_stride] (const int i, const int j) -> TIn { + return b[i*b_row_stride + j]; + }; + + const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& { + return c[i*c_row_stride + j]; + }; + + const int M_BLOCKS = iceildiv(M, M_BLOCK); + const int N_BLOCKS = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < M_BLOCKS; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < N_BLOCKS; nblock++) { + // Create an appropriately sized block of accumulators + TOut accum[M_BLOCK][N_BLOCK]; + for (int i = 0; i < M_BLOCK; i++) { + for (int j = 0; j < N_BLOCK; j++) { + accum[i][j] = static_cast(0); + } + } + + // Perform this portion of the matrix multiply + for (int k = 0; k < K; k++) { + // Load elements of A + TIn elems_a[M_BLOCK]; + for (int i = 0; i < M_BLOCK; i++) { + elems_a[i] = A(mblock*M_BLOCK + i, k); + } + + // Load elements of B + TIn elems_b[N_BLOCK]; + for (int j = 0; j < N_BLOCK; j++) { + elems_b[j] = B(k, nblock*N_BLOCK + j); + } + + // Perform the partial matrix multiply + for (int i = 0; i < M_BLOCK; i++) { + for (int j = 0; j < N_BLOCK; j++) { + accum[i][j] += elems_a[i] * elems_b[j]; + } + } + } + + // Store the partial product + for (int i = 0; i < M_BLOCK; i++) { + for (int j = 0; j < N_BLOCK; j++) { + C(mblock*M_BLOCK + i, nblock*N_BLOCK + j) = accum[i][j]; + } + } + } + } +} + +#include "gemm/a64_sgemm.hpp" diff --git a/src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp b/src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp new file mode 100644 index 0000000000..e1b7488c31 --- /dev/null +++ b/src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp @@ -0,0 +1,355 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include +#include "../utils.hpp" + +#ifdef __aarch64__ + +template <> +inline void BlockedGemm<8, 12, float, float>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int M_BLOCK = 8; + const int N_BLOCK = 12; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = K; + + asm volatile ( + // Create an 8x12 block of accumulators + " A_1 .req v27\n" + "sA_1 .req s27\n" + " A_2 .req v28\n" + "sA_2 .req s28\n" + " A_3 .req v29\n" + "sA_3 .req s29\n" + " A_4 .req v30\n" + "sA_4 .req s30\n" + + " B_1 .req v24\n" " B_2 .req v25\n" " B_3 .req v26\n" + "qB_1 .req q24\n" "qB_2 .req q25\n" "qB_3 .req q26\n" + + " C_11 .req v0\n" " C_12 .req v1\n" " C_13 .req v2\n" + " C_21 .req v3\n" " C_22 .req v4\n" " C_23 .req v5\n" + " C_31 .req v6\n" " C_32 .req v7\n" " C_33 .req v8\n" + " C_41 .req v9\n" " C_42 .req v10\n" " C_43 .req v11\n" + " C_51 .req v12\n" " C_52 .req v13\n" " C_53 .req v14\n" + " C_61 .req v15\n" " C_62 .req v16\n" " C_63 .req v17\n" + " C_71 .req v18\n" " C_72 .req v19\n" " C_73 .req v20\n" + " C_81 .req v21\n" " C_82 .req v22\n" " C_83 .req v23\n" + + "qC_11 .req q0\n" "qC_12 .req q1\n" "qC_13 .req q2\n" + "qC_21 .req q3\n" "qC_22 .req q4\n" "qC_23 .req q5\n" + "qC_31 .req q6\n" "qC_32 .req q7\n" "qC_33 .req q8\n" + "qC_41 .req q9\n" "qC_42 .req q10\n" "qC_43 .req q11\n" + "qC_51 .req q12\n" "qC_52 .req q13\n" "qC_53 .req q14\n" + "qC_61 .req q15\n" "qC_62 .req q16\n" "qC_63 .req q17\n" + "qC_71 .req q18\n" "qC_72 .req q19\n" "qC_73 .req q20\n" + "qC_81 .req q21\n" "qC_82 .req q22\n" "qC_83 .req q23\n" + + "aptr1 .req x17\n" + "aptr2 .req x18\n" + "aptr3 .req x19\n" + "aptr4 .req x20\n" + "aptr5 .req x21\n" + "aptr6 .req x22\n" + "aptr7 .req x23\n" + + // Initialise accumulators with 0 + // Initialise pointers + "movi C_11.4s, #0\n" + "add aptr1, %x[aptr], %x[a_row_stride]\n" + "movi C_12.4s, #0\n" + "add aptr2, aptr1, %x[a_row_stride]\n" + "movi C_13.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride]\n" + "movi C_21.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride]\n" + "movi C_22.4s, #0\n" + "add aptr5, aptr4, %x[a_row_stride]\n" + "movi C_23.4s, #0\n" + "add aptr6, aptr5, %x[a_row_stride]\n" + "movi C_31.4s, #0\n" + "add aptr7, aptr6, %x[a_row_stride]\n" + "movi C_32.4s, #0\n" + "ldr qB_1, [%x[bptr]]\n" + "movi C_33.4s, #0\n" + "ldr qB_2, [%x[bptr], #0x10]\n" + "movi C_41.4s, #0\n" + "prfm pldl1keep, [%x[bptr], #0x00]\n" + "movi C_42.4s, #0\n" + "prfm pldl1keep, [%x[bptr], #0x10]\n" + "movi C_43.4s, #0\n" + "prfm pldl1keep, [%x[bptr], #0x20]\n" + "movi C_51.4s, #0\n" + "prfm pldl1keep, [%x[aptr], #0x00]\n" + "movi C_52.4s, #0\n" + "prfm pldl1keep, [ aptr1, #0x00]\n" + "movi C_53.4s, #0\n" + "prfm pldl1keep, [ aptr2, #0x00]\n" + "movi C_61.4s, #0\n" + "prfm pldl1keep, [ aptr3, #0x00]\n" + "movi C_62.4s, #0\n" + "prfm pldl1keep, [ aptr4, #0x00]\n" + "movi C_63.4s, #0\n" + "prfm pldl1keep, [ aptr5, #0x00]\n" + "movi C_71.4s, #0\n" + "prfm pldl1keep, [ aptr6, #0x00]\n" + "movi C_72.4s, #0\n" + "prfm pldl1keep, [ aptr7, #0x00]\n" + "movi C_73.4s, #0\n" + "ldr sA_1, [%x[aptr]], #0x4\n" + "movi C_81.4s, #0\n" + "ldr sA_2, [ aptr1], #0x4\n" + "movi C_82.4s, #0\n" + "ldr sA_3, [ aptr2], #0x4\n" + "movi C_83.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 2f\n" + + "1:" + "fmla C_11.4s, B_1.4s, A_1.s[0]\n" + "ldr qB_3, [%x[bptr], #0x20]\n" + "fmla C_12.4s, B_2.4s, A_1.s[0]\n" + "ldr sA_4, [ aptr3], #0x4\n" + "fmla C_13.4s, B_3.4s, A_1.s[0]\n" + "ldr sA_1, [ aptr4], #0x04\n" + + "fmla C_21.4s, B_1.4s, A_2.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride]\n" + "fmla C_22.4s, B_2.4s, A_2.s[0]\n" + "prfm pldl1keep, [ aptr3, #0x10]\n" + "fmla C_23.4s, B_3.4s, A_2.s[0]\n" + "ldr sA_2, [ aptr5], #0x04\n" + + "fmla C_31.4s, B_1.4s, A_3.s[0]\n" + "prfm pldl1keep, [%x[bptr], #0x00]\n" + "fmla C_32.4s, B_2.4s, A_3.s[0]\n" + "prfm pldl1keep, [%x[bptr], #0x10]\n" + "fmla C_33.4s, B_3.4s, A_3.s[0]\n" + "ldr sA_3, [ aptr6], #0x04\n" + + "fmla C_41.4s, B_1.4s, A_4.s[0]\n" + "prfm pldl1keep, [%x[bptr], #0x20]\n" + "fmla C_42.4s, B_2.4s, A_4.s[0]\n" + "prfm pldl1keep, [ aptr4, #0x10]\n" + "fmla C_43.4s, B_3.4s, A_4.s[0]\n" + "ldr sA_4, [ aptr7], #0x04\n" + + "fmla C_51.4s, B_1.4s, A_1.s[0]\n" + "prfm pldl1keep, [ aptr5, #0x10]\n" + "fmla C_52.4s, B_2.4s, A_1.s[0]\n" + "prfm pldl1keep, [ aptr6, #0x10]\n" + "fmla C_53.4s, B_3.4s, A_1.s[0]\n" + "ldr sA_1, [%x[aptr]], #0x04\n" + + "fmla C_61.4s, B_1.4s, A_2.s[0]\n" + "prfm pldl1keep, [ aptr7, #0x10]\n" + "fmla C_62.4s, B_2.4s, A_2.s[0]\n" + "subs %x[k], %x[k], #1\n" + "fmla C_63.4s, B_3.4s, A_2.s[0]\n" + "ldr sA_2, [ aptr1], #0x04\n" + + "fmla C_71.4s, B_1.4s, A_3.s[0]\n" + "prfm pldl1keep, [%x[aptr], #0x10]\n" + "fmla C_72.4s, B_2.4s, A_3.s[0]\n" + "prfm pldl1keep, [ aptr1, #0x10]\n" + "fmla C_73.4s, B_3.4s, A_3.s[0]\n" + "ldr sA_3, [ aptr2], #0x04\n" + + "fmla C_81.4s, B_1.4s, A_4.s[0]\n" + "prfm pldl1keep, [ aptr2, #0x10]\n" + "fmla C_82.4s, B_2.4s, A_4.s[0]\n" + "ldp qB_1, qB_2, [%x[bptr]]\n" + "fmla C_83.4s, B_3.4s, A_4.s[0]\n" + "bne 1b\n" + + "2:" + "fmla C_11.4s, B_1.4s, A_1.s[0]\n" + "ldr qB_3, [%x[bptr], #0x20]\n" + "fmla C_12.4s, B_2.4s, A_1.s[0]\n" + "stp qC_11, qC_12, [%x[cptr]]\n" + "fmla C_13.4s, B_3.4s, A_1.s[0]\n" + "str qC_13, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_1, [ aptr4], #0x04\n" + + "fmla C_21.4s, B_1.4s, A_2.s[0]\n" + "ldr sA_4, [ aptr3], #0x4\n" + "fmla C_22.4s, B_2.4s, A_2.s[0]\n" + "stp qC_21, qC_22, [%x[cptr]]\n" + "fmla C_23.4s, B_3.4s, A_2.s[0]\n" + "str qC_23, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_2, [ aptr5], #0x04\n" + + "fmla C_31.4s, B_1.4s, A_3.s[0]\n" + "fmla C_32.4s, B_2.4s, A_3.s[0]\n" + "stp qC_31, qC_32, [%x[cptr]]\n" + "fmla C_33.4s, B_3.4s, A_3.s[0]\n" + "str qC_33, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_3, [ aptr6], #0x04\n" + + "fmla C_41.4s, B_1.4s, A_4.s[0]\n" + "fmla C_42.4s, B_2.4s, A_4.s[0]\n" + "stp qC_41, qC_42, [%x[cptr]]\n" + "fmla C_43.4s, B_3.4s, A_4.s[0]\n" + "str qC_43, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_4, [ aptr7], #0x04\n" + + "fmla C_51.4s, B_1.4s, A_1.s[0]\n" + "fmla C_52.4s, B_2.4s, A_1.s[0]\n" + "stp qC_51, qC_52, [%x[cptr]]\n" + "fmla C_53.4s, B_3.4s, A_1.s[0]\n" + "str qC_53, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + "fmla C_61.4s, B_1.4s, A_2.s[0]\n" + "fmla C_62.4s, B_2.4s, A_2.s[0]\n" + "stp qC_61, qC_62, [%x[cptr]]\n" + "fmla C_63.4s, B_3.4s, A_2.s[0]\n" + "str qC_63, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + "fmla C_71.4s, B_1.4s, A_3.s[0]\n" + "fmla C_72.4s, B_2.4s, A_3.s[0]\n" + "stp qC_71, qC_72, [%x[cptr]]\n" + "fmla C_73.4s, B_3.4s, A_3.s[0]\n" + "str qC_73, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + "fmla C_81.4s, B_1.4s, A_4.s[0]\n" + "fmla C_82.4s, B_2.4s, A_4.s[0]\n" + "stp qC_81, qC_82, [%x[cptr]]\n" + "fmla C_83.4s, B_3.4s, A_4.s[0]\n" + "str qC_83, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + // Clear aliases + ".unreq aptr1\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + ".unreq aptr5\n" + ".unreq aptr6\n" + ".unreq aptr7\n" + + ".unreq A_1\n" ".unreq A_2\n" ".unreq A_3\n" ".unreq A_4\n" + ".unreq sA_1\n" ".unreq sA_2\n" ".unreq sA_3\n" ".unreq sA_4\n" + + ".unreq B_1\n" ".unreq B_2\n" ".unreq B_3\n" + ".unreq qB_1\n" ".unreq qB_2\n" ".unreq qB_3\n" + + ".unreq C_11\n" ".unreq C_12\n" ".unreq C_13\n" + ".unreq C_21\n" ".unreq C_22\n" ".unreq C_23\n" + ".unreq C_31\n" ".unreq C_32\n" ".unreq C_33\n" + ".unreq C_41\n" ".unreq C_42\n" ".unreq C_43\n" + ".unreq C_51\n" ".unreq C_52\n" ".unreq C_53\n" + ".unreq C_61\n" ".unreq C_62\n" ".unreq C_63\n" + ".unreq C_71\n" ".unreq C_72\n" ".unreq C_73\n" + ".unreq C_81\n" ".unreq C_82\n" ".unreq C_83\n" + + ".unreq qC_11\n" ".unreq qC_12\n" ".unreq qC_13\n" + ".unreq qC_21\n" ".unreq qC_22\n" ".unreq qC_23\n" + ".unreq qC_31\n" ".unreq qC_32\n" ".unreq qC_33\n" + ".unreq qC_41\n" ".unreq qC_42\n" ".unreq qC_43\n" + ".unreq qC_51\n" ".unreq qC_52\n" ".unreq qC_53\n" + ".unreq qC_61\n" ".unreq qC_62\n" ".unreq qC_63\n" + ".unreq qC_71\n" ".unreq qC_72\n" ".unreq qC_73\n" + ".unreq qC_81\n" ".unreq qC_82\n" ".unreq qC_83\n" + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride] "r" (a_row_stride * sizeof(float)), + [b_row_stride] "r" (b_row_stride * sizeof(float)), + [c_row_stride] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "x17", "x18", "x19", "x20", "x21", "x22", "x23" + ); + } + } +} + +/*****************************************************************************/ +/* 4x16 blocked GEMM with specialised tails + */ +#include "a64_sgemm_4x16.hpp" + +template <> +inline void BlockedGemm<4, 16, float, float>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + // Despatch based on tail of K + switch (K % 4) { + case 3: + sgemm_4x16_impl<3>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + case 2: + sgemm_4x16_impl<2>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + case 1: + sgemm_4x16_impl<1>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + case 0: + sgemm_4x16_impl<0>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + default: + assert(0); + break; + } +} + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp b/src/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp new file mode 100644 index 0000000000..e74610ef27 --- /dev/null +++ b/src/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp @@ -0,0 +1,1445 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +template +inline void sgemm_4x16_impl( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +); + +template <> +inline void sgemm_4x16_impl<0>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 0; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC12.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC13.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC14.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC21.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC22.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC23.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC24.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC31.4s, #0\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 2f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "2:" // Tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} + +template <> +inline void sgemm_4x16_impl<1>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 1; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC12.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC13.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC14.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC21.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC22.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC23.4s, #0\n" + "cbnz %x[k], 3f\n" + + // Prepare for tail in K + "movi vC24.4s, #0\n" + "ldr sA1, [%x[aptr]], #0x04\n" + "movi vC31.4s, #0\n" + "ldr sA2, [ aptr2], #0x04\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "b 2f\n" // Jump to tail + + "3:" // Prepare for loop over K + "movi vC24.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC31.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 4f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "4:" // Tail iteration + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr sA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr sA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + + "2:" // Common tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "ldr sA3, [ aptr3], #0x10\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "ldr sA4, [ aptr4], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} + +template <> +inline void sgemm_4x16_impl<2>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 2; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC12.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC13.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC14.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC21.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC22.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC23.4s, #0\n" + "cbnz %x[k], 3f\n" + + // Prepare for tail in K + "movi vC24.4s, #0\n" + "ldr dA1, [%x[aptr]], #0x08\n" + "movi vC31.4s, #0\n" + "ldr dA2, [ aptr2], #0x08\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "b 2f\n" // Jump to tail + + "3:" // Prepare for loop over K + "movi vC24.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC31.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 4f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "4:" // Tail iteration + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr dA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr dA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + + "2:" // Common tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr dA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr dA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} + +template <> +inline void sgemm_4x16_impl<3>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 3; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC12.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC13.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC14.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC21.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC22.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC23.4s, #0\n" + "cbnz %x[k], 3f\n" + + // Prepare for tail in K + "movi vC24.4s, #0\n" + "ldr dA1, [%x[aptr]], #0x08\n" + "movi vC31.4s, #0\n" + "ldr dA2, [ aptr2], #0x08\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "b 2f\n" // Jump to tail + + "3:" // Prepare for loop over K + "movi vC24.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC31.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 4f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "4:" // Tail iteration + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr dA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr dA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + + "2:" // Common tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr dA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr dA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "ldr sA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "ldr sA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "ldr sA3, [ aptr3], #0x10\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "ldr sA4, [ aptr4], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} diff --git a/src/core/NEON/kernels/winograd/perf.h b/src/core/NEON/kernels/winograd/perf.h new file mode 100644 index 0000000000..11fb0c452f --- /dev/null +++ b/src/core/NEON/kernels/winograd/perf.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +/* Prototypes from perf.c */ + +void start_counter(int fd); +long long get_counter(int fd); +long long stop_counter(int fd); +int open_instruction_counter(void); +int open_cycle_counter(void); diff --git a/src/core/NEON/kernels/winograd/profiler.hpp b/src/core/NEON/kernels/winograd/profiler.hpp new file mode 100644 index 0000000000..143192b589 --- /dev/null +++ b/src/core/NEON/kernels/winograd/profiler.hpp @@ -0,0 +1,244 @@ + +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "perf.h" +#include + +class profiler { +private: +#ifdef CYCLE_PROFILING + struct ProfileEntry { + int event_id; + long int bytes_read, ops, bytes_written; + long int duration; + }; + + static const int maxevents = 10000; + ProfileEntry events[maxevents]; + int currentevent; + int countfd; + + std::map event_ids; + + int get_event_id(const char *id) { + if (!event_ids.count(id)) { + event_ids.emplace(id, event_ids.size()); + } + return event_ids[id]; + } +#endif // CYCLE_PROFILING + +public: +#ifdef CYCLE_PROFILING + profiler() { + currentevent = 0; + countfd = open_cycle_counter(); + } + + ~profiler() { + close(countfd); + + // Compute performance from recorded events + struct ProfileResult { + ProfileResult() : total_calls(0), + total_duration(0), + total_bytes_read(0), + total_ops(0), + total_bytes_written(0) { + } + + void operator+=(const ProfileEntry &rhs) { + total_calls++; + total_duration += rhs.duration; + total_bytes_read += rhs.bytes_read; + total_ops += rhs.ops; + total_bytes_written = rhs.bytes_written; + } + + float avg_duration(void) const { + return static_cast(total_duration) / + static_cast(total_calls); + } + + float bytes_read_per_cycle(void) const { + return static_cast(total_bytes_read) / + static_cast(total_duration); + } + + float ops_per_cycle(void) const { + return static_cast(total_ops) / + static_cast(total_duration); + } + + float bytes_written_per_cycle(void) const { + return static_cast(total_bytes_written) / + static_cast(total_duration); + } + + long int total_calls, + total_duration, + total_bytes_read, + total_ops, + total_bytes_written; + }; + + std::vector totals; + totals.resize(event_ids.size()); + for (int i = 0; i < currentevent; i++) { + const auto &event = events[i]; + totals[event.event_id] += event; + } + + // Get the longest label + int len_label = 0; + for (const auto &kv : event_ids) { + len_label = std::max(len_label, static_cast(strlen(kv.first))); + } + + // Get the longest values for every other field + const auto get_length_of_field = + [totals] (const char *title, auto f, auto len) -> size_t { + size_t l = strlen(title); + for (const auto &v : totals) { + l = std::max(l, len(f(v))); + } + return l; + }; + + // Get the strlen for an int + const auto intlen = [] (long int x) -> size_t { + size_t len = 0; + do { + x /= 10; + len++; + } while (x); + return len; + }; + + // Get the strlen for a float + const auto floatlen = [] (const int precision) { + return [precision] (float x) { + size_t len = 0; + + if (!std::isfinite(x)) { + return static_cast(3); + } + + do { + x /= 10.0f; + len++; + } while (x > 1.0f); + return len + 1 + precision; + }; + }; + + const int len_calls = get_length_of_field( + "Calls", [] (const auto &v) {return v.total_calls;}, + intlen + ); + const int len_duration = get_length_of_field( + "Duration", [] (const auto &v) {return v.total_duration;}, + intlen + ); + const int len_average_duration = get_length_of_field( + "Average", [] (const auto &v) {return v.avg_duration();}, + floatlen(2) + ); + const int len_reads_per_cycle = get_length_of_field( + "Reads / cycle", + [] (const auto &v) {return v.bytes_read_per_cycle();}, + floatlen(6) + ); + const int len_ops_per_cycle = get_length_of_field( + "Ops / cycle", + [] (const auto &v) {return v.ops_per_cycle();}, + floatlen(6) + ); + const int len_writes_per_cycle = get_length_of_field( + "Writes / cycle", + [] (const auto &v) {return v.bytes_written_per_cycle();}, + floatlen(6) + ); + + // Print header + printf( + "%*s %*s %*s %*s %*s %*s %*s\n", + len_label, "", + len_calls, "Calls", + len_duration, "Duration", + len_average_duration, "Average", + len_reads_per_cycle, "Reads / cycle", + len_ops_per_cycle, "Ops / cycle", + len_writes_per_cycle, "Writes / cycle" + ); + for (const auto &kv : event_ids) { + const auto id = kv.second; + printf( + "%*s %*ld %*ld %*.2f %*.6f %*.6f %*.6f\n", + len_label, kv.first, + len_calls, totals[id].total_calls, + len_duration, totals[id].total_duration, + len_average_duration, totals[id].avg_duration(), + len_reads_per_cycle, totals[id].bytes_read_per_cycle(), + len_ops_per_cycle, totals[id].ops_per_cycle(), + len_writes_per_cycle, totals[id].bytes_written_per_cycle() + ); + } + printf("\n"); + } +#endif // CYCLE_PROFILING + + template + void operator() (const char * event, + T func, + long int bytes_read = 0, + long int ops = 0, + long int bytes_written = 0) { +#ifdef CYCLE_PROFILING + if (currentevent==maxevents) { + func(); + } else { + start_counter(countfd); + func(); + long long cycs = stop_counter(countfd); + + // Store the profiling data + events[currentevent++] = { + get_event_id(event), bytes_read, ops, bytes_written, cycs + }; + } +#else + func(); +#endif // CYCLE_PROFILING + } +}; diff --git a/src/core/NEON/kernels/winograd/shims.hpp b/src/core/NEON/kernels/winograd/shims.hpp new file mode 100644 index 0000000000..249e5757f0 --- /dev/null +++ b/src/core/NEON/kernels/winograd/shims.hpp @@ -0,0 +1,319 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +/** Re-order a weight tensor from [Output feature map x Input feature map x + * Height x Width] format to [Height x Width x Input feature map x Output + * feature map] format. + */ +template +inline void ofm_ifm_h_w_to_h_w_ifm_ofm( + const T* const in, // Input in [Output x Input x Height x Width] form + T* const out, // Output in [Height x Width x Input x Output] form + const int n_output_feature_maps, + const int n_input_feature_maps, + const int n_rows, + const int n_cols, + int in_output_feature_map_stride=0, + int in_input_feature_map_stride=0, + int in_row_stride=0, + int out_row_stride=0, + int out_col_stride=0, + int out_input_feature_map_stride=0 +); + +/** Re-order a weight tensor from [Height x Width x Input feature map x Output + * feature map] format to [Output feature map x Input feature map x Height x + * Width] format. + */ +template +inline void h_w_ifm_ofm_to_ofm_ifm_h_w( + const T* const in, // Input in [Height x Width x Input x Output] form + T* const out, // Output in [Output x Input x Height x Width] form + const int n_rows, + const int n_cols, + const int n_input_feature_maps, + const int n_output_feature_maps, + int in_row_stride=0, + int in_col_stride=0, + int in_input_feature_map_stride=0, + int out_output_feature_map_stride=0, + int out_input_feature_map_stride=0, + int out_row_stride=0 +); + + +/* Re-order a tensor from NCHW format to NHWC. + */ +template +inline void nchw_to_nhwc( + const T* const in, + T* const out, + const int n_batches, + const int n_channels, + const int n_rows, + const int n_cols, + int in_batch_stride=0, + int in_channel_stride=0, + int in_row_stride=0, + int out_batch_stride=0, + int out_row_stride=0, + int out_col_stride=0 +) +{ + // Fill in the stride values + in_row_stride = (in_row_stride) ? in_row_stride : n_cols; + in_channel_stride = (in_channel_stride) ? in_channel_stride + : n_rows * in_row_stride; + in_batch_stride = (in_batch_stride) ? in_batch_stride + : n_channels * in_channel_stride; + + out_col_stride = (out_col_stride) ? out_col_stride : n_channels; + out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride; + out_batch_stride = (out_batch_stride) ? out_batch_stride + : n_rows * out_row_stride; + + // Perform the re-ordering + for (int n = 0; n < n_batches; n++) + { + const T* const in_batch = in + n*in_batch_stride; + T* const out_batch = out + n*out_batch_stride; + + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in_batch + i*in_row_stride; + T* const out_row = out_batch + i*out_row_stride; + + for (int j = 0; j < n_cols; j++) + { + const T* const in_col = in_row + j; + T* const out_col = out_row + j*out_col_stride; + + for (int c = 0; c < n_channels; c++) + { + const T* const in_channel = in_col + c*in_channel_stride; + out_col[c] = *(in_channel); + } + } + } + } +} + +/* Re-order a tensor from NHWC format to NCHW. + */ +template +inline void nhwc_to_nchw( + const T* const in, // Input data in NHWC form + T* const out, // Output data in NCHW form + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + int in_batch_stride=0, + int in_row_stride=0, + int in_col_stride=0, + int out_batch_stride=0, + int out_channel_stride=0, + int out_row_stride=0 +) +{ + // Fill in stride values + in_col_stride = (in_col_stride) ? in_col_stride : n_channels; + in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride; + in_batch_stride = (in_batch_stride) ? in_batch_stride + : n_rows * in_row_stride; + + out_row_stride = (out_row_stride) ? out_row_stride : n_cols; + out_channel_stride = (out_channel_stride) ? out_channel_stride + : n_rows * out_row_stride; + out_batch_stride = (out_batch_stride) ? out_batch_stride + : n_channels * out_channel_stride; + + // Perform the re-ordering + // For every batch + for (int n = 0; n < n_batches; n++) + { + const T* const in_batch = in + n*in_batch_stride; + T* const out_batch = out + n*out_batch_stride; + + // For every row + for (int i = 0; i < n_rows; i++) + { + const T* const in_i = in_batch + i*in_row_stride; + T* const out_i = out_batch + i*out_row_stride; + + // For every column + for (int j = 0; j < n_cols; j++) + { + const T* const in_j = in_i + j*in_col_stride; + T* const out_j = out_i + j; + + // For every channel + for (int c = 0; c < n_channels; c++) + { + const T* const in_channel = in_j + c; + T* const out_channel = out_j + c*out_channel_stride; + *(out_channel) = *(in_channel); + } + } + } + } +} + + +/*****************************************************************************/ +/* Generic weight re-order implementation. + */ +template +inline void ofm_ifm_h_w_to_h_w_ifm_ofm( + const T* const in, // Input in [Output x Input x Height x Width] form + T* const out, // Output in [Height x Width x Input x Output] form + const int n_output_feature_maps, + const int n_input_feature_maps, + const int n_rows, + const int n_cols, + int in_output_feature_map_stride, + int in_input_feature_map_stride, + int in_row_stride, + int out_row_stride, + int out_col_stride, + int out_input_feature_map_stride +) +{ + // Fill in stride values + in_row_stride = (in_row_stride) + ? in_row_stride + : n_cols; + in_input_feature_map_stride = (in_input_feature_map_stride) + ? in_input_feature_map_stride + : n_rows * in_row_stride; + in_output_feature_map_stride = (in_output_feature_map_stride) + ? in_output_feature_map_stride + : n_input_feature_maps * in_input_feature_map_stride; + + out_input_feature_map_stride = (out_input_feature_map_stride) + ? out_input_feature_map_stride + : n_output_feature_maps; + out_col_stride = (out_col_stride) + ? out_col_stride + : n_input_feature_maps * out_input_feature_map_stride; + out_row_stride = (out_row_stride) + ? out_row_stride + : n_cols * out_col_stride; + + // Perform the re-ordering + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in + i * in_row_stride; + T* out_row = out + i * out_row_stride; + + for (int j = 0; j < n_cols; j++) + { + const T* const in_col = in_row + j; + T* const out_col = out_row + j * out_col_stride; + + for (int ifm = 0; ifm < n_input_feature_maps; ifm++) + { + const T* const in_ifm = in_col + ifm * in_input_feature_map_stride; + T* const out_ifm = out_col + ifm * out_input_feature_map_stride; + + for (int ofm = 0; ofm < n_output_feature_maps; ofm++) + { + const T* const in_ofm = in_ifm + ofm * in_output_feature_map_stride; + T* const out_ofm = out_ifm + ofm; + *(out_ofm) = *(in_ofm); + } + } + } + } +} + +/*****************************************************************************/ +/* Generic weight re-order implementation. + */ +template +inline void h_w_ifm_ofm_to_ofm_ifm_h_w( + const T* const in, // Input in [Height x Width x Input x Output] form + T* const out, // Output in [Output x Input x Height x Width] form + const int n_rows, + const int n_cols, + const int n_input_feature_maps, + const int n_output_feature_maps, + int in_row_stride, + int in_col_stride, + int in_input_feature_map_stride, + int out_output_feature_map_stride, + int out_input_feature_map_stride, + int out_row_stride +) +{ + // Fill in the stride values + in_input_feature_map_stride = (in_input_feature_map_stride) + ? in_input_feature_map_stride + : n_output_feature_maps; + in_col_stride = (in_col_stride) + ? in_col_stride + : n_input_feature_maps * in_input_feature_map_stride; + in_row_stride = (in_row_stride) + ? in_row_stride + : n_cols * in_col_stride; + + out_row_stride = (out_row_stride) + ? out_row_stride + : n_cols; + out_input_feature_map_stride = (out_input_feature_map_stride) + ? out_input_feature_map_stride + : n_rows * out_row_stride; + out_output_feature_map_stride = (out_output_feature_map_stride) + ? out_output_feature_map_stride + : n_input_feature_maps * out_input_feature_map_stride; + + // Perform the re-ordering + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in + i * in_row_stride; + T* const out_row = out + i * out_row_stride; + + for (int j = 0; j < n_cols; j++) + { + const T* const in_col = in_row + j * in_col_stride; + T* const out_col = out_row + j; + + for (int ifm = 0; ifm < n_input_feature_maps; ifm++) + { + const T* const in_ifm = in_col + ifm * in_input_feature_map_stride; + T* const out_ifm = out_col + ifm * out_input_feature_map_stride; + + for (int ofm = 0; ofm < n_output_feature_maps; ofm++) + { + const T* const in_ofm = in_ifm + ofm; + T* const out_ofm = out_ifm + ofm * out_output_feature_map_stride; + *(out_ofm) = *(in_ofm); + } + } + } + } +} + diff --git a/src/core/NEON/kernels/winograd/transforms.hpp b/src/core/NEON/kernels/winograd/transforms.hpp new file mode 100644 index 0000000000..8546ee9e2e --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms.hpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include "transforms/input_2x2_3x3.hpp" +#include "transforms/kernel_2x2_3x3.hpp" +#include "transforms/output_2x2_3x3.hpp" diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp new file mode 100644 index 0000000000..ca8d012e5e --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp @@ -0,0 +1,639 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include "arm_compute/core/NEON/kernels/winograd/tensor.hpp" + + +namespace winograd { + /* Transform an input tensor into the Winograd domain. + */ + template + struct Winograd2x2_3x3GemmInput { + static void execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride + ); + + static size_t bytes_read(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + return input_shape.n_batches * tile_rows * (16 + 8*(tile_cols - 1)) * input_shape.n_channels * sizeof(T); + } + + static int flops_performed(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + return input_shape.n_batches * tile_rows * (32 + 24*(tile_cols - 1)) * input_shape.n_channels; + } + + static size_t bytes_written(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + const int M = input_shape.n_batches * tile_rows * tile_cols; + return 16 * M * input_shape.n_channels * sizeof(T); + } + + protected: + template + static void process_tile_tensor( + const int tile_M, // Number of rows of tiles + const int tile_N, // Number of columns of tiles + int n_channels, // Number of input channels + const T* const input, // Base input pointer (appropriate to batch and channel) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + T* const matrix, // 1st output matrix (appropriate to batch and channel) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix + ); + + template + static void process_tile_row( + const int tile_N, // Number of tiles in the row + const T* const input, // Base input pointer (appropriate to batch, channel and row) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + T* const matrix, // 1st output matrix (appropriate to batch, channel and row) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix + ); + }; + + template + struct Winograd2x2_3x3GemmInputChannelwise { + static void execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride + ); + + static size_t bytes_read(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + // We read as many bytes as we write + return bytes_written(input_shape, output_shape); + } + + static int flops_performed(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + return input_shape.n_batches * tile_rows * 32 * tile_cols * input_shape.n_channels; + } + + static size_t bytes_written(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + return winograd::Winograd2x2_3x3GemmInput::bytes_written(input_shape, output_shape); + } + + protected: + typedef void (*tilefunc)(int, const T*, int, int, T*, int); + template + static void process_tile( + int n_channels, // Number of channels in the tile + const T* const input_base, + const int input_row_stride, + const int input_col_stride, + T* const matrix_base, + const int matrix_stride + ); + + private: + template + static void _process_tile( + int &n_channels, const T* &inptr, + const int input_row_stride, const int input_col_stride, + T* &outptr, const int matrix_stride + ); + }; +} + +/*****************************************************************************/ +// Include specialised implementations here +#include "input_2x2_3x3/a64_float.hpp" +#include "input_2x2_3x3/a64_float_channelwise.hpp" +/*****************************************************************************/ + +/*****************************************************************************/ +template +void winograd::Winograd2x2_3x3GemmInput::execute( + const T *inptr_base, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride +) { + // Select an appropriate matrix processing method for the shape and padding + // of the input tensor. + typedef void (*tensorfunc)(int, int, int, const T*, int, int, T*, int, int); + const auto process_tensor = [&padding_type, &input_shape] () -> tensorfunc { + if (padding_type == PADDING_VALID) { + const int pad_bottom = input_shape.n_rows % 2; + const int pad_right = input_shape.n_cols % 2; + + if (pad_bottom == 0 && pad_right == 0) { + return process_tile_tensor; + } else if (pad_bottom == 0 && pad_right == 1) { + return process_tile_tensor; + } else if (pad_bottom == 1 && pad_right == 0) { + return process_tile_tensor; + } else if (pad_bottom == 1 && pad_right == 1) { + return process_tile_tensor; + } + } else { // PADDING_SAME + const int pad_bottom = 1 + input_shape.n_rows % 2; + const int pad_right = 1 + input_shape.n_cols % 2; + + if (pad_bottom == 1 && pad_right == 1) { + return process_tile_tensor; + } else if (pad_bottom == 1 && pad_right == 2) { + return process_tile_tensor; + } else if (pad_bottom == 2 && pad_right == 1) { + return process_tile_tensor; + } else if (pad_bottom == 2 && pad_right == 2) { + return process_tile_tensor; + } + } + + printf("%s::%u Uncovered case.\n", __FILE__, __LINE__); + exit(-1); + return NULL; // No function found + } (); + + // Compute strides + const int input_row_stride = input_shape.n_cols * input_shape.n_channels; + const int input_col_stride = input_shape.n_channels; + + // Process each batch of the tensor in turn. + for (int batch = 0; batch < input_shape.n_batches; batch++) { + // Work out pointers + const T *inptr = inptr_base + (batch * input_shape.n_rows * + input_shape.n_cols * input_shape.n_channels); + T *outptr = outptr_base + batch * matrix_batch_stride; + + // Delegate doing the actual work + process_tensor( + tile_M, tile_N, input_shape.n_channels, + inptr, input_row_stride, input_col_stride, + outptr, matrix_stride, matrix_row_stride + ); + } +} + +/*****************************************************************************/ +template +template +void winograd::Winograd2x2_3x3GemmInput::process_tile_tensor( + const int tile_M, // Number of rows of tiles + const int tile_N, // Number of columns of tiles + int n_channels, // Number of input channels + const T* const input, // Base input pointer (appropriate to batch and channel) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + T* const matrix, // 1st output matrix (appropriate to batch and channel) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix +) { + // Base row processing functions + typedef void (*rowfunc)(int, const T*, int, int, T*, int, int); + const rowfunc process_top_row[3] = { + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 1> + : process_tile_row<1, 1, 0, pad_right, 1>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 2> + : process_tile_row<1, 1, 0, pad_right, 2>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 4> + : process_tile_row<1, 1, 0, pad_right, 4>, + }; + const rowfunc process_middle_row[3] = { + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 1> + : process_tile_row<0, 1, 0, pad_right, 1>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 2> + : process_tile_row<0, 1, 0, pad_right, 2>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 4> + : process_tile_row<0, 1, 0, pad_right, 4>, + }; + const rowfunc process_bottom_row[3] = { + (padding == PADDING_VALID) + ? process_tile_row<0, 0, pad_bottom, pad_right, 1> + : process_tile_row<0, 1, pad_bottom, pad_right, 1>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, pad_bottom, pad_right, 2> + : process_tile_row<0, 1, pad_bottom, pad_right, 2>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, pad_bottom, pad_right, 4> + : process_tile_row<0, 1, pad_bottom, pad_right, 4>, + }; + + // Method to get an input pointer for the given tile row + const auto get_inptr = [&input, &input_row_stride] (const int tile_i) { + if (padding == PADDING_VALID) { + return input + 2 * tile_i * input_row_stride; + } else { + return input + (2 * tile_i - (tile_i ? 1 : 0)) * input_row_stride; + } + }; + + // Wrapper to process a row of tiles, covering all channels. + const auto process_row = + [tile_N, input_row_stride, input_col_stride, matrix_stride, matrix_row_stride, n_channels] + (const rowfunc f[3], const T *inptr, T *outptr) { + int rem_channels = n_channels; + + // While there remain channels to process continue to process the + // row. + for (; rem_channels >= 4; rem_channels -= 4, inptr += 4, outptr += 4) { + f[2](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); + } + for (; rem_channels >= 2; rem_channels -= 2, inptr += 2, outptr += 2) { + f[1](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); + } + if (rem_channels) { + f[0](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); + } + }; + + // Process all rows of tiles in the tensor + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + T* const m_row = matrix + tile_i * tile_N * matrix_row_stride; + const T *row_inptr = get_inptr(tile_i); + + if (tile_i == 0) { + // Top row of the input + process_row(process_top_row, row_inptr, m_row); + } else if (tile_i == tile_M - 1) { + // Bottom row of the input + process_row(process_bottom_row, row_inptr, m_row); + } else { + // Any other row of the input + process_row(process_middle_row, row_inptr, m_row); + } + } +} + +/*****************************************************************************/ +template +template +void winograd::Winograd2x2_3x3GemmInput::process_tile_row( + const int tile_N, // Number of tiles in the row + const T* const input, // Base input pointer (appropriate to batch, channel and row) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + T* const matrix, // 1st output matrix (appropriate to batch, channel and row) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix +) { + // Construct copies of the pointers + const T *inptr = input; + T *outptr = matrix; + + // Storage for the tensors x, X.T x, and X.T x X. + T x[4][4][proc_channels], XTx[4][4][proc_channels], XTxX[4][4][proc_channels]; + + // For every tile in the row + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + // Determine the padding for the tile + const int tile_pad_left = (tile_j == 0) ? pad_left : 0; + const int tile_pad_right = (tile_j == tile_N - 1) ? pad_right : 0; + + // Load tile values. If this is the first tile in the row then we must load + // all values, otherwise we can just load the final two columns of the input. + for (int i = 0; i < 4; i++) { + for (int j = ((tile_j == 0) ? 0 : 2); j < 4; j++) { + // Fill with padding if required + if (i < pad_top || 4 - pad_bottom <= i || + j < tile_pad_left || 4 - tile_pad_right <= j) { + for (int c = 0; c < proc_channels; c++) { + x[i][j][c] = static_cast(0); // Padding + } + } else { + // Load values, note that the initial padding offsets the pointer we + // were provided. + for (int c = 0; c < proc_channels; c++) { + const int row_offset = (i - pad_top) * input_row_stride; + const int col_offset = (j - tile_pad_left) * input_col_stride; + x[i][j][c] = inptr[row_offset + col_offset + c]; + } + } + } + } + + // Compute the matrix X.T x. Note, can elide operations depending on the + // padding. Furthermore, if this isn't the left-most tile we can skip half + // of the operations by copying results from the previous version of X.T x. + // This latter optimisation can be simplified by unrolling the outermost + // loop by two and by renaming the registers containing XTx. + if (tile_j == 0) { + for (int j = 0; j < 4; j++) { + for (int c = 0; c < proc_channels; c++) { + XTx[0][j][c] = x[0][j][c] - x[2][j][c]; + XTx[1][j][c] = x[1][j][c] + x[2][j][c]; + XTx[2][j][c] = -x[1][j][c] + x[2][j][c]; + XTx[3][j][c] = x[1][j][c] - x[3][j][c]; + } + } + } else { + for (int j = 0; j < 2; j++) { + for (int c = 0; c < proc_channels; c++) { + XTx[0][j][c] = XTx[0][j + 2][c]; + XTx[1][j][c] = XTx[1][j + 2][c]; + XTx[2][j][c] = XTx[2][j + 2][c]; + XTx[3][j][c] = XTx[3][j + 2][c]; + } + } + for (int j = 2; j < 4; j++) { + for (int c = 0; c < proc_channels; c++) { + XTx[0][j][c] = x[0][j][c] - x[2][j][c]; + XTx[1][j][c] = x[1][j][c] + x[2][j][c]; + XTx[2][j][c] = -x[1][j][c] + x[2][j][c]; + XTx[3][j][c] = x[1][j][c] - x[3][j][c]; + } + } + } + + // Compute the matrix X.T x X. Note, can elide operations based on the + // padding. + for (int i = 0; i < 4; i++) { + for (int c = 0; c < proc_channels; c++) { + XTxX[i][0][c] = XTx[i][0][c] - XTx[i][2][c]; + XTxX[i][1][c] = XTx[i][1][c] + XTx[i][2][c]; + XTxX[i][2][c] = -XTx[i][1][c] + XTx[i][2][c]; + XTxX[i][3][c] = XTx[i][1][c] - XTx[i][3][c]; + } + } + + // Store the output matrix (X.T x X) + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + // Get a pointer to the relevant output matrix + T *mptr = outptr + (i*4 + j)*matrix_stride; + + // Write out the channels + for (int c = 0; c < proc_channels; c++) { + mptr[c] = XTxX[i][j][c]; + } + } + } + + // Update the pointers + inptr += input_col_stride * ((tile_j == 0 && pad_left) ? 1 : 2); + outptr += matrix_row_stride; + } +} + +/*****************************************************************************/ +template +void winograd::Winograd2x2_3x3GemmInputChannelwise::execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride +) { + const int n_channels = input_shape.n_channels; + const int input_col_stride = n_channels; + const int input_row_stride = input_shape.n_cols * input_col_stride; + + // Determine the padding and hence select appropriate methods for each tile. + tilefunc fs[3][3]; + + if (padding_type == PADDING_VALID) { + constexpr int pad_top = 0; + constexpr int pad_left = 0; + const int pad_right = input_shape.n_cols % 2 == 0; + + fs[0][0] = process_tile; + fs[0][1] = process_tile; + fs[0][2] = (pad_right) ? process_tile : process_tile; + + fs[1][0] = process_tile<0, pad_left, 0, 0>; + fs[1][1] = process_tile<0, 0, 0, 0>; + fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 0> : process_tile<0, 0, 0, 1>; + + if (input_shape.n_rows % 2 == 0) { + constexpr int pad_bottom = 0; + fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; + fs[2][1] = process_tile<0, 0, pad_bottom, 0>; + fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>; + } else { + constexpr int pad_bottom = 1; + fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; + fs[2][1] = process_tile<0, 0, pad_bottom, 0>; + fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>; + } + } else { + constexpr int pad_top = 1; + constexpr int pad_left = 1; + const int pad_right = input_shape.n_cols % 2 == 0; + + fs[0][0] = process_tile; + fs[0][1] = process_tile; + fs[0][2] = (pad_right) ? process_tile : process_tile; + + fs[1][0] = process_tile<0, pad_left, 0, 0>; + fs[1][1] = process_tile<0, 0, 0, 0>; + fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 1> : process_tile<0, 0, 0, 2>; + + if (input_shape.n_rows % 2 == 0) { + constexpr int pad_bottom = 1; + fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; + fs[2][1] = process_tile<0, 0, pad_bottom, 0>; + fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>; + } else { + constexpr int pad_bottom = 2; + fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; + fs[2][1] = process_tile<0, 0, pad_bottom, 0>; + fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>; + } + } + + // Process each tile in turn + for (int batch = 0; batch < input_shape.n_batches; batch++) { + const T* const input_base_batch = inptr + batch*input_shape.n_rows*input_shape.n_cols*n_channels; + + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + const int row_offset = (tile_i == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1); + const T* const input_base_row = input_base_batch + (2*tile_i - row_offset)*input_shape.n_cols*n_channels; + + // Select the set of functions for the row + const int fs_i = (tile_i == 0) ? 0 : ((tile_i < tile_M - 1) ? 1 : 2); + + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + // Select the function for the column + const int fs_j = (tile_j == 0) ? 0 : ((tile_j < tile_N - 1) ? 1 : 2); + const auto f = fs[fs_i][fs_j]; + + // Get pointers into the input and outputs + const int col_offset = (tile_j == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1); + const T* const input_base_col = input_base_row + (2*tile_j - col_offset)*n_channels; + T* const matrix_base = outptr_base + batch*matrix_batch_stride + (tile_i*tile_N + tile_j)*matrix_row_stride; + f(n_channels, input_base_col, input_row_stride, input_col_stride, + matrix_base, matrix_stride); + } + } + } +} + +template +template +void winograd::Winograd2x2_3x3GemmInputChannelwise::process_tile( + int n_channels, // Number of channels in the tile + const T* const input_base, + const int input_row_stride, + const int input_col_stride, + T* const matrix_base, + const int matrix_stride +) { + // Copy pointers + const T *inptr = input_base; + T *outptr = matrix_base; + + // Process channels (modifies inptr, outptr and n_channels) + _process_tile( + n_channels, inptr, input_row_stride, input_col_stride, + outptr, matrix_stride + ); + _process_tile( + n_channels, inptr, input_row_stride, input_col_stride, + outptr, matrix_stride + ); + _process_tile( + n_channels, inptr, input_row_stride, input_col_stride, + outptr, matrix_stride + ); +} + +template +template +void winograd::Winograd2x2_3x3GemmInputChannelwise::_process_tile( + int &n_channels, + const T* &inptr, const int input_row_stride, const int input_col_stride, + T* &outptr, const int matrix_stride +) { + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + T* outptrs[4] = { + outptr, + outptr + matrix_stride * 4, + outptr + matrix_stride * 8, + outptr + matrix_stride * 12 + }; + + // The matrix X; zeroed to account for padding. + T x[4][4]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + x[i][j] = 0; + } + } + + // The matrices X.T x and U + T XTx[4][4], U[4][4]; + + // Now progress through each channel + for (; n_channels >= proc_channels; n_channels -= proc_channels) { + for (int n = 0; n < proc_channels; n++) { + // Load the matrix X + for (int cell_i = pad_top, i = 0; cell_i < 4 - pad_bottom; cell_i++, i++) { + for (int cell_j = pad_left, j = 0; cell_j < 4 - pad_right; cell_j++, j++) { + x[cell_i][cell_j] = inptr[i*input_row_stride + j*input_col_stride]; + } + } + inptr++; + + // Compute the matrix X.T + for (int j = 0; j < 4; j++) { + XTx[0][j] = x[0][j] - x[2][j]; + XTx[1][j] = x[1][j] + x[2][j]; + XTx[2][j] = x[2][j] - x[1][j]; + XTx[3][j] = x[1][j] - x[3][j]; + } + + // Hence compute the matrix U + for (int i = 0; i < 4; i++) { + U[i][0] = XTx[i][0] - XTx[i][2]; + U[i][1] = XTx[i][1] + XTx[i][2]; + U[i][2] = XTx[i][2] - XTx[i][1]; + U[i][3] = XTx[i][1] - XTx[i][3]; + } + + // Store the matrix U + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + outptrs[i][j * matrix_stride] = U[i][j]; + } + outptrs[i]++; + } + } + } + + // Update the output pointer for future calls + outptr = outptrs[0]; +} diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp new file mode 100644 index 0000000000..a99cbe325b --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp @@ -0,0 +1,1498 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include "../input_2x2_3x3.hpp" + +#ifdef __aarch64__ +namespace winograd { + +// Pad left by one column, pad right by one column, no upper or lower padding, 4 channels +template <> +template <> +inline void Winograd2x2_3x3GemmInput::process_tile_row<0, 1, 0, 1, 4>( + const int tile_N, // Number of tiles in the row + const float* const input, // Base input pointer (appropriate to batch, channel and row) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + float* const matrix, // 1st output matrix (appropriate to batch, channel and row) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix +) { + /* SIMD register allocation + * ======================== + * + * In the following code we read 4x4 tiles of a matrix `x`, with which we + * compute another matrix `X.T x` where: + * + * / 1 0 0 0 \ + * X = | 0 1 -1 1 | + * | -1 1 1 0 | + * \ 0 0 0 -1 / + * + * Hence, `X.T` is a program which operates upon rows of the matrix `X`. + * We subsequently compute and store the matrix `U = (X.T x) X`. + * + * Importantly, each iteration of the loop below loads a new matrix `x'` + * where the final two columns of `x'` are the first two columns of the + * previous `x`. That is: + * + * x11 x12 x13 x14 + * x21 x22 x23 x24 + * x31 x32 x33 x34 + * x41 x42 x43 x44 + * + * x'11 x'12 x'13 x'14 + * x'21 x'22 x'23 x'24 + * x'31 x'32 x'33 x'34 + * x'41 x'42 x'43 x'44 + * + * Consequently, while the first iteration of the below loop must load 16 + * values for `x`, the second need load only 8. *Furthermore*, since we noted + * above that the operation `X.T x` was a program which operated upon *rows* + * of the matrix `x` it follows that that the relation that `x'[i][1] = + * x[i][3]` and `x'[i][2] = x[i][4]` applies also the matrices `X.T x'` and + * `X.T x`. That is: + * + * (X.T x)11 (X.T x)12 (X.T x)13 (X.T x)14 + * (X.T x)21 (X.T x)22 (X.T x)23 (X.T x)24 + * (X.T x)31 (X.T x)32 (X.T x)33 (X.T x)34 + * (X.T x)41 (X.T x)42 (X.T x)43 (X.T x)44 + * + * (X.T x')11 (X.T x')12 (X.T x')13 (X.T x')14 + * (X.T x')12 (X.T x')12 (X.T x')12 (X.T x')12 + * (X.T x')13 (X.T x')13 (X.T x')13 (X.T x')13 + * (X.T x')14 (X.T x')14 (X.T x')14 (X.T x')14 + * + * Hence, as well as not needing to load new values for x'[i][1..2] it is + * also unnecessary to recompute values for (X.T x')[i][1..2]. + * + * Following this we break the registers into blocks `A` and `B` used by the + * two stages of the unrolled loop. These registers named such that the + * latter columns of `A` become the earlier columns of `B` and vice-versa: + * + * AXTx11 AXTx12 > AXTx13 AXTx14 | + * AXTx21 AXTx22 > AXTx23 AXTx24 | + * AXTx31 AXTx32 > AXTx33 AXTx34 | + * AXTx41 AXTx42 > AXTx43 AXTx44 | + * + * BXTx13 BXTx14 | BXTx11 BXTx12 > + * BXTx23 BXTx24 | BXTx21 BXTx22 > + * BXTx33 BXTx34 | BXTx31 BXTx32 > + * BXTx43 BXTx44 | BXTx41 BXTx42 > + * + * These 32 named registers require only 16 architectural registers. 1 + * additional architectural register is used as scratch space and 8 + * architectural registers are used to load in the values x[1..4][3,4]. + * + * Input and output addressing + * =========================== + * TODO Description + */ + const float *inptr0 = input; + const float *inptr1 = input + input_row_stride; + const float *inptr2 = input + input_row_stride * 2; + const float *inptr3 = input + input_row_stride * 3; + + float *outptr0 = matrix; + float *outptr4 = matrix + matrix_stride * 4; + float *outptr8 = matrix + matrix_stride * 8; + float *outptr12 = matrix + matrix_stride * 12; + + int tile_j = tile_N; // Tiles to process + + asm volatile ( + // Named SIMD registers according to the policy given above + // Registers into which to load the latter two columns of `x` + "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n" + "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" + "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" + "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n" + + // Registers for storing X.T x (both A and B halves) + "AXTx11 .req v8\n" "BXTx13 .req v8\n" + "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" + "AXTx21 .req v10\n" "BXTx23 .req v10\n" + "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" + "AXTx31 .req v12\n" "BXTx33 .req v12\n" + "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" + "AXTx41 .req v14\n" "BXTx43 .req v14\n" + "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" + "AXTx13 .req v16\n" "BXTx11 .req v16\n" + "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" + "AXTx23 .req v18\n" "BXTx21 .req v18\n" + "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" + "AXTx33 .req v20\n" "BXTx31 .req v20\n" + "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" + "AXTx43 .req v22\n" "BXTx41 .req v22\n" + "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" + + // Result register (TODO Does using more registers yield better + // performance) + "U .req v24\n qU .req q24\n" + + // ---------------------------------------------------------------------- + // Head of loop + // Loads a complete 4x4 tile of x, computes X.T x, computes and stores + // `U = X.T x X`. Prepares for the 'A' half of the loop. + // NOTE: Since the first tile has the leftmost column padded we can + // skip 4 loads and 4 calculations for the matrix X.T x X. + + // Temporarily alias registers for computing the first (non-padded) + // column of x. + "x_12 .req v0\n qx_12 .req q0\n" + "x_22 .req v1\n qx_22 .req q1\n" + "x_32 .req v2\n qx_32 .req q2\n" + "x_42 .req v3\n qx_42 .req q3\n" + + "ldr qx_12, [%x[inptr0]]\n" + "ldr qx_22, [%x[inptr1]]\n" + "ldr qx_32, [%x[inptr2]]\n" + "ldr qx_42, [%x[inptr3]]\n" + + "fsub BXTx12.4s, x_12.4s, x_32.4s\n" + "fadd BXTx22.4s, x_22.4s, x_32.4s\n" + "fsub BXTx32.4s, x_32.4s, x_22.4s\n" + "fsub BXTx42.4s, x_22.4s, x_42.4s\n" + + ".unreq x_12\n .unreq qx_12\n" + ".unreq x_22\n .unreq qx_22\n" + ".unreq x_32\n .unreq qx_32\n" + ".unreq x_42\n .unreq qx_42\n" + + // Load and compute latter two columns of the first tile. Progress the + // input pointers (by three columns so that the each points are the + // second column of the next tile, that is, each points at the first + // column which must be read for the next tile. + "ldr qx_13, [%x[inptr0], %x[colstride1]]\n" + "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" + "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" + "ldr qx_43, [%x[inptr3], %x[colstride1]]\n" + + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride2]]\n" + + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" + + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" + + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride2]]\n" + + "fsub BXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride3]\n" + + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride3]\n" + + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride3]\n" + + "fsub BXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride3]\n" + + // Compute and store U for the first tile + // First row + "fneg U.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fneg U.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fneg U.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row, simultaneously load the first column of inputs for the + // next tile. + "fneg U.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + // Update the loop counter, subtract two to account for both the head and + // the tail. + "subs %x[tile_j], %x[tile_j], #2\n" + "beq 2f\n" // Jump to "A" tail if out of tiles + + // ---------------------------------------------------------------------- + "1:" + // Start part A + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fsub AXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "fsub AXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" + "fsub AXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride2]\n" + "fadd AXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub AXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "fsub AXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride2]\n" + + // Compute and store U. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, AXTx12.4s, AXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, AXTx22.4s, AXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, AXTx32.4s, AXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, AXTx42.4s, AXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + "subs %x[tile_j], %x[tile_j], #1\n" + "beq 3f\n" // Jump to 'B' tail + + // Start part B + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" + "fsub BXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride2]\n" + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "fsub BXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride2]\n" + + // Compute and store U. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + "subs %x[tile_j], %x[tile_j], #1\n" + "bne 1b\n" // Continue loop, otherwise flow into 'A' tail + + // ---------------------------------------------------------------------- + "2:" + // 'A' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fsub AXTx13.4s, x_13.4s, x_33.4s\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "fsub AXTx43.4s, x_23.4s, x_43.4s\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" + + "b 4f\n" // Jump to end of function + + // ---------------------------------------------------------------------- + "3:" + // 'B' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" + + // ---------------------------------------------------------------------- + "4:" + // End of function + + // Clear names + ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n" + ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" + ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" + ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n" + ".unreq AXTx11\n" ".unreq BXTx13\n" + ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" + ".unreq AXTx21\n" ".unreq BXTx23\n" + ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" + ".unreq AXTx31\n" ".unreq BXTx33\n" + ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" + ".unreq AXTx41\n" ".unreq BXTx43\n" + ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" + ".unreq AXTx13\n" ".unreq BXTx11\n" + ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" + ".unreq AXTx23\n" ".unreq BXTx21\n" + ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" + ".unreq AXTx33\n" ".unreq BXTx31\n" + ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" + ".unreq AXTx43\n" ".unreq BXTx41\n" + ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" + ".unreq U\n" ".unreq qU\n" + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [inptr3] "+r" (inptr3), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [tile_j] "+r" (tile_j) // Tile counter + : [colstride1] "r" (1 * input_col_stride * sizeof(float)), + [colstride2] "r" (2 * input_col_stride * sizeof(float)), + [colstride3] "r" (3 * input_col_stride * sizeof(float)), + [mstride1] "r" (1 * matrix_stride * sizeof(float)), + [mstride2] "r" (2 * matrix_stride * sizeof(float)), + [mstride3] "r" (3 * matrix_stride * sizeof(float)), + [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24" + ); +} + +// Pad top, left and right by 1. +template <> +template <> +inline void Winograd2x2_3x3GemmInput::process_tile_row<1, 1, 0, 1, 4>( + const int tile_N, + const float* const input, + const int input_row_stride, + const int input_col_stride, + float* const matrix, + const int matrix_stride, + const int matrix_row_stride +) { + const float *inptr0 = input; + const float *inptr1 = input + input_row_stride; + const float *inptr2 = input + input_row_stride * 2; + + float *outptr0 = matrix; + float *outptr4 = matrix + matrix_stride * 4; + float *outptr8 = matrix + matrix_stride * 8; + float *outptr12 = matrix + matrix_stride * 12; + + int tile_j = tile_N; // Tiles to process + + asm volatile ( + // Named SIMD registers according to the policy given above + // Registers into which to load the latter two columns of `x` + // NOTE: We need only load the latter three rows since we know that the + // first row is padded. + "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" + "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" + "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n" + + // Registers for storing X.T x (both A and B halves) + "AXTx11 .req v8\n" "BXTx13 .req v8\n" + "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" + "AXTx21 .req v10\n" "BXTx23 .req v10\n" + "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" + "AXTx31 .req v12\n" "BXTx33 .req v12\n" + "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" + "AXTx41 .req v14\n" "BXTx43 .req v14\n" + "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" + "AXTx13 .req v16\n" "BXTx11 .req v16\n" + "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" + "AXTx23 .req v18\n" "BXTx21 .req v18\n" + "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" + "AXTx33 .req v20\n" "BXTx31 .req v20\n" + "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" + "AXTx43 .req v22\n" "BXTx41 .req v22\n" + "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" + + // Result register (TODO Does using more registers yield better + // performance) + "U .req v24\n qU .req q24\n" + + // ---------------------------------------------------------------------- + // Head of loop + // Loads a complete 4x4 tile of x, computes X.T x, computes and stores + // `U = X.T x X`. Prepares for the 'A' half of the loop. + // NOTE: Since the first tile has the leftmost column padded we can + // skip 4 loads and 4 calculations for the matrix X.T x X. + + // Temporarily alias registers for computing the first (non-padded) + // column of x. + "x_22 .req v1\n qx_22 .req q1\n" + "x_32 .req v2\n qx_32 .req q2\n" + "x_42 .req v3\n qx_42 .req q3\n" + + "ldr qx_22, [%x[inptr1]]\n" + "ldr qx_32, [%x[inptr2]]\n" + "ldr qx_42, [%x[inptr3]]\n" + + "fneg BXTx12.4s, x_32.4s\n" + "fadd BXTx22.4s, x_22.4s, x_32.4s\n" + "fsub BXTx32.4s, x_32.4s, x_22.4s\n" + "fsub BXTx42.4s, x_22.4s, x_42.4s\n" + + ".unreq x_22\n .unreq qx_22\n" + ".unreq x_32\n .unreq qx_32\n" + ".unreq x_42\n .unreq qx_42\n" + + // Load and compute latter two columns of the first tile. Progress the + // input pointers (by three columns so that the each points are the + // second column of the next tile, that is, each points at the first + // column which must be read for the next tile. + "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" + "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" + "ldr qx_43, [%x[inptr3], %x[colstride1]]\n" + + "fneg BXTx13.4s, x_33.4s\n" + + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" + + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" + + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride2]]\n" + + "fneg BXTx14.4s, x_34.4s\n" + + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride3]\n" + + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride3]\n" + + "fsub BXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride3]\n" + + // Compute and store U for the first tile + // First row + "fneg U.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fneg U.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fneg U.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row, simultaneously load the first column of inputs for the + // next tile. + "fneg U.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + // Update the loop counter, subtract two to account for both the head and + // the tail. + "subs %x[tile_j], %x[tile_j], #2\n" + "beq 2f\n" // Jump to "A" tail if out of tiles + + // ---------------------------------------------------------------------- + "1:" + // Start part A + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fneg AXTx13.4s, x_33.4s\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "fsub AXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" + "fneg AXTx14.4s, x_34.4s\n" + "fadd AXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub AXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "fsub AXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride2]\n" + + // Compute and store U. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, AXTx12.4s, AXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, AXTx22.4s, AXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, AXTx32.4s, AXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, AXTx42.4s, AXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + "subs %x[tile_j], %x[tile_j], #1\n" + "beq 3f\n" // Jump to 'B' tail + + // Start part B + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fneg BXTx13.4s, x_33.4s\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" + "fneg BXTx14.4s, x_34.4s\n" + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "fsub BXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride2]\n" + + // Compute and store U. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + "subs %x[tile_j], %x[tile_j], #1\n" + "bne 1b\n" // Continue loop, otherwise flow into 'A' tail + + // ---------------------------------------------------------------------- + "2:" + // 'A' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fneg AXTx13.4s, x_33.4s\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "fsub AXTx43.4s, x_23.4s, x_43.4s\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" + + "b 4f\n" // Jump to end of function + + // ---------------------------------------------------------------------- + "3:" + // 'B' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fneg BXTx13.4s, x_33.4s\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" + + // ---------------------------------------------------------------------- + "4:" + // End of function + + // Clear names + ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" + ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" + ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n" + ".unreq AXTx11\n" ".unreq BXTx13\n" + ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" + ".unreq AXTx21\n" ".unreq BXTx23\n" + ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" + ".unreq AXTx31\n" ".unreq BXTx33\n" + ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" + ".unreq AXTx41\n" ".unreq BXTx43\n" + ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" + ".unreq AXTx13\n" ".unreq BXTx11\n" + ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" + ".unreq AXTx23\n" ".unreq BXTx21\n" + ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" + ".unreq AXTx33\n" ".unreq BXTx31\n" + ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" + ".unreq AXTx43\n" ".unreq BXTx41\n" + ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" + ".unreq U\n" ".unreq qU\n" + : [inptr1] "+r" (inptr0), // Offset to account for padded row + [inptr2] "+r" (inptr1), // Offset to account for padded row + [inptr3] "+r" (inptr2), // Offset to account for padded row + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [tile_j] "+r" (tile_j) // Tile counter + : [colstride1] "r" (1 * input_col_stride * sizeof(float)), + [colstride2] "r" (2 * input_col_stride * sizeof(float)), + [colstride3] "r" (3 * input_col_stride * sizeof(float)), + [mstride1] "r" (1 * matrix_stride * sizeof(float)), + [mstride2] "r" (2 * matrix_stride * sizeof(float)), + [mstride3] "r" (3 * matrix_stride * sizeof(float)), + [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24" + ); +} + +// Pad left, right and bottom by 1. +template <> +template <> +inline void Winograd2x2_3x3GemmInput::process_tile_row<0, 1, 1, 1, 4>( + const int tile_N, + const float* const input, + const int input_row_stride, + const int input_col_stride, + float* const matrix, + const int matrix_stride, + const int matrix_row_stride +) { + const float *inptr0 = input; + const float *inptr1 = input + input_row_stride; + const float *inptr2 = input + input_row_stride * 2; + + float *outptr0 = matrix; + float *outptr4 = matrix + matrix_stride * 4; + float *outptr8 = matrix + matrix_stride * 8; + float *outptr12 = matrix + matrix_stride * 12; + + int tile_j = tile_N; // Tiles to process + + asm volatile ( + // Named SIMD registers according to the policy given above + // Registers into which to load the latter two columns of `x` + // NOTE: Bottom row is not required since since it is padded. + "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n" + "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" + "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" + + // Registers for storing X.T x (both A and B halves) + "AXTx11 .req v8\n" "BXTx13 .req v8\n" + "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" + "AXTx21 .req v10\n" "BXTx23 .req v10\n" + "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" + "AXTx31 .req v12\n" "BXTx33 .req v12\n" + "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" + "AXTx41 .req v14\n" "BXTx43 .req v14\n" + "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" + "AXTx13 .req v16\n" "BXTx11 .req v16\n" + "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" + "AXTx23 .req v18\n" "BXTx21 .req v18\n" + "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" + "AXTx33 .req v20\n" "BXTx31 .req v20\n" + "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" + "AXTx43 .req v22\n" "BXTx41 .req v22\n" + "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" + + // Result register (TODO Does using more registers yield better + // performance) + "U .req v24\n qU .req q24\n" + + // ---------------------------------------------------------------------- + // Head of loop + // Loads a complete 4x4 tile of x, computes X.T x, computes and stores + // `U = X.T x X`. Prepares for the 'A' half of the loop. + // NOTE: Since the first tile has the leftmost column padded we can + // skip 4 loads and 4 calculations for the matrix X.T x X. + + // Temporarily alias registers for computing the first (non-padded) + // column of x. + "x_12 .req v0\n qx_12 .req q0\n" + "x_22 .req v1\n qx_22 .req q1\n" + "x_32 .req v2\n qx_32 .req q2\n" + + "ldr qx_12, [%x[inptr0]]\n" + "ldr qx_22, [%x[inptr1]]\n" + "ldr qx_32, [%x[inptr2]]\n" + + "fsub BXTx12.4s, x_12.4s, x_32.4s\n" + "fadd BXTx22.4s, x_22.4s, x_32.4s\n" + "fsub BXTx32.4s, x_32.4s, x_22.4s\n" + "mov BXTx42.16b, x_22.16b\n" // Probably should do better + + ".unreq x_12\n .unreq qx_12\n" + ".unreq x_22\n .unreq qx_22\n" + ".unreq x_32\n .unreq qx_32\n" + + // Load and compute latter two columns of the first tile. Progress the + // input pointers (by three columns so that the each points are the + // second column of the next tile, that is, each points at the first + // column which must be read for the next tile. + "ldr qx_13, [%x[inptr0], %x[colstride1]]\n" + "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" + "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" + + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride2]]\n" + + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" + + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" + + "mov BXTx43.16b, x_23.16b\n" + "fsub BXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride3]\n" + + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride3]\n" + + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride3]\n" + + "mov BXTx44.16b, x_24.16b\n" + + // Compute and store U for the first tile + // First row + "fneg U.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fneg U.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fneg U.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row, simultaneously load the first column of inputs for the + // next tile. + "fneg U.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + // Update the loop counter, subtract two to account for both the head and + // the tail. + "subs %x[tile_j], %x[tile_j], #2\n" + "beq 2f\n" // Jump to "A" tail if out of tiles + + // ---------------------------------------------------------------------- + "1:" + // Start part A + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fsub AXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "mov AXTx43.16b, x_23.16b\n" + + "fsub AXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride2]\n" + "fadd AXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub AXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "mov AXTx44.16b, x_24.16b\n" + + // Compute and store U. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, AXTx12.4s, AXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, AXTx22.4s, AXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, AXTx32.4s, AXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, AXTx42.4s, AXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + "subs %x[tile_j], %x[tile_j], #1\n" + "beq 3f\n" // Jump to 'B' tail + + // Start part B + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "mov BXTx43.16b, x_23.16b\n" + + "fsub BXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride2]\n" + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "mov BXTx44.16b, x_24.16b\n" + + // Compute and store U. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + "subs %x[tile_j], %x[tile_j], #1\n" + "bne 1b\n" // Continue loop, otherwise flow into 'A' tail + + // ---------------------------------------------------------------------- + "2:" + // 'A' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fsub AXTx13.4s, x_13.4s, x_33.4s\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "mov AXTx43.16b, x_23.16b\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" + + "b 4f\n" // Jump to end of function + + // ---------------------------------------------------------------------- + "3:" + // 'B' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "mov BXTx43.16b, x_23.16b\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" + + // ---------------------------------------------------------------------- + "4:" + // End of function + + // Clear names + ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n" + ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" + ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" + ".unreq AXTx11\n" ".unreq BXTx13\n" + ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" + ".unreq AXTx21\n" ".unreq BXTx23\n" + ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" + ".unreq AXTx31\n" ".unreq BXTx33\n" + ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" + ".unreq AXTx41\n" ".unreq BXTx43\n" + ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" + ".unreq AXTx13\n" ".unreq BXTx11\n" + ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" + ".unreq AXTx23\n" ".unreq BXTx21\n" + ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" + ".unreq AXTx33\n" ".unreq BXTx31\n" + ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" + ".unreq AXTx43\n" ".unreq BXTx41\n" + ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" + ".unreq U\n" ".unreq qU\n" + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [tile_j] "+r" (tile_j) // Tile counter + : [colstride1] "r" (1 * input_col_stride * sizeof(float)), + [colstride2] "r" (2 * input_col_stride * sizeof(float)), + [colstride3] "r" (3 * input_col_stride * sizeof(float)), + [mstride1] "r" (1 * matrix_stride * sizeof(float)), + [mstride2] "r" (2 * matrix_stride * sizeof(float)), + [mstride3] "r" (3 * matrix_stride * sizeof(float)), + [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24" + ); +} +} +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp new file mode 100644 index 0000000000..ad1ad55291 --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp @@ -0,0 +1,961 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include "../input_2x2_3x3.hpp" + +#ifdef __aarch64__ + +namespace winograd { + +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 0, 0, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 1*input_row_stride; + auto inptr2 = inptr0 + 2*input_row_stride; + auto inptr3 = inptr0 + 3*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_11 .req v0\n" "qX_11 .req q0\n" + "X_12 .req v1\n" "qX_12 .req q1\n" + "X_13 .req v2\n" "qX_13 .req q2\n" + "X_14 .req v3\n" "qX_14 .req q3\n" + "X_21 .req v4\n" "qX_21 .req q4\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_24 .req v7\n" "qX_24 .req q7\n" + "X_31 .req v8\n" "qX_31 .req q8\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_34 .req v11\n" "qX_34 .req q11\n" + "X_41 .req v12\n" "qX_41 .req q12\n" + "X_42 .req v13\n" "qX_42 .req q13\n" + "X_43 .req v14\n" "qX_43 .req q14\n" + "X_44 .req v15\n" "qX_44 .req q15\n" + "xX_11 .req v16\n" + "xX_12 .req v17\n" + "xX_13 .req v18\n" + "xX_14 .req v19\n" + "xX_21 .req v20\n" + "xX_22 .req v21\n" + "xX_23 .req v22\n" + "xX_24 .req v23\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req v27\n" + "xX_41 .req v28\n" + "xX_42 .req v29\n" + "xX_43 .req v30\n" + "xX_44 .req v31\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_11, [%x[inptr0]]\n" + "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" + "ldr qX_14, [%x[inptr0], %x[colstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qX_21, [%x[inptr1]]\n" + "fsub xX_11.4s, x_11.4s, x_13.4s\n" + "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" + "fadd xX_12.4s, x_12.4s, x_13.4s\n" + "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" + "fsub xX_13.4s, x_13.4s, x_12.4s\n" + "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" + "fsub xX_14.4s, x_12.4s, x_14.4s\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qX_31, [%x[inptr2]]\n" + "fsub xX_21.4s, x_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" + "fsub xX_24.4s, x_22.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "ldr qX_41, [%x[inptr3]]\n" + "fsub xX_31.4s, x_31.4s, x_33.4s\n" + "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "ldr qX_44, [%x[inptr3], %x[colstride3]]\n" + "fsub xX_34.4s, x_32.4s, x_34.4s\n" + "add %x[inptr3], %x[inptr3], #0x10\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fsub xX_41.4s, x_41.4s, x_43.4s\n" + + "fsub U.4s, xX_11.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fsub U.4s, xX_12.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, xX_13.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, xX_14.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd xX_42.4s, x_42.4s, x_43.4s\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub xX_43.4s, x_43.4s, x_42.4s\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fsub xX_44.4s, x_42.4s, x_44.4s\n" + + "fsub U.4s, xX_21.4s, xX_41.4s\n" + "str qU, [%x[outptr12]]\n" + "fsub U.4s, xX_22.4s, xX_42.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, xX_23.4s, xX_43.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "fsub U.4s, xX_24.4s, xX_44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq qU\n" + ".unreq U\n" + ".unreq X_11\n" ".unreq qX_11\n" + ".unreq X_12\n" ".unreq qX_12\n" + ".unreq X_13\n" ".unreq qX_13\n" + ".unreq X_14\n" ".unreq qX_14\n" + ".unreq X_21\n" ".unreq qX_21\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_24\n" ".unreq qX_24\n" + ".unreq X_31\n" ".unreq qX_31\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_34\n" ".unreq qX_34\n" + ".unreq X_41\n" ".unreq qX_41\n" + ".unreq X_42\n" ".unreq qX_42\n" + ".unreq X_43\n" ".unreq qX_43\n" + ".unreq X_44\n" ".unreq qX_44\n" + ".unreq xX_11\n" + ".unreq xX_12\n" + ".unreq xX_13\n" + ".unreq xX_14\n" + ".unreq xX_21\n" + ".unreq xX_22\n" + ".unreq xX_23\n" + ".unreq xX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + ".unreq xX_41\n" + ".unreq xX_42\n" + ".unreq xX_43\n" + ".unreq xX_44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [inptr3] "+r" (inptr3), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [colstride3] "r" (input_col_stride * sizeof(float) * 3), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} + +// Pad top by 1 +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<1, 0, 0, 0, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 0*input_row_stride; + auto inptr2 = inptr0 + 1*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_21 .req v4\n" "qX_21 .req q4\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_24 .req v7\n" "qX_24 .req q7\n" + "X_31 .req v8\n" "qX_31 .req q8\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_34 .req v11\n" "qX_34 .req q11\n" + "X_41 .req v12\n" "qX_41 .req q12\n" + "X_42 .req v13\n" "qX_42 .req q13\n" + "X_43 .req v14\n" "qX_43 .req q14\n" + "X_44 .req v15\n" "qX_44 .req q15\n" + "xX_21 .req v20\n" + "xX_22 .req v21\n" + "xX_23 .req v22\n" + "xX_24 .req v23\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req v27\n" + "xX_41 .req v28\n" + "xX_42 .req v29\n" + "xX_43 .req v30\n" + "xX_44 .req v31\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_21, [%x[inptr1]]\n" + "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" + "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" + "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qX_31, [%x[inptr2]]\n" + "fsub xX_21.4s, x_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" + "fsub xX_24.4s, x_22.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "ldr qX_41, [%x[inptr3]]\n" + "fsub xX_31.4s, x_31.4s, x_33.4s\n" + "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "ldr qX_44, [%x[inptr3], %x[colstride3]]\n" + "fsub xX_34.4s, x_32.4s, x_34.4s\n" + "add %x[inptr3], %x[inptr3], #0x10\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fsub xX_41.4s, x_41.4s, x_43.4s\n" + + "fneg U.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fneg U.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fneg U.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fneg U.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd xX_42.4s, x_42.4s, x_43.4s\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub xX_43.4s, x_43.4s, x_42.4s\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fsub xX_44.4s, x_42.4s, x_44.4s\n" + + "fsub U.4s, xX_21.4s, xX_41.4s\n" + "str qU, [%x[outptr12]]\n" + "fsub U.4s, xX_22.4s, xX_42.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, xX_23.4s, xX_43.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "fsub U.4s, xX_24.4s, xX_44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq qU\n" + ".unreq U\n" + ".unreq X_21\n" ".unreq qX_21\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_24\n" ".unreq qX_24\n" + ".unreq X_31\n" ".unreq qX_31\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_34\n" ".unreq qX_34\n" + ".unreq X_41\n" ".unreq qX_41\n" + ".unreq X_42\n" ".unreq qX_42\n" + ".unreq X_43\n" ".unreq qX_43\n" + ".unreq X_44\n" ".unreq qX_44\n" + ".unreq xX_21\n" + ".unreq xX_22\n" + ".unreq xX_23\n" + ".unreq xX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + ".unreq xX_41\n" + ".unreq xX_42\n" + ".unreq xX_43\n" + ".unreq xX_44\n" + + : [inptr1] "+r" (inptr0), // Offset for missing row + [inptr2] "+r" (inptr1), // Offset for missing row + [inptr3] "+r" (inptr2), // Offset for missing row + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [colstride3] "r" (input_col_stride * sizeof(float) * 3), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} + +// Pad left by 1 +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 1, 0, 0, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 1*input_row_stride; + auto inptr2 = inptr0 + 2*input_row_stride; + auto inptr3 = inptr0 + 3*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_12 .req v1\n" "qX_12 .req q1\n" + "X_13 .req v2\n" "qX_13 .req q2\n" + "X_14 .req v3\n" "qX_14 .req q3\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_24 .req v7\n" "qX_24 .req q7\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_34 .req v11\n" "qX_34 .req q11\n" + "X_42 .req v13\n" "qX_42 .req q13\n" + "X_43 .req v14\n" "qX_43 .req q14\n" + "X_44 .req v15\n" "qX_44 .req q15\n" + "xX_11 .req v16\n" + "xX_12 .req v17\n" + "xX_13 .req v18\n" + "xX_14 .req v19\n" + "xX_21 .req v20\n" + "xX_22 .req v21\n" + "xX_23 .req v22\n" + "xX_24 .req v23\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req v27\n" + "xX_41 .req v28\n" + "xX_42 .req v29\n" + "xX_43 .req v30\n" + "xX_44 .req v31\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_12, [%x[inptr0]]\n" + "ldr qX_13, [%x[inptr0], %x[colstride1]]\n" + "ldr qX_14, [%x[inptr0], %x[colstride2]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "fneg xX_11.4s, x_13.4s\n" + "ldr qX_22, [%x[inptr1]]\n" + "fadd xX_12.4s, x_12.4s, x_13.4s\n" + "ldr qX_23, [%x[inptr1], %x[colstride1]]\n" + "fsub xX_13.4s, x_13.4s, x_12.4s\n" + "ldr qX_24, [%x[inptr1], %x[colstride2]]\n" + "fsub xX_14.4s, x_12.4s, x_14.4s\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "fneg xX_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride1]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "ldr qX_34, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_24.4s, x_22.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "fneg xX_31.4s, x_33.4s\n" + "ldr qX_42, [%x[inptr3]]\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "ldr qX_43, [%x[inptr3], %x[colstride1]]\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "ldr qX_44, [%x[inptr3], %x[colstride2]]\n" + "fsub xX_34.4s, x_32.4s, x_34.4s\n" + "add %x[inptr3], %x[inptr3], #0x10\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fneg xX_41.4s, x_43.4s\n" + + "fsub U.4s, xX_11.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fsub U.4s, xX_12.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, xX_13.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, xX_14.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd xX_42.4s, x_42.4s, x_43.4s\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub xX_43.4s, x_43.4s, x_42.4s\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fsub xX_44.4s, x_42.4s, x_44.4s\n" + + "fsub U.4s, xX_21.4s, xX_41.4s\n" + "str qU, [%x[outptr12]]\n" + "fsub U.4s, xX_22.4s, xX_42.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, xX_23.4s, xX_43.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "fsub U.4s, xX_24.4s, xX_44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq X_12\n" ".unreq qX_12\n" + ".unreq X_13\n" ".unreq qX_13\n" + ".unreq X_14\n" ".unreq qX_14\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_24\n" ".unreq qX_24\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_34\n" ".unreq qX_34\n" + ".unreq X_42\n" ".unreq qX_42\n" + ".unreq X_43\n" ".unreq qX_43\n" + ".unreq X_44\n" ".unreq qX_44\n" + ".unreq xX_11\n" + ".unreq xX_12\n" + ".unreq xX_13\n" + ".unreq xX_14\n" + ".unreq xX_21\n" + ".unreq xX_22\n" + ".unreq xX_23\n" + ".unreq xX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + ".unreq xX_41\n" + ".unreq xX_42\n" + ".unreq xX_43\n" + ".unreq xX_44\n" + ".unreq U\n" + ".unreq qU\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [inptr3] "+r" (inptr3), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} + +// Pad bottom by 1 +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 1, 0, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 1*input_row_stride; + auto inptr2 = inptr0 + 2*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_11 .req v0\n" "qX_11 .req q0\n" + "X_12 .req v1\n" "qX_12 .req q1\n" + "X_13 .req v2\n" "qX_13 .req q2\n" + "X_14 .req v3\n" "qX_14 .req q3\n" + "X_21 .req v4\n" "qX_21 .req q4\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_24 .req v7\n" "qX_24 .req q7\n" + "X_31 .req v8\n" "qX_31 .req q8\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_34 .req v11\n" "qX_34 .req q11\n" + "xX_11 .req v16\n" + "xX_12 .req v17\n" + "xX_13 .req v18\n" + "xX_14 .req v19\n" + "xX_21 .req v20\n" "qxX_21 .req q20\n" + "xX_22 .req v21\n" "qxX_22 .req q21\n" + "xX_23 .req v22\n" "qxX_23 .req q22\n" + "xX_24 .req v23\n" "qxX_24 .req q23\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req v27\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_11, [%x[inptr0]]\n" + "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" + "ldr qX_14, [%x[inptr0], %x[colstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qX_21, [%x[inptr1]]\n" + "fsub xX_11.4s, x_11.4s, x_13.4s\n" + "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" + "fadd xX_12.4s, x_12.4s, x_13.4s\n" + "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" + "fsub xX_13.4s, x_13.4s, x_12.4s\n" + "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" + "fsub xX_14.4s, x_12.4s, x_14.4s\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qX_31, [%x[inptr2]]\n" + "fsub xX_21.4s, x_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" + "fsub xX_24.4s, x_22.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "fsub xX_31.4s, x_31.4s, x_33.4s\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "fsub xX_34.4s, x_32.4s, x_34.4s\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fsub U.4s, xX_11.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fsub U.4s, xX_12.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, xX_13.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, xX_14.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "str qxX_21, [%x[outptr12]]\n" + "str qxX_22, [%x[outptr12], %x[mstride1]]\n" + "str qxX_23, [%x[outptr12], %x[mstride2]]\n" + "str qxX_24, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq qU\n" + ".unreq U\n" + ".unreq X_11\n" ".unreq qX_11\n" + ".unreq X_12\n" ".unreq qX_12\n" + ".unreq X_13\n" ".unreq qX_13\n" + ".unreq X_14\n" ".unreq qX_14\n" + ".unreq X_21\n" ".unreq qX_21\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_24\n" ".unreq qX_24\n" + ".unreq X_31\n" ".unreq qX_31\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_34\n" ".unreq qX_34\n" + ".unreq xX_11\n" + ".unreq xX_12\n" + ".unreq xX_13\n" + ".unreq xX_14\n" + ".unreq xX_21\n" ".unreq qxX_21\n" + ".unreq xX_22\n" ".unreq qxX_22\n" + ".unreq xX_23\n" ".unreq qxX_23\n" + ".unreq xX_24\n" ".unreq qxX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [colstride3] "r" (input_col_stride * sizeof(float) * 3), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} + +// Pad right by 1 +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 0, 1, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 1*input_row_stride; + auto inptr2 = inptr0 + 2*input_row_stride; + auto inptr3 = inptr0 + 3*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_11 .req v0\n" "qX_11 .req q0\n" + "X_12 .req v1\n" "qX_12 .req q1\n" + "X_13 .req v2\n" "qX_13 .req q2\n" + "X_21 .req v4\n" "qX_21 .req q4\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_31 .req v8\n" "qX_31 .req q8\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_41 .req v12\n" "qX_41 .req q12\n" + "X_42 .req v13\n" "qX_42 .req q13\n" + "X_43 .req v14\n" "qX_43 .req q14\n" + "xX_11 .req v16\n" + "xX_12 .req v17\n" + "xX_13 .req v18\n" + "xX_14 .req x_12\n" + "xX_21 .req v20\n" + "xX_22 .req v21\n" + "xX_23 .req v22\n" + "xX_24 .req x_22\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req x_32\n" + "xX_41 .req v28\n" + "xX_42 .req v29\n" + "xX_43 .req v30\n" + "xX_44 .req x_42\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_11, [%x[inptr0]]\n" + "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qX_21, [%x[inptr1]]\n" + "fsub xX_11.4s, x_11.4s, x_13.4s\n" + "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" + "fadd xX_12.4s, x_12.4s, x_13.4s\n" + "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" + "fsub xX_13.4s, x_13.4s, x_12.4s\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qX_31, [%x[inptr2]]\n" + "fsub xX_21.4s, x_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "ldr qX_41, [%x[inptr3]]\n" + "fsub xX_31.4s, x_31.4s, x_33.4s\n" + "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "add %x[inptr3], %x[inptr3], #0x10\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fsub xX_41.4s, x_41.4s, x_43.4s\n" + + "fsub U.4s, xX_11.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fsub U.4s, xX_12.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, xX_13.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, xX_14.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd xX_42.4s, x_42.4s, x_43.4s\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub xX_43.4s, x_43.4s, x_42.4s\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fsub U.4s, xX_21.4s, xX_41.4s\n" + "str qU, [%x[outptr12]]\n" + "fsub U.4s, xX_22.4s, xX_42.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, xX_23.4s, xX_43.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "fsub U.4s, xX_24.4s, xX_44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq qU\n" + ".unreq U\n" + ".unreq X_11\n" ".unreq qX_11\n" + ".unreq X_12\n" ".unreq qX_12\n" + ".unreq X_13\n" ".unreq qX_13\n" + ".unreq X_21\n" ".unreq qX_21\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_31\n" ".unreq qX_31\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_41\n" ".unreq qX_41\n" + ".unreq X_42\n" ".unreq qX_42\n" + ".unreq X_43\n" ".unreq qX_43\n" + ".unreq xX_11\n" + ".unreq xX_12\n" + ".unreq xX_13\n" + ".unreq xX_14\n" + ".unreq xX_21\n" + ".unreq xX_22\n" + ".unreq xX_23\n" + ".unreq xX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + ".unreq xX_41\n" + ".unreq xX_42\n" + ".unreq xX_43\n" + ".unreq xX_44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [inptr3] "+r" (inptr3), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} +} +#endif diff --git a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp b/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp new file mode 100644 index 0000000000..033442aa14 --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +namespace winograd { + /* Transform a kernel into the Winograd domain. + * + * NOTE: It is assumed that the kernel is in the form [height x width x + * input_channels x output_channel]. + */ + template + struct winograd2x2_3x3_gemm_kernel_transform_impl{ + static void execute( + const KernelShape &shape, + const T* const kernel, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride + ); + + protected: + template + static void transform_kernel( + const T* const kernel, + const int n_input_channels, + const int n_output_channels, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride + ); + }; +} + +/*****************************************************************************/ +/* Transform a fp32 kernel into the Winograd domain. + */ +#include "kernel_2x2_3x3/a64_float.hpp" // AArch64 specialisations + +namespace winograd +{ +template <> +inline void winograd2x2_3x3_gemm_kernel_transform_impl::execute( + const KernelShape &shape, + const float* const kernel, + float* const matrix_base, + const int matrix_stride, + const int matrix_row_stride +) { + // Delegate based on tail size + const int n_input_channels = shape.n_input_channels; + const int n_output_channels = shape.n_output_channels; + + switch (n_output_channels % 4) { + case 0: + transform_kernel<0>( + kernel, n_input_channels, n_output_channels, + matrix_base, matrix_stride, matrix_row_stride + ); + break; + case 1: + transform_kernel<1>( + kernel, n_input_channels, n_output_channels, + matrix_base, matrix_stride, matrix_row_stride + ); + break; + case 2: + transform_kernel<2>( + kernel, n_input_channels, n_output_channels, + matrix_base, matrix_stride, matrix_row_stride + ); + break; + case 3: + transform_kernel<3>( + kernel, n_input_channels, n_output_channels, + matrix_base, matrix_stride, matrix_row_stride + ); + break; + default: + ARM_COMPUTE_ERROR("Cannot happen"); + break; + } +} + +template <> +template +inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel( + const float* const kernel, + const int n_input_channels, + const int n_output_channels, + float* const matrix_base, + const int mstride, + const int matrix_row_stride +) { + // Use one input pointer for each row of the kernel, use two additional + // offsets to extract columns. + const int kernel_col_stride = n_input_channels * n_output_channels; + const int kernel_row_stride = 3 * kernel_col_stride; + const float *inptr0 = kernel; + const float *inptr1 = kernel + kernel_row_stride; + const float *inptr2 = kernel + kernel_row_stride*2; + + // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three + // offsets to extract further matrices. + float *outptr0 = matrix_base; + float *outptr4 = matrix_base + mstride * 4; + float *outptr8 = matrix_base + mstride * 8; + float *outptr12 = matrix_base + mstride * 12; + + // For every input channel + for (int in_c = 0; in_c < n_input_channels; in_c++) { + // For every output channel + for (int c = 0; c < n_output_channels; c++) { + // Read in the kernel + float w11 = inptr0[0], w12 = inptr0[kernel_col_stride], w13 = inptr0[kernel_col_stride*2]; + float w21 = inptr1[0], w22 = inptr1[kernel_col_stride], w23 = inptr1[kernel_col_stride*2]; + float w31 = inptr2[0], w32 = inptr2[kernel_col_stride], w33 = inptr2[kernel_col_stride*2]; + + // Progress input pointers + inptr0++; + inptr1++; + inptr2++; + + // Compute the kernel W w, note we need only compute the middle two rows + // (2 and 3) because the first and last rows are merely copies of values + // from the matrix w. + float Ww11 = w11, Ww12 = w12, Ww13 = w13; + float Ww21 = 0.5*(w11 + w21 + w31), Ww22 = 0.5*(w12 + w22 + w32), Ww23 = 0.5*(w13 + w23 + w33); + float Ww31 = 0.5*(w11 - w21 + w31), Ww32 = 0.5*(w12 - w22 + w32), Ww33 = 0.5*(w13 - w23 + w33); + float Ww41 = w31, Ww42 = w32, Ww43 = w33; + + // Hence compute W w W.T; again note we need compute only the middle two + // columns since the first and last columns are copies of the first and + // last columns of the previous matrix. + float WwWT11 = Ww11, WwWT12 = 0.5*(Ww11 + Ww12 + Ww13), WwWT13 = 0.5*(Ww11 - Ww12 + Ww13), WwWT14 = Ww13; + float WwWT21 = Ww21, WwWT22 = 0.5*(Ww21 + Ww22 + Ww23), WwWT23 = 0.5*(Ww21 - Ww22 + Ww23), WwWT24 = Ww23; + float WwWT31 = Ww31, WwWT32 = 0.5*(Ww31 + Ww32 + Ww33), WwWT33 = 0.5*(Ww31 - Ww32 + Ww33), WwWT34 = Ww33; + float WwWT41 = Ww41, WwWT42 = 0.5*(Ww41 + Ww42 + Ww43), WwWT43 = 0.5*(Ww41 - Ww42 + Ww43), WwWT44 = Ww43; + + // Store the computed weights + outptr0[0 * mstride] = WwWT11; + outptr0[1 * mstride] = WwWT12; + outptr0[2 * mstride] = WwWT13; + outptr0[3 * mstride] = WwWT14; + + outptr4[0 * mstride] = WwWT21; + outptr4[1 * mstride] = WwWT22; + outptr4[2 * mstride] = WwWT23; + outptr4[3 * mstride] = WwWT24; + + outptr8[0 * mstride] = WwWT31; + outptr8[1 * mstride] = WwWT32; + outptr8[2 * mstride] = WwWT33; + outptr8[3 * mstride] = WwWT34; + + outptr12[0 * mstride] = WwWT41; + outptr12[1 * mstride] = WwWT42; + outptr12[2 * mstride] = WwWT43; + outptr12[3 * mstride] = WwWT44; + + // Progress output pointers + outptr0++; + outptr4++; + outptr8++; + outptr12++; + } + + // Progression to complete stride + outptr0 += matrix_row_stride - n_output_channels; + outptr4 += matrix_row_stride - n_output_channels; + outptr8 += matrix_row_stride - n_output_channels; + outptr12 += matrix_row_stride - n_output_channels; + } +} +} diff --git a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp b/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp new file mode 100644 index 0000000000..3dd62d1ac1 --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp @@ -0,0 +1,822 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ +namespace winograd { +template <> +template <> +inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<0>( + const float* const kernel, + const int n_input_channels, + const int n_output_channels, + float* const matrix_base, + const int mstride, + const int matrix_row_stride +) { + // Use one input pointer for each row of the kernel, use two additional + // offsets to extract columns. + const int kernel_col_stride = n_input_channels * n_output_channels; + const int kernel_row_stride = 3 * kernel_col_stride; + const float *inptr0 = kernel; + const float *inptr1 = kernel + kernel_row_stride; + const float *inptr2 = kernel + kernel_row_stride*2; + + // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three + // offsets to extract further matrices. + float *outptr0 = matrix_base; + float *outptr4 = matrix_base + mstride * 4; + float *outptr8 = matrix_base + mstride * 8; + float *outptr12 = matrix_base + mstride * 12; + + // For every input channel + for (int in_c = 0; in_c < n_input_channels; in_c++) { + int n_remaining_channels = n_output_channels; + + asm volatile ( + // Registers into which to read the kernel + "w_11 .req v0\n" "qw_11 .req q0\n" + "w_12 .req v1\n" "qw_12 .req q1\n" + "w_13 .req v2\n" "qw_13 .req q2\n" + "w_21 .req v3\n" "qw_21 .req q3\n" + "w_22 .req v4\n" "qw_22 .req q4\n" + "w_23 .req v5\n" "qw_23 .req q5\n" + "w_31 .req v6\n" "qw_31 .req q6\n" + "w_32 .req v7\n" "qw_32 .req q7\n" + "w_33 .req v8\n" "qw_33 .req q8\n" + + // Transformed matrix Ww + "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" + "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" + "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" + "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" + + // Output matrix U = WwWT + "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" + "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" + "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" + "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" + + // Storage view of output matrices + "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" + "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" + "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" + "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" + + "half .req v23\n" // {0.5, ..., 0.5} + "dup half.4s, %w[one_half]\n" + "scratch .req v24\n" + + "1:" + // Load tile of the kernel + "ldr qw_11, [%x[inptr0]]\n" + "str qU11, [%x[outptr0]]\n" + "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" + "str qU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qw_21, [%x[inptr1]]\n" + "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qw_31, [%x[inptr2]]\n" + "str qU41, [%x[outptr12]]\n" + "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" + "str qU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.4s, w_11.4s, w_31.4s\n" + "fmul Ww21.4s, scratch.4s, half.4s\n" + "fmla Ww21.4s, w_21.4s, half.4s\n" + "str qU21, [%x[outptr4]]\n" + "fmul Ww31.4s, scratch.4s, half.4s\n" + "fmls Ww31.4s, w_21.4s, half.4s\n" + "str qU31, [%x[outptr8]]\n" + + "fadd scratch.4s, w_12.4s, w_32.4s\n" + "fmul Ww22.4s, scratch.4s, half.4s\n" + "fmla Ww22.4s, w_22.4s, half.4s\n" + "fmul Ww32.4s, scratch.4s, half.4s\n" + "fmls Ww32.4s, w_22.4s, half.4s\n" + + "fadd scratch.4s, w_13.4s, w_33.4s\n" + "fmul Ww23.4s, scratch.4s, half.4s\n" + "fmla Ww23.4s, w_23.4s, half.4s\n" + "str qU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.4s, scratch.4s, half.4s\n" + "fmls Ww33.4s, w_23.4s, half.4s\n" + "str qU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns + // of U and update output pointers + "fadd scratch.4s, Ww11.4s, Ww13.4s\n" + "fmul U12.4s, scratch.4s, half.4s\n" + "fmla U12.4s, Ww12.4s, half.4s\n" + "str qU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.4s, scratch.4s, half.4s\n" + "fmls U13.4s, Ww12.4s, half.4s\n" + "str qU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd scratch.4s, Ww21.4s, Ww23.4s\n" + "fmul U22.4s, scratch.4s, half.4s\n" + "fmla U22.4s, Ww22.4s, half.4s\n" + "str qU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.4s, scratch.4s, half.4s\n" + "fmls U23.4s, Ww22.4s, half.4s\n" + "str qU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fadd scratch.4s, Ww31.4s, Ww33.4s\n" + "fmul U32.4s, scratch.4s, half.4s\n" + "fmla U32.4s, Ww32.4s, half.4s\n" + "str qU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.4s, scratch.4s, half.4s\n" + "fmls U33.4s, Ww32.4s, half.4s\n" + "str qU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fadd scratch.4s, Ww41.4s, Ww43.4s\n" + "fmul U42.4s, scratch.4s, half.4s\n" + "fmla U42.4s, Ww42.4s, half.4s\n" + "str qU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.4s, scratch.4s, half.4s\n" + "fmls U43.4s, Ww42.4s, half.4s\n" + "str qU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" + "bne 1b\n" + + // Clear aliases + ".unreq half\n" + ".unreq scratch\n" + ".unreq w_11\n" ".unreq qw_11\n" + ".unreq w_12\n" ".unreq qw_12\n" + ".unreq w_13\n" ".unreq qw_13\n" + ".unreq w_21\n" ".unreq qw_21\n" + ".unreq w_22\n" ".unreq qw_22\n" + ".unreq w_23\n" ".unreq qw_23\n" + ".unreq w_31\n" ".unreq qw_31\n" + ".unreq w_32\n" ".unreq qw_32\n" + ".unreq w_33\n" ".unreq qw_33\n" + ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" + ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" + ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" + ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" + ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" + ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" + ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" + ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" + ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" + ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" + ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" + ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [n_remaining_channels] "+r" (n_remaining_channels) + : [mstride1] "r" (sizeof(float) * mstride), + [mstride2] "r" (sizeof(float) * mstride * 2), + [mstride3] "r" (sizeof(float) * mstride * 3), + [colstride1] "r" (sizeof(float) * kernel_col_stride), + [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), + [one_half] "r" (0.5f) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24" + ); + + // Progression to complete stride + outptr0 += matrix_row_stride - n_output_channels; + outptr4 += matrix_row_stride - n_output_channels; + outptr8 += matrix_row_stride - n_output_channels; + outptr12 += matrix_row_stride - n_output_channels; + } +} + +template <> +template <> +inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<2>( + const float* const kernel, + const int n_input_channels, + const int n_output_channels, + float* const matrix_base, + const int mstride, + const int matrix_row_stride +) { + // Use one input pointer for each row of the kernel, use two additional + // offsets to extract columns. + const int kernel_col_stride = n_input_channels * n_output_channels; + const int kernel_row_stride = 3 * kernel_col_stride; + const float *inptr0 = kernel; + const float *inptr1 = kernel + kernel_row_stride; + const float *inptr2 = kernel + kernel_row_stride*2; + + // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three + // offsets to extract further matrices. + float *outptr0 = matrix_base; + float *outptr4 = matrix_base + mstride * 4; + float *outptr8 = matrix_base + mstride * 8; + float *outptr12 = matrix_base + mstride * 12; + + // For every input channel + for (int in_c = 0; in_c < n_input_channels; in_c++) { + int n_remaining_channels = n_output_channels; + + asm volatile ( + // Registers into which to read the kernel + "w_11 .req v0\n" "qw_11 .req q0\n" "dw_11 .req d0\n" + "w_12 .req v1\n" "qw_12 .req q1\n" "dw_12 .req d1\n" + "w_13 .req v2\n" "qw_13 .req q2\n" "dw_13 .req d2\n" + "w_21 .req v3\n" "qw_21 .req q3\n" "dw_21 .req d3\n" + "w_22 .req v4\n" "qw_22 .req q4\n" "dw_22 .req d4\n" + "w_23 .req v5\n" "qw_23 .req q5\n" "dw_23 .req d5\n" + "w_31 .req v6\n" "qw_31 .req q6\n" "dw_31 .req d6\n" + "w_32 .req v7\n" "qw_32 .req q7\n" "dw_32 .req d7\n" + "w_33 .req v8\n" "qw_33 .req q8\n" "dw_33 .req d8\n" + + // Transformed matrix Ww + "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" + "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" + "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" + "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" + + // Output matrix U = WwWT + "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" + "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" + "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" + "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" + + // Storage view of output matrices + "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" + "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" + "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" + "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" + + "dU11 .req d0\n" "dU12 .req d15\n" "dU13 .req d16\n" "dU14 .req d2\n" + "dU21 .req d9\n" "dU22 .req d17\n" "dU23 .req d18\n" "dU24 .req d11\n" + "dU31 .req d12\n" "dU32 .req d19\n" "dU33 .req d20\n" "dU34 .req d14\n" + "dU41 .req d6\n" "dU42 .req d21\n" "dU43 .req d22\n" "dU44 .req d8\n" + + "half .req v23\n" // {0.5, ..., 0.5} + "dup half.4s, %w[one_half]\n" + "scratch .req v24\n" + + // Subtract the tail from the number of remaining channels and jump to + // the tail if necessary. + "subs %x[n_remaining_channels], %x[n_remaining_channels], #2\n" + "beq 2f\n" + + "1:" + // Load tile of the kernel + "ldr qw_11, [%x[inptr0]]\n" + "str qU11, [%x[outptr0]]\n" + "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" + "str qU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qw_21, [%x[inptr1]]\n" + "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qw_31, [%x[inptr2]]\n" + "str qU41, [%x[outptr12]]\n" + "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" + "str qU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.4s, w_11.4s, w_31.4s\n" + "fmul Ww21.4s, scratch.4s, half.4s\n" + "fmla Ww21.4s, w_21.4s, half.4s\n" + "str qU21, [%x[outptr4]]\n" + "fmul Ww31.4s, scratch.4s, half.4s\n" + "fmls Ww31.4s, w_21.4s, half.4s\n" + "str qU31, [%x[outptr8]]\n" + + "fadd scratch.4s, w_12.4s, w_32.4s\n" + "fmul Ww22.4s, scratch.4s, half.4s\n" + "fmla Ww22.4s, w_22.4s, half.4s\n" + "fmul Ww32.4s, scratch.4s, half.4s\n" + "fmls Ww32.4s, w_22.4s, half.4s\n" + + "fadd scratch.4s, w_13.4s, w_33.4s\n" + "fmul Ww23.4s, scratch.4s, half.4s\n" + "fmla Ww23.4s, w_23.4s, half.4s\n" + "str qU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.4s, scratch.4s, half.4s\n" + "fmls Ww33.4s, w_23.4s, half.4s\n" + "str qU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns + // of U and update output pointers + "fadd scratch.4s, Ww11.4s, Ww13.4s\n" + "fmul U12.4s, scratch.4s, half.4s\n" + "fmla U12.4s, Ww12.4s, half.4s\n" + "str qU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.4s, scratch.4s, half.4s\n" + "fmls U13.4s, Ww12.4s, half.4s\n" + "str qU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd scratch.4s, Ww21.4s, Ww23.4s\n" + "fmul U22.4s, scratch.4s, half.4s\n" + "fmla U22.4s, Ww22.4s, half.4s\n" + "str qU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.4s, scratch.4s, half.4s\n" + "fmls U23.4s, Ww22.4s, half.4s\n" + "str qU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fadd scratch.4s, Ww31.4s, Ww33.4s\n" + "fmul U32.4s, scratch.4s, half.4s\n" + "fmla U32.4s, Ww32.4s, half.4s\n" + "str qU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.4s, scratch.4s, half.4s\n" + "fmls U33.4s, Ww32.4s, half.4s\n" + "str qU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fadd scratch.4s, Ww41.4s, Ww43.4s\n" + "fmul U42.4s, scratch.4s, half.4s\n" + "fmla U42.4s, Ww42.4s, half.4s\n" + "str qU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.4s, scratch.4s, half.4s\n" + "fmls U43.4s, Ww42.4s, half.4s\n" + "str qU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" + "bne 1b\n" + + // Tail size 2 + "2:" + // Load tile of the kernel + "ldr dw_11, [%x[inptr0]]\n" + "str dU11, [%x[outptr0]]\n" + "ldr dw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr dw_13, [%x[inptr0], %x[colstride2]]\n" + "str dU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x08\n" + + "ldr dw_21, [%x[inptr1]]\n" + "ldr dw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr dw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x08\n" + + "ldr dw_31, [%x[inptr2]]\n" + "str dU41, [%x[outptr12]]\n" + "ldr dw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr dw_33, [%x[inptr2], %x[colstride2]]\n" + "str dU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x08\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.2s, w_11.2s, w_31.2s\n" + "fmul Ww21.2s, scratch.2s, half.2s\n" + "fmla Ww21.2s, w_21.2s, half.2s\n" + "str dU21, [%x[outptr4]]\n" + "fmul Ww31.2s, scratch.2s, half.2s\n" + "fmls Ww31.2s, w_21.2s, half.2s\n" + "str dU31, [%x[outptr8]]\n" + + "fadd scratch.2s, w_12.2s, w_32.2s\n" + "fmul Ww22.2s, scratch.2s, half.2s\n" + "fmla Ww22.2s, w_22.2s, half.2s\n" + "fmul Ww32.2s, scratch.2s, half.2s\n" + "fmls Ww32.2s, w_22.2s, half.2s\n" + + "fadd scratch.2s, w_13.2s, w_33.2s\n" + "fmul Ww23.2s, scratch.2s, half.2s\n" + "fmla Ww23.2s, w_23.2s, half.2s\n" + "str dU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.2s, scratch.2s, half.2s\n" + "fmls Ww33.2s, w_23.2s, half.2s\n" + "str dU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns of + // U and update output pointers + "fadd scratch.2s, Ww11.2s, Ww13.2s\n" + "fmul U12.2s, scratch.2s, half.2s\n" + "fmla U12.2s, Ww12.2s, half.2s\n" + "str dU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.2s, scratch.2s, half.2s\n" + "fmls U13.2s, Ww12.2s, half.2s\n" + "str dU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x08\n" + + "fadd scratch.2s, Ww21.2s, Ww23.2s\n" + "fmul U22.2s, scratch.2s, half.2s\n" + "fmla U22.2s, Ww22.2s, half.2s\n" + "str dU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.2s, scratch.2s, half.2s\n" + "fmls U23.2s, Ww22.2s, half.2s\n" + "str dU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x08\n" + + "fadd scratch.2s, Ww31.2s, Ww33.2s\n" + "fmul U32.2s, scratch.2s, half.2s\n" + "fmla U32.2s, Ww32.2s, half.2s\n" + "str dU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.2s, scratch.2s, half.2s\n" + "fmls U33.2s, Ww32.2s, half.2s\n" + "str dU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x08\n" + + "fadd scratch.2s, Ww41.2s, Ww43.2s\n" + "fmul U42.2s, scratch.2s, half.2s\n" + "fmla U42.2s, Ww42.2s, half.2s\n" + "str dU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.2s, scratch.2s, half.2s\n" + "fmls U43.2s, Ww42.2s, half.2s\n" + "str dU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x08\n" + + // Clear aliases + ".unreq half\n" + ".unreq scratch\n" + ".unreq w_11\n" ".unreq qw_11\n" ".unreq dw_11\n" + ".unreq w_12\n" ".unreq qw_12\n" ".unreq dw_12\n" + ".unreq w_13\n" ".unreq qw_13\n" ".unreq dw_13\n" + ".unreq w_21\n" ".unreq qw_21\n" ".unreq dw_21\n" + ".unreq w_22\n" ".unreq qw_22\n" ".unreq dw_22\n" + ".unreq w_23\n" ".unreq qw_23\n" ".unreq dw_23\n" + ".unreq w_31\n" ".unreq qw_31\n" ".unreq dw_31\n" + ".unreq w_32\n" ".unreq qw_32\n" ".unreq dw_32\n" + ".unreq w_33\n" ".unreq qw_33\n" ".unreq dw_33\n" + ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" + ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" + ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" + ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" + ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" + ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" + ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" + ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" + ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" + ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" + ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" + ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" + ".unreq dU11\n" ".unreq dU12\n" ".unreq dU13\n" ".unreq dU14\n" + ".unreq dU21\n" ".unreq dU22\n" ".unreq dU23\n" ".unreq dU24\n" + ".unreq dU31\n" ".unreq dU32\n" ".unreq dU33\n" ".unreq dU34\n" + ".unreq dU41\n" ".unreq dU42\n" ".unreq dU43\n" ".unreq dU44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [n_remaining_channels] "+r" (n_remaining_channels) + : [mstride1] "r" (sizeof(float) * mstride), + [mstride2] "r" (sizeof(float) * mstride * 2), + [mstride3] "r" (sizeof(float) * mstride * 3), + [colstride1] "r" (sizeof(float) * kernel_col_stride), + [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), + [one_half] "r" (0.5f) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24" + ); + + // Progression to complete stride + outptr0 += matrix_row_stride - n_output_channels; + outptr4 += matrix_row_stride - n_output_channels; + outptr8 += matrix_row_stride - n_output_channels; + outptr12 += matrix_row_stride - n_output_channels; + } +} + +template <> +template <> +inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<1>( + const float* const kernel, + const int n_input_channels, + const int n_output_channels, + float* const matrix_base, + const int mstride, + const int matrix_row_stride +) { + // Use one input pointer for each row of the kernel, use two additional + // offsets to extract columns. + const int kernel_col_stride = n_input_channels * n_output_channels; + const int kernel_row_stride = 3 * kernel_col_stride; + const float *inptr0 = kernel; + const float *inptr1 = kernel + kernel_row_stride; + const float *inptr2 = kernel + kernel_row_stride*2; + + // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three + // offsets to extract further matrices. + float *outptr0 = matrix_base; + float *outptr4 = matrix_base + mstride * 4; + float *outptr8 = matrix_base + mstride * 8; + float *outptr12 = matrix_base + mstride * 12; + + // For every input channel + for (int in_c = 0; in_c < n_input_channels; in_c++) { + int n_remaining_channels = n_output_channels; + + asm volatile ( + // Registers into which to read the kernel + "w_11 .req v0\n" "qw_11 .req q0\n" "sw_11 .req s0\n" + "w_12 .req v1\n" "qw_12 .req q1\n" "sw_12 .req s1\n" + "w_13 .req v2\n" "qw_13 .req q2\n" "sw_13 .req s2\n" + "w_21 .req v3\n" "qw_21 .req q3\n" "sw_21 .req s3\n" + "w_22 .req v4\n" "qw_22 .req q4\n" "sw_22 .req s4\n" + "w_23 .req v5\n" "qw_23 .req q5\n" "sw_23 .req s5\n" + "w_31 .req v6\n" "qw_31 .req q6\n" "sw_31 .req s6\n" + "w_32 .req v7\n" "qw_32 .req q7\n" "sw_32 .req s7\n" + "w_33 .req v8\n" "qw_33 .req q8\n" "sw_33 .req s8\n" + + // Transformed matrix Ww + "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" + "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" + "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" + "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" + + // Output matrix U = WwWT + "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" + "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" + "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" + "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" + + // Storage view of output matrices + "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" + "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" + "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" + "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" + + "sU11 .req s0\n" "sU12 .req s15\n" "sU13 .req s16\n" "sU14 .req s2\n" + "sU21 .req s9\n" "sU22 .req s17\n" "sU23 .req s18\n" "sU24 .req s11\n" + "sU31 .req s12\n" "sU32 .req s19\n" "sU33 .req s20\n" "sU34 .req s14\n" + "sU41 .req s6\n" "sU42 .req s21\n" "sU43 .req s22\n" "sU44 .req s8\n" + + "half .req v23\n" // {0.5, ..., 0.5} + "dup half.4s, %w[one_half]\n" + "scratch .req v24\n" + + // Subtract the tail from the number of remaining channels and jump to + // the tail if necessary. + "subs %x[n_remaining_channels], %x[n_remaining_channels], #1\n" + "beq 2f\n" + + "1:" + // Load tile of the kernel + "ldr qw_11, [%x[inptr0]]\n" + "str qU11, [%x[outptr0]]\n" + "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" + "str qU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qw_21, [%x[inptr1]]\n" + "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qw_31, [%x[inptr2]]\n" + "str qU41, [%x[outptr12]]\n" + "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" + "str qU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.4s, w_11.4s, w_31.4s\n" + "fmul Ww21.4s, scratch.4s, half.4s\n" + "fmla Ww21.4s, w_21.4s, half.4s\n" + "str qU21, [%x[outptr4]]\n" + "fmul Ww31.4s, scratch.4s, half.4s\n" + "fmls Ww31.4s, w_21.4s, half.4s\n" + "str qU31, [%x[outptr8]]\n" + + "fadd scratch.4s, w_12.4s, w_32.4s\n" + "fmul Ww22.4s, scratch.4s, half.4s\n" + "fmla Ww22.4s, w_22.4s, half.4s\n" + "fmul Ww32.4s, scratch.4s, half.4s\n" + "fmls Ww32.4s, w_22.4s, half.4s\n" + + "fadd scratch.4s, w_13.4s, w_33.4s\n" + "fmul Ww23.4s, scratch.4s, half.4s\n" + "fmla Ww23.4s, w_23.4s, half.4s\n" + "str qU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.4s, scratch.4s, half.4s\n" + "fmls Ww33.4s, w_23.4s, half.4s\n" + "str qU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns + // of U and update output pointers + "fadd scratch.4s, Ww11.4s, Ww13.4s\n" + "fmul U12.4s, scratch.4s, half.4s\n" + "fmla U12.4s, Ww12.4s, half.4s\n" + "str qU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.4s, scratch.4s, half.4s\n" + "fmls U13.4s, Ww12.4s, half.4s\n" + "str qU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd scratch.4s, Ww21.4s, Ww23.4s\n" + "fmul U22.4s, scratch.4s, half.4s\n" + "fmla U22.4s, Ww22.4s, half.4s\n" + "str qU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.4s, scratch.4s, half.4s\n" + "fmls U23.4s, Ww22.4s, half.4s\n" + "str qU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fadd scratch.4s, Ww31.4s, Ww33.4s\n" + "fmul U32.4s, scratch.4s, half.4s\n" + "fmla U32.4s, Ww32.4s, half.4s\n" + "str qU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.4s, scratch.4s, half.4s\n" + "fmls U33.4s, Ww32.4s, half.4s\n" + "str qU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fadd scratch.4s, Ww41.4s, Ww43.4s\n" + "fmul U42.4s, scratch.4s, half.4s\n" + "fmla U42.4s, Ww42.4s, half.4s\n" + "str qU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.4s, scratch.4s, half.4s\n" + "fmls U43.4s, Ww42.4s, half.4s\n" + "str qU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" + "bne 1b\n" + + // Tail size 1 + "2:" + // Load tile of the kernel + "ldr sw_11, [%x[inptr0]]\n" + "str sU11, [%x[outptr0]]\n" + "ldr sw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr sw_13, [%x[inptr0], %x[colstride2]]\n" + "str sU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x04\n" + + "ldr sw_21, [%x[inptr1]]\n" + "ldr sw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr sw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x04\n" + + "ldr sw_31, [%x[inptr2]]\n" + "str sU41, [%x[outptr12]]\n" + "ldr sw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr sw_33, [%x[inptr2], %x[colstride2]]\n" + "str sU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x04\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.2s, w_11.2s, w_31.2s\n" + "fmul Ww21.2s, scratch.2s, half.2s\n" + "fmla Ww21.2s, w_21.2s, half.2s\n" + "str sU21, [%x[outptr4]]\n" + "fmul Ww31.2s, scratch.2s, half.2s\n" + "fmls Ww31.2s, w_21.2s, half.2s\n" + "str sU31, [%x[outptr8]]\n" + + "fadd scratch.2s, w_12.2s, w_32.2s\n" + "fmul Ww22.2s, scratch.2s, half.2s\n" + "fmla Ww22.2s, w_22.2s, half.2s\n" + "fmul Ww32.2s, scratch.2s, half.2s\n" + "fmls Ww32.2s, w_22.2s, half.2s\n" + + "fadd scratch.2s, w_13.2s, w_33.2s\n" + "fmul Ww23.2s, scratch.2s, half.2s\n" + "fmla Ww23.2s, w_23.2s, half.2s\n" + "str sU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.2s, scratch.2s, half.2s\n" + "fmls Ww33.2s, w_23.2s, half.2s\n" + "str sU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns of + // U and update output pointers + "fadd scratch.2s, Ww11.2s, Ww13.2s\n" + "fmul U12.2s, scratch.2s, half.2s\n" + "fmla U12.2s, Ww12.2s, half.2s\n" + "str sU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.2s, scratch.2s, half.2s\n" + "fmls U13.2s, Ww12.2s, half.2s\n" + "str sU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x04\n" + + "fadd scratch.2s, Ww21.2s, Ww23.2s\n" + "fmul U22.2s, scratch.2s, half.2s\n" + "fmla U22.2s, Ww22.2s, half.2s\n" + "str sU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.2s, scratch.2s, half.2s\n" + "fmls U23.2s, Ww22.2s, half.2s\n" + "str sU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x04\n" + + "fadd scratch.2s, Ww31.2s, Ww33.2s\n" + "fmul U32.2s, scratch.2s, half.2s\n" + "fmla U32.2s, Ww32.2s, half.2s\n" + "str sU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.2s, scratch.2s, half.2s\n" + "fmls U33.2s, Ww32.2s, half.2s\n" + "str sU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x04\n" + + "fadd scratch.2s, Ww41.2s, Ww43.2s\n" + "fmul U42.2s, scratch.2s, half.2s\n" + "fmla U42.2s, Ww42.2s, half.2s\n" + "str sU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.2s, scratch.2s, half.2s\n" + "fmls U43.2s, Ww42.2s, half.2s\n" + "str sU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x04\n" + + // Clear aliases + ".unreq half\n" + ".unreq scratch\n" + ".unreq w_11\n" ".unreq qw_11\n" ".unreq sw_11\n" + ".unreq w_12\n" ".unreq qw_12\n" ".unreq sw_12\n" + ".unreq w_13\n" ".unreq qw_13\n" ".unreq sw_13\n" + ".unreq w_21\n" ".unreq qw_21\n" ".unreq sw_21\n" + ".unreq w_22\n" ".unreq qw_22\n" ".unreq sw_22\n" + ".unreq w_23\n" ".unreq qw_23\n" ".unreq sw_23\n" + ".unreq w_31\n" ".unreq qw_31\n" ".unreq sw_31\n" + ".unreq w_32\n" ".unreq qw_32\n" ".unreq sw_32\n" + ".unreq w_33\n" ".unreq qw_33\n" ".unreq sw_33\n" + ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" + ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" + ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" + ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" + ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" + ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" + ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" + ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" + ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" + ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" + ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" + ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" + ".unreq sU11\n" ".unreq sU12\n" ".unreq sU13\n" ".unreq sU14\n" + ".unreq sU21\n" ".unreq sU22\n" ".unreq sU23\n" ".unreq sU24\n" + ".unreq sU31\n" ".unreq sU32\n" ".unreq sU33\n" ".unreq sU34\n" + ".unreq sU41\n" ".unreq sU42\n" ".unreq sU43\n" ".unreq sU44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [n_remaining_channels] "+r" (n_remaining_channels) + : [mstride1] "r" (sizeof(float) * mstride), + [mstride2] "r" (sizeof(float) * mstride * 2), + [mstride3] "r" (sizeof(float) * mstride * 3), + [colstride1] "r" (sizeof(float) * kernel_col_stride), + [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), + [one_half] "r" (0.5f) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24" + ); + + // Progression to complete stride + outptr0 += matrix_row_stride - n_output_channels; + outptr4 += matrix_row_stride - n_output_channels; + outptr8 += matrix_row_stride - n_output_channels; + outptr12 += matrix_row_stride - n_output_channels; + } +} +} +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp new file mode 100644 index 0000000000..0992c0bb44 --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +namespace winograd { + /* Transform from the Winograd domain back to the spatial domain. + */ + template + struct Winograd2x2_3x3GemmOutput { + static void execute( + const Tensor4DShape &output_shape, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output + ); + + protected: + /* Specialised implementation method. */ + template + static void _execute( + const Tensor4DShape &output_shape, + T *output, + const T *input, + const int matrix_stride, + const int matrix_row_stride + ); + }; + + /* Two-stage implementation of the transformation from the Winograd domain. + * + * First computes Z.F and then computes (Z.F).Z^T. + */ + template + struct Winograd2x2_3x3GemmOutput_TwoStage { + static void execute( + const Tensor4DShape &output_shape, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output + ); + + protected: + template + static void compute_zf( + const int n_rows, const int n_channels, + T* const zf, const T* const input[16] + ); + + template + static void compute_zfzT( + const Tensor4DShape &output_shape, + T* const output, const T* const zf + ); + }; +} + +#include "output_2x2_3x3/a64_float.hpp" +// #include "output_2x2_3x3/a64_float_two_stage.hpp" + +/*****************************************************************************/ +/* +template +void winograd::Winograd2x2_3x3GemmOutput::execute( + const Tensor4DShape &output_shape, + const int tile_M, + const int tile_N, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output +) { + T* const antipadding = reinterpret_cast(malloc(sizeof(T) * output_shape.n_channels)); + + // Get input pointers + const T* inptrs[16]; + for (int i = 0; i < 16; i++) { + inptrs[i] = matrices[i]; + } + + for (int batch = 0; batch < output_shape.n_batches; batch++) { + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + // Get pointers for each of the 4 output cells required for this computation + T* outptrs[4]; + for (int cell_i = 0, c = 0; cell_i < 2; cell_i++) { + for (int cell_j = 0; cell_j < 2; cell_j++, c++) { + const int i = tile_i*2 + cell_i; + const int j = tile_j*2 + cell_j; + + if (i < output_shape.n_rows && j < output_shape.n_cols) { + outptrs[c] = output + ( + (batch*output_shape.n_rows + i) * output_shape.n_cols + + j) * output_shape.n_channels; + } else { + outptrs[c] = antipadding; + } + } // cell_j + } // cell_i + + for (int n = 0; n < output_shape.n_channels; n++) { + // Read 16 values and progress pointers + T v[16]; + for (int i = 0; i < 16; i++) { + v[i] = *(inptrs[i]++); + } + + // Compute output for 4 pixels + *(outptrs[0]++) = v[ 0] + v[ 1] + v[ 2] + + v[ 4] + v[ 5] + v[ 6] + + v[ 8] + v[ 9] + v[10]; + *(outptrs[1]++) = v[ 1] - v[ 2] - v[ 3] + + v[ 5] - v[ 6] - v[ 7] + + v[ 9] - v[10] - v[11]; + *(outptrs[2]++) = v[ 4] + v[ 5] + v[ 6] - + v[ 8] - v[ 9] - v[10] - + v[12] - v[13] - v[14]; + *(outptrs[3]++) = v[ 5] - v[ 6] - v[ 7] - + v[ 9] + v[10] + v[11] - + v[13] + v[14] + v[15]; + } // output_channel + } // tile_j + } // tile_i + } // batch + + free(antipadding); +} +*/ + +/*****************************************************************************/ +/* +template +void winograd::Winograd2x2_3x3GemmOutput_TwoStage::execute( + const Tensor4DShape &output_shape, + T* const matrices[16], T* const output +) { + // Allocate memory for the intermediate matrices + const int tile_M = iceildiv(output_shape.n_rows, 2); + const int tile_N = iceildiv(output_shape.n_cols, 2); + const int n_rows = output_shape.n_batches * tile_M * tile_N; + const int n_channels = output_shape.n_channels; + T* matrices_zf = reinterpret_cast( + calloc(8 * n_rows * n_channels, sizeof(T)) + ); + + // Perform the first stage transform, computing ZF. + // Specializations should dispatch to different methods based on tail size. + compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); + + // Perform the second stage transform, finishing Z F Z^T - variable dispatch + // based on size of the output. Specialisations can also dispatch based on + // the tail-size of the channel. + if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { + compute_zfzT(output_shape, output, matrices_zf); + } else if (output_shape.n_rows % 2) { + compute_zfzT(output_shape, output, matrices_zf); + } else if (output_shape.n_cols % 2) { + compute_zfzT(output_shape, output, matrices_zf); + } else { + compute_zfzT(output_shape, output, matrices_zf); + } + + free(reinterpret_cast(matrices_zf)); +} + +template +template +void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zf( + const int n_rows, const int n_channels, + T* output, const T* const input[16] +) { + // Extract 8 output pointers + T* outptr[8]; + for (int i = 0; i < 8; i++) { + outptr[i] = output + i*n_rows*n_channels; + } + + // Copy the 16 input pointers + const T* inptr[16]; + for (int i = 0; i < 16; i++) { + inptr[i] = input[i]; + } + + // For every row of the matrices + for (int i = 0; i < n_rows; i++) { + // For every channel + for (int j = 0; j < n_channels; j++) { + // Extract values from the input matrices + T val[16]; + for (int n = 0; n < 16; n++) { + val[n] = *(inptr[n]++); + } + + // Compute output values + *(outptr[0]++) = val[0] + val[1] + val[2]; + *(outptr[1]++) = val[1] - val[2] - val[3]; + *(outptr[2]++) = val[4] + val[5] + val[6]; + *(outptr[3]++) = val[5] - val[6] - val[7]; + *(outptr[4]++) = val[8] + val[9] + val[10]; + *(outptr[5]++) = val[9] - val[10] - val[11]; + *(outptr[6]++) = val[12] + val[13] + val[14]; + *(outptr[7]++) = val[13] - val[14] - val[15]; + } + } +} + +template +template +void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zfzT( + const Tensor4DShape &output_shape, + T* const output, const T* const input +) { + // Sizing information + const int tile_M = output_shape.n_rows / 2; + const int tile_N = output_shape.n_cols / 2; + + const int n_rows = (output_shape.n_batches * + (tile_M + (tail_M ? 1 : 0)) * + (tile_N + (tail_N ? 1 : 0))); + const int n_channels = output_shape.n_channels; + + // Extract 8 input pointers + const T* inptr[8]; + for (int i = 0; i < 8; i++) { + inptr[i] = input + i*n_rows*n_channels; + } + + // Extract 4 output pointers + T* outptr00 = output; + T* outptr01 = outptr00 + n_channels; + T* outptr10 = outptr00 + output_shape.n_cols * n_channels; + T* outptr11 = outptr10 + n_channels; + + // Progress over the output tiles, generating output values. + for (int batch = 0; batch < output_shape.n_batches; batch++) { + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + T v[8]; + for (int i = 0; i < 8; i++) { + v[i] = *(inptr[i]++); + } + + // Compute the output values and progress the output pointers. + *(outptr00++) = v[0] + v[2] + v[4]; + *(outptr01++) = v[1] + v[3] + v[5]; + *(outptr10++) = v[2] - v[4] - v[6]; + *(outptr11++) = v[3] - v[5] - v[7]; + } + + // Progress the output pointers to the next column + outptr00 += n_channels; + outptr01 += n_channels; + outptr10 += n_channels; + outptr11 += n_channels; + } + + if (tail_N) { + // Only evaluate the left-most columns of the output + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + T v[8]; + for (int i = 0; i < 4; i++) { + v[i * 2] = *inptr[i * 2]; + } + for (int i = 0; i < 8; i++) { + inptr[i]++; + } + + // Compute the output values and progress the output pointers. + *(outptr00++) = v[0] + v[2] + v[4]; + *(outptr10++) = v[2] - v[4] - v[6]; + } + + // Progress the output pointers to the next column + outptr01 += n_channels; // Account for being skipped above + outptr11 += n_channels; // Account for being skipped above + } + + // Progress the output pointers to the next row + outptr00 += output_shape.n_cols * n_channels; + outptr01 += output_shape.n_cols * n_channels; + outptr10 += output_shape.n_cols * n_channels; + outptr11 += output_shape.n_cols * n_channels; + } + + if (tail_M) { + // Only work on the upper row of the output + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + T v[8]; + for (int i = 0; i < 8; i++) { + v[i] = *(inptr[i]++); + } + + // Compute the output values and progress the output pointers. + *(outptr00++) = v[0] + v[2] + v[4]; + *(outptr01++) = v[1] + v[3] + v[5]; + } + + // Progress the output pointers to the next column + outptr00 += n_channels; + outptr01 += n_channels; + outptr10 += 2 * n_channels; // Account for being skipped above + outptr11 += 2 * n_channels; // Account for being skipped above + } + + if (tail_N) { + // Only evaluate the upper-left cell of the output + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + T v[8]; + for (int i = 0; i < 3; i++) { + v[i * 2] = *inptr[i * 2]; + } + for (int i = 0; i < 8; i++) { + inptr[i]++; + } + + // Compute the output values and progress the output pointers. + *(outptr00++) = v[0] + v[2] + v[4]; + } + + // Progress the output pointers to the next column + outptr01 += n_channels; // Account for being skipped above + outptr10 += n_channels; // Account for being skipped above + outptr11 += n_channels; // Account for being skipped above + } + } + } +} +*/ diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp new file mode 100644 index 0000000000..5925f9d569 --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp @@ -0,0 +1,650 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +/* Float implementation for AArch64. + */ +#ifdef __aarch64__ +namespace winograd { + + +template <> +template <> +inline void Winograd2x2_3x3GemmOutput::_execute( + const Tensor4DShape &output_shape, + float *output, + const float *input, + const int mstride, + const int matrix_row_stride +) { + const int tile_M = output_shape.n_rows / 2; + const int tile_N = output_shape.n_cols / 2; + int batch = output_shape.n_batches; + float *outptr = output; + + const float *inptr0 = input; + const float *inptr4 = input + 4 * mstride; + const float *inptr8 = input + 8 * mstride; + const float *inptr12 = input + 12 * mstride; + + const size_t col_stride = sizeof(float) * output_shape.n_channels; + const size_t row_stride = col_stride * tile_N * 2; + + asm volatile ( + // Aliases for elements of the input matrix `F` + // V-register Q-register + "F11 .req v0\n" "qF11 .req q0\n" + "F12 .req v1\n" "qF12 .req q1\n" + "F13 .req v2\n" "qF13 .req q2\n" + "F14 .req v3\n" "qF14 .req q3\n" + "F21 .req v4\n" "qF21 .req q4\n" + "F22 .req v5\n" "qF22 .req q5\n" + "F23 .req v6\n" "qF23 .req q6\n" + "F24 .req v7\n" "qF24 .req q7\n" + "F31 .req v8\n" "qF31 .req q8\n" + "F32 .req v9\n" "qF32 .req q9\n" + "F33 .req v10\n" "qF33 .req q10\n" + "F34 .req v11\n" "qF34 .req q11\n" + "F41 .req v12\n" "qF41 .req q12\n" + "F42 .req v13\n" "qF42 .req q13\n" + "F43 .req v14\n" "qF43 .req q14\n" + "F44 .req v15\n" "qF44 .req q15\n" + + // Aliases for elements of the intermediate matrix `FZ` + "FZ11 .req v16\n" + "FZ12 .req v17\n" + "FZ21 .req v18\n" + "FZ22 .req v19\n" + "FZ31 .req v20\n" + "FZ32 .req v21\n" + "FZ41 .req v22\n" + "FZ42 .req v23\n" + + // Aliases for elements of the output matrix `f` (called `g` due to case + // insensitivity of aliases). + " g11 .req v24\n" + "qg11 .req q24\n" + " g12 .req v25\n" + "qg12 .req q25\n" + " g21 .req v26\n" + "qg21 .req q26\n" + " g22 .req v27\n" + "qg22 .req q27\n" + + // Prepare the various strides + "col_stride .req %x[col_stride]\n" + "row_stride .req %x[row_stride]\n" + "row_plus_col_stride .req %x[row_plus_col_stride]\n" + + "mstride1 .req %x[mstride1]\n" + "mstride2 .req %x[mstride2]\n" + "mstride3 .req %x[mstride3]\n" + + "tile_i .req x19\n" // Tile row counter + "tile_j .req x20\n" // Tile column counter + "channel .req x21\n" // Channel counter + + "1:" // Loop over batches + "mov tile_i, %x[tile_M]\n" // Reset tile row counter + + "2:" // Loop over rows of tiles + "mov tile_j, %x[tile_N]\n" // Reset tile column counter + + "3:" // Loop over columns of tiles + // Perform initial loads of the matrix `F` + "ldr qF11, [%x[inptr0]]\n" + "ldr qF12, [%x[inptr0], mstride1]\n" + "ldr qF13, [%x[inptr0], mstride2]\n" + "ldr qF14, [%x[inptr0], mstride3]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + "ldr qF21, [%x[inptr4]]\n" + "ldr qF22, [%x[inptr4], mstride1]\n" + "subs channel, %x[n_channels], #4\n" // Reset channel counter + + "ldr qF23, [%x[inptr4], mstride2]\n" + "ldr qF24, [%x[inptr4], mstride3]\n" + "add %x[inptr4], %x[inptr4], #0x10\n" + "beq 5f\n" // Jump straight to tail if necessary + + "4:" // Loop over channels + "ldr qF31, [%x[inptr8]]\n" + "fadd FZ11.4s, F11.4s, F12.4s\n" + + "ldr qF32, [%x[inptr8], mstride1]\n" + "fsub FZ12.4s, F12.4s, F13.4s\n" + + "ldr qF33, [%x[inptr8], mstride2]\n" + "fadd FZ11.4s, FZ11.4s, F13.4s\n" + + "ldr qF34, [%x[inptr8], mstride3]\n" + "fsub FZ12.4s, FZ12.4s, F14.4s\n" + + "ldr qF41, [%x[inptr12]]\n" + "fadd FZ21.4s, F21.4s, F22.4s\n" + + "ldr qF42, [%x[inptr12], mstride1]\n" + "fsub FZ22.4s, F22.4s, F23.4s\n" + + "ldr qF43, [%x[inptr12], mstride2]\n" + "fadd FZ21.4s, FZ21.4s, F23.4s\n" + + "ldr qF44, [%x[inptr12], mstride3]\n" + "fsub FZ22.4s, FZ22.4s, F24.4s\n" + + "fadd FZ31.4s, F31.4s, F32.4s\n" + "add %x[inptr8], %x[inptr8], #0x10\n" + + "fsub FZ32.4s, F32.4s, F33.4s\n" + "add %x[inptr12], %x[inptr12], #0x10\n" + + "fadd FZ31.4s, FZ31.4s, F33.4s\n" + + "fsub FZ32.4s, FZ32.4s, F34.4s\n" + + "fadd g11.4s, FZ11.4s, FZ21.4s\n" + + "fadd g12.4s, FZ12.4s, FZ22.4s\n" + + "fadd g11.4s, g11.4s, FZ31.4s\n" + + "fadd g12.4s, g12.4s, FZ32.4s\n" + + "ldr qF11, [%x[inptr0]]\n" + "fadd FZ41.4s, F41.4s, F42.4s\n" + + "ldr qF12, [%x[inptr0], mstride1]\n" + "fsub g21.4s, FZ21.4s, FZ31.4s\n" + + "ldr qF13, [%x[inptr0], mstride2]\n" + "fsub FZ42.4s, F42.4s, F43.4s\n" + + "ldr qF14, [%x[inptr0], mstride3]\n" + "str qg11, [%x[outptr]]\n" + + "ldr qF21, [%x[inptr4]]\n" + "fadd FZ41.4s, FZ41.4s, F43.4s\n" + + "ldr qF22, [%x[inptr4], mstride1]\n" + "str qg12, [%x[outptr], col_stride]\n" + + "ldr qF23, [%x[inptr4], mstride2]\n" + "fsub FZ42.4s, FZ42.4s, F44.4s\n" + + "ldr qF24, [%x[inptr4], mstride3]\n" + "fsub g22.4s, FZ22.4s, FZ32.4s\n" + + "fsub g21.4s, g21.4s, FZ41.4s\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "fsub g22.4s, g22.4s, FZ42.4s\n" + "add %x[inptr4], %x[inptr4], #0x10\n" + + "subs channel, channel, #4\n" + + "str qg21, [%x[outptr], row_stride]\n" + + "str qg22, [%x[outptr], row_plus_col_stride]\n" + + "add %x[outptr], %x[outptr], #0x10\n" + + "bne 4b\n" + + "5:" // Channel tail + "ldr qF31, [%x[inptr8]]\n" + "fadd FZ11.4s, F11.4s, F12.4s\n" + + "ldr qF32, [%x[inptr8], mstride1]\n" + "fsub FZ12.4s, F12.4s, F13.4s\n" + + "ldr qF33, [%x[inptr8], mstride2]\n" + "fadd FZ11.4s, FZ11.4s, F13.4s\n" + + "ldr qF34, [%x[inptr8], mstride3]\n" + "fsub FZ12.4s, FZ12.4s, F14.4s\n" + + "ldr qF41, [%x[inptr12]]\n" + "fadd FZ21.4s, F21.4s, F22.4s\n" + + "ldr qF42, [%x[inptr12], mstride1]\n" + "fsub FZ22.4s, F22.4s, F23.4s\n" + + "ldr qF43, [%x[inptr12], mstride2]\n" + "fadd FZ21.4s, FZ21.4s, F23.4s\n" + + "ldr qF44, [%x[inptr12], mstride3]\n" + "fsub FZ22.4s, FZ22.4s, F24.4s\n" + + "fadd FZ31.4s, F31.4s, F32.4s\n" + "add %x[inptr8], %x[inptr8], #0x10\n" + + "fsub FZ32.4s, F32.4s, F33.4s\n" + "add %x[inptr12], %x[inptr12], #0x10\n" + + "fadd FZ31.4s, FZ31.4s, F33.4s\n" + + "fsub FZ32.4s, FZ32.4s, F34.4s\n" + + "fadd g11.4s, FZ11.4s, FZ21.4s\n" + + "fadd g12.4s, FZ12.4s, FZ22.4s\n" + + "fadd g11.4s, g11.4s, FZ31.4s\n" + + "fadd g12.4s, g12.4s, FZ32.4s\n" + + "fadd FZ41.4s, F41.4s, F42.4s\n" + + "fsub g21.4s, FZ21.4s, FZ31.4s\n" + + "fsub FZ42.4s, F42.4s, F43.4s\n" + + "str qg11, [%x[outptr]]\n" + + "fadd FZ41.4s, FZ41.4s, F43.4s\n" + + "str qg12, [%x[outptr], col_stride]\n" + + "fsub FZ42.4s, FZ42.4s, F44.4s\n" + + "fsub g22.4s, FZ22.4s, FZ32.4s\n" + + "fsub g21.4s, g21.4s, FZ41.4s\n" + + "fsub g22.4s, g22.4s, FZ42.4s\n" + + "subs channel, channel, #4\n" + + "str qg21, [%x[outptr], row_stride]\n" + + // Progress input pointers to the next row of the matrix + "add %x[inptr0], %x[inptr0], %x[mrowpad]\n" + "add %x[inptr4], %x[inptr4], %x[mrowpad]\n" + "add %x[inptr8], %x[inptr8], %x[mrowpad]\n" + "add %x[inptr12], %x[inptr12], %x[mrowpad]\n" + + "str qg22, [%x[outptr], row_plus_col_stride]\n" + + "add %x[outptr], %x[outptr], #0x10\n" + + + "add %x[outptr], %x[outptr], col_stride\n" + "subs tile_j, tile_j, #1\n" + "bne 3b\n" + + "add %x[outptr], %x[outptr], row_stride\n" + "subs tile_i, tile_i, #1\n" + "bne 2b\n" + + "subs %[batch], %[batch], #1\n" + "bne 1b\n" + + ".unreq F11\n" ".unreq qF11\n" + ".unreq F12\n" ".unreq qF12\n" + ".unreq F13\n" ".unreq qF13\n" + ".unreq F14\n" ".unreq qF14\n" + ".unreq F21\n" ".unreq qF21\n" + ".unreq F22\n" ".unreq qF22\n" + ".unreq F23\n" ".unreq qF23\n" + ".unreq F24\n" ".unreq qF24\n" + ".unreq F31\n" ".unreq qF31\n" + ".unreq F32\n" ".unreq qF32\n" + ".unreq F33\n" ".unreq qF33\n" + ".unreq F34\n" ".unreq qF34\n" + ".unreq F41\n" ".unreq qF41\n" + ".unreq F42\n" ".unreq qF42\n" + ".unreq F43\n" ".unreq qF43\n" + ".unreq F44\n" ".unreq qF44\n" + + ".unreq FZ11\n" ".unreq FZ12\n" + ".unreq FZ21\n" ".unreq FZ22\n" + ".unreq FZ31\n" ".unreq FZ32\n" + ".unreq FZ41\n" ".unreq FZ42\n" + + ".unreq g11\n" ".unreq qg11\n" + ".unreq g12\n" ".unreq qg12\n" + ".unreq g21\n" ".unreq qg21\n" + ".unreq g22\n" ".unreq qg22\n" + + ".unreq col_stride\n" + ".unreq row_stride\n" + ".unreq row_plus_col_stride\n" + + ".unreq mstride1\n" + ".unreq mstride2\n" + ".unreq mstride3\n" + + ".unreq tile_i \n" + ".unreq tile_j \n" + ".unreq channel\n" + + : [batch] "+r" (batch), + [outptr] "+r" (outptr), + [inptr0] "+r" (inptr0), + [inptr4] "+r" (inptr4), + [inptr8] "+r" (inptr8), + [inptr12] "+r" (inptr12) + : [tile_M] "r" (tile_M), + [tile_N] "r" (tile_N), + [n_channels] "r" (output_shape.n_channels), + [col_stride] "r" (col_stride), + [row_stride] "r" (row_stride), + [row_plus_col_stride] "r" (row_stride + col_stride), + [mstride1] "r" (mstride * sizeof(float)), + [mstride2] "r" (2 * mstride * sizeof(float)), + [mstride3] "r" (3 * mstride * sizeof(float)), + [mrowpad] "r" ((matrix_row_stride - output_shape.n_channels) * sizeof(float)) + : "x19", "x20", "x21", + "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "q16", "q17", "q18", "q19", "q20", "q21", + "q22", "q23", "q24", "q25", "q26", "q27", + "cc", "memory" + ); +} + +template <> +template +inline void Winograd2x2_3x3GemmOutput::_execute( + const Tensor4DShape &output_shape, + float *output, + const float *input, + const int mstride, + const int matrix_row_stride +) { + // Compute basic information about the shape of the matrices + const int tile_M = output_shape.n_rows / 2; + const int tile_N = output_shape.n_cols / 2; + const int n_channels = output_shape.n_channels; + + // Extract 16 input pointers + const float* inptr[16]; + for (int i = 0; i < 16; i++) { + inptr[i] = input + i*mstride; + } + + // Extract 4 output pointers + float *outptr00 = output; + float *outptr01 = outptr00 + n_channels; + float *outptr10 = outptr00 + output_shape.n_cols * n_channels; + float *outptr11 = outptr10 + n_channels; + + // Progress over the output tiles, generating output values. + for (int batch = 0; batch < output_shape.n_batches; batch++) { + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + float F[4][4]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + F[i][j] = *(inptr[i*4 + j]++); + } + } + + // Compute the matrix F.Z + float ZF[4][2]; + ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; + ZF[0][1] = F[0][1] - F[0][2] - F[0][3]; + ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; + ZF[1][1] = F[1][1] - F[1][2] - F[1][3]; + ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; + ZF[2][1] = F[2][1] - F[2][2] - F[2][3]; + ZF[3][0] = F[3][0] + F[3][1] + F[3][2]; + ZF[3][1] = F[3][1] - F[3][2] - F[3][3]; + + // Hence compute the output matrix Z^T . (F.Z) + *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; + *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1]; + *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0]; + *(outptr11++) = ZF[1][1] - ZF[2][1] - ZF[3][1]; + } + + // Progress the input pointers to the next row + for (int i = 0; i < 16; i++) { + inptr[i] += matrix_row_stride - n_channels; + } + + // Progress the output pointers to the next column + outptr00 += n_channels; + outptr01 += n_channels; + outptr10 += n_channels; + outptr11 += n_channels; + } + + if (tail_N) { + // Only evaluate the left-most columns of the output + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + float F[4][3]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 3; j++) { + F[i][j] = *(inptr[i*4 + j]++); + } + } + for (int i = 0; i < 4; i++) { + inptr[i*4 + 3]++; + } + + // Compute the matrix F.Z + float ZF[4][1]; + ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; + ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; + ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; + ZF[3][0] = F[3][0] + F[3][1] + F[3][2]; + + // Hence compute the output matrix Z^T . (F.Z) + *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; + *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0]; + } + + // Progress the input pointers to the next row + for (int i = 0; i < 16; i++) { + inptr[i] += matrix_row_stride - n_channels; + } + + // Progress the output pointers to the next column + outptr01 += n_channels; // Account for being skipped above + outptr11 += n_channels; // Account for being skipped above + } + + // Progress the output pointers to the next row + outptr00 += output_shape.n_cols * n_channels; + outptr01 += output_shape.n_cols * n_channels; + outptr10 += output_shape.n_cols * n_channels; + outptr11 += output_shape.n_cols * n_channels; + } + + if (tail_M) { + // Only work on the upper row of the output + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + float F[3][4]; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 4; j++) { + F[i][j] = *(inptr[i*4 + j]++); + } + } + for (int j = 0; j < 4; j++) { + inptr[12 + j]++; + } + + // Compute the matrix F.Z + float ZF[3][2]; + ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; + ZF[0][1] = F[0][1] - F[0][2] - F[0][3]; + ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; + ZF[1][1] = F[1][1] - F[1][2] - F[1][3]; + ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; + ZF[2][1] = F[2][1] - F[2][2] - F[2][3]; + + // Hence compute the output matrix Z^T . (F.Z) + *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; + *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1]; + } + + // Progress the input pointers to the next row + for (int i = 0; i < 16; i++) { + inptr[i] += matrix_row_stride - n_channels; + } + + // Progress the output pointers to the next column + outptr00 += n_channels; + outptr01 += n_channels; + outptr10 += 2 * n_channels; // Account for being skipped above + outptr11 += 2 * n_channels; // Account for being skipped above + } + + if (tail_N) { + // Only evaluate the upper-left cell of the output + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + float F[3][3]; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + F[i][j] = *(inptr[i*4 + j]); + } + } + for (int i = 0; i < 16; i++) { + inptr[i]++; + } + + // Compute the matrix F.Z + float ZF[3][1]; + ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; + ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; + ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; + + // Hence compute the output matrix Z^T . (F.Z) + *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; + } + + // Progress the input pointers to the next row + for (int i = 0; i < 16; i++) { + inptr[i] += matrix_row_stride - n_channels; + } + + // Progress the output pointers to the next column + outptr01 += n_channels; // Account for being skipped above + outptr10 += n_channels; // Account for being skipped above + outptr11 += n_channels; // Account for being skipped above + } + } + } +} + +/*****************************************************************************/ +template <> +inline void Winograd2x2_3x3GemmOutput::execute( + const Tensor4DShape &output_shape, + float* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + float* const output +) { + // Dispatch to an appropriate implementation based on the shape of the output + // tensor. + if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { + constexpr bool tail_M = true, tail_N = true; + switch (output_shape.n_channels % 4) { + case 0: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 1: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 2: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 3: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + default: + assert(0); + break; + } + } else if (output_shape.n_rows % 2) { + constexpr bool tail_M = true, tail_N = false; + switch (output_shape.n_channels % 4) { + case 0: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 1: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 2: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 3: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + default: + assert(0); + break; + } + } else if (output_shape.n_cols % 2) { + constexpr bool tail_M = false, tail_N = true; + switch (output_shape.n_channels % 4) { + case 0: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 1: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 2: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 3: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + default: + assert(0); + break; + + } + } else { + constexpr bool tail_M = false, tail_N = false; + switch (output_shape.n_channels % 4) { + case 0: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 1: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 2: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 3: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + default: + assert(0); + break; + + } + } +} +/*****************************************************************************/ + +} // namespace winograd +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp new file mode 100644 index 0000000000..f551b12b52 --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp @@ -0,0 +1,655 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +/*****************************************************************************/ +// Compute ZF specializations + +template <> +template <> +inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zf<0>( + const int n_rows, const int n_channels, + float* output, const float* const input[16] +) { + // Make copies of some variables + int row = n_rows; + float* outptr = output; + const float* inptr = input[0]; + + // Perform the transformation + asm volatile ( + // "inptr0 .req %x[inptr]\n" + "inptr1 .req x0\n" + "inptr2 .req x1\n" + "inptr3 .req x2\n" + "inptr4 .req x3\n" + "inptr5 .req x4\n" + "inptr6 .req x5\n" + "inptr7 .req x6\n" + "inptr8 .req x7\n" + "inptr9 .req x8\n" + "inptr10 .req x9\n" + "inptr11 .req x10\n" + "inptr12 .req x11\n" + "inptr13 .req x12\n" + "inptr14 .req x13\n" + "inptr15 .req x14\n" + + // "outptr0 .req %x[outptr]\n" + "outptr1 .req x15\n" + "outptr2 .req x16\n" + "outptr3 .req x17\n" + "outptr4 .req x18\n" + "outptr5 .req x19\n" + "outptr6 .req x20\n" + "outptr7 .req x21\n" + + // Compute additional pointers into the input and output matrices. + "mstride .req x22\n" // Matrix stride + "mul mstride, %x[row], %x[n_channels]\n" + "lsl mstride, mstride, #2\n" // * sizeof(float) + + "add inptr1, %x[inptr], mstride\n" + "add inptr2, %x[inptr], mstride, LSL #1\n" + "add inptr3, inptr2, mstride\n" + "add inptr4, inptr3, mstride\n" + "add inptr5, inptr4, mstride\n" + "add inptr6, inptr5, mstride\n" + "add inptr7, inptr6, mstride\n" + "add inptr8, inptr7, mstride\n" + "add inptr9, inptr8, mstride\n" + "add inptr10, inptr9, mstride\n" + "add inptr11, inptr10, mstride\n" + "add inptr12, inptr11, mstride\n" + "add inptr13, inptr12, mstride\n" + "add inptr14, inptr13, mstride\n" + "add inptr15, inptr14, mstride\n" + + "add outptr1, %[outptr], mstride\n" + "add outptr2, outptr1, mstride\n" + "add outptr3, outptr2, mstride\n" + "add outptr4, outptr3, mstride\n" + "add outptr5, outptr4, mstride\n" + "add outptr6, outptr5, mstride\n" + "add outptr7, outptr6, mstride\n" + + ".unreq mstride\n" + + "column .req x22\n" // Column loop counter + + "1:" // Loop over rows + "ldr q0, [%x[inptr]], #0x10\n" + "ldr q1, [inptr1], #0x10\n" + "ldr q2, [inptr2], #0x10\n" + "ldr q3, [inptr3], #0x10\n" + "ldr q4, [inptr4], #0x10\n" + "ldr q5, [inptr5], #0x10\n" + "ldr q6, [inptr6], #0x10\n" + "ldr q7, [inptr7], #0x10\n" + "subs column, %x[n_channels], #0x4\n" + "beq 3f\n" + + "2:" // Loop over columns + "ldr q8, [inptr8], #0x10\n" + "prfm pldl1keep, [%x[inptr], #196]\n" + "fadd v16.4s, v0.4s, v1.4s\n" + + "ldr q9, [inptr9], #0x10\n" + "prfm pldl1keep, [inptr1, #196]\n" + "fsub v17.4s, v1.4s, v2.4s\n" + + "ldr q10, [inptr10], #0x10\n" + "prfm pldl1keep, [inptr2, #196]\n" + "fadd v16.4s, v16.4s, v2.4s\n" + + "ldr q11, [inptr11], #0x10\n" + "prfm pldl1keep, [inptr3, #196]\n" + "fsub v17.4s, v17.4s, v3.4s\n" + + "ldr q12, [inptr12], #0x10\n" + "prfm pldl1keep, [inptr4, #196]\n" + "str q16, [%x[outptr]], #0x10\n" + + "ldr q13, [inptr13], #0x10\n" + "prfm pldl1keep, [inptr5, #196]\n" + "str q17, [outptr1], #0x10\n" + + "ldr q14, [inptr14], #0x10\n" + "prfm pldl1keep, [inptr6, #196]\n" + "fadd v16.4s, v4.4s, v5.4s\n" + + "ldr q15, [inptr15], #0x10\n" + "prfm pldl1keep, [inptr7, #196]\n" + "fsub v17.4s, v5.4s, v6.4s\n" + + "ldr q0, [%x[inptr]], #0x10\n" + "prfm pldl1keep, [inptr8, #196]\n" + "fadd v16.4s, v16.4s, v6.4s\n" + + "ldr q1, [inptr1], #0x10\n" + "prfm pldl1keep, [inptr9, #196]\n" + "fsub v17.4s, v17.4s, v7.4s\n" + + "ldr q2, [inptr2], #0x10\n" + "prfm pldl1keep, [inptr10, #196]\n" + "str q16, [outptr2], #0x10\n" + + "ldr q3, [inptr3], #0x10\n" + "prfm pldl1keep, [inptr11, #196]\n" + "str q17, [outptr3], #0x10\n" + + "ldr q4, [inptr4], #0x10\n" + "prfm pldl1keep, [inptr12, #196]\n" + "fadd v16.4s, v8.4s, v9.4s\n" + + "ldr q5, [inptr5], #0x10\n" + "prfm pldl1keep, [inptr13, #196]\n" + "fsub v17.4s, v9.4s, v10.4s\n" + + "ldr q6, [inptr6], #0x10\n" + "prfm pldl1keep, [inptr14, #196]\n" + "fadd v16.4s, v16.4s, v10.4s\n" + + "ldr q7, [inptr7], #0x10\n" + "prfm pldl1keep, [inptr15, #196]\n" + "fsub v17.4s, v17.4s, v11.4s\n" + + "str q16, [outptr4], #0x10\n" + "fadd v16.4s, v12.4s, v13.4s\n" + "fsub v18.4s, v13.4s, v14.4s\n" + + "str q17, [outptr5], #0x10\n" + "fadd v16.4s, v16.4s, v14.4s\n" + "fsub v18.4s, v18.4s, v15.4s\n" + + "str q16, [outptr6], #0x10\n" + "subs column, column, #0x4\n" + + "str q18, [outptr7], #0x10\n" + "bne 2b\n" + + "3:" // Tail + "ldr q8, [inptr8], #0x10\n" + "prfm pldl1keep, [%x[inptr], #196]\n" + "fadd v16.4s, v0.4s, v1.4s\n" + + "ldr q9, [inptr9], #0x10\n" + "prfm pldl1keep, [inptr1, #196]\n" + "fsub v17.4s, v1.4s, v2.4s\n" + + "ldr q10, [inptr10], #0x10\n" + "prfm pldl1keep, [inptr2, #196]\n" + "fadd v16.4s, v16.4s, v2.4s\n" + + "ldr q11, [inptr11], #0x10\n" + "prfm pldl1keep, [inptr3, #196]\n" + "fsub v17.4s, v17.4s, v3.4s\n" + + "ldr q12, [inptr12], #0x10\n" + "prfm pldl1keep, [inptr4, #196]\n" + "str q16, [%x[outptr]], #0x10\n" + + "ldr q13, [inptr13], #0x10\n" + "prfm pldl1keep, [inptr5, #196]\n" + "str q17, [outptr1], #0x10\n" + + "ldr q14, [inptr14], #0x10\n" + "prfm pldl1keep, [inptr6, #196]\n" + "fadd v16.4s, v4.4s, v5.4s\n" + + "ldr q15, [inptr15], #0x10\n" + "prfm pldl1keep, [inptr7, #196]\n" + "fsub v17.4s, v5.4s, v6.4s\n" + + "prfm pldl1keep, [inptr8, #196]\n" + "prfm pldl1keep, [inptr9, #196]\n" + "fadd v16.4s, v16.4s, v6.4s\n" + + "prfm pldl1keep, [inptr10, #196]\n" + "prfm pldl1keep, [inptr11, #196]\n" + "fsub v17.4s, v17.4s, v7.4s\n" + + "prfm pldl1keep, [inptr12, #196]\n" + "prfm pldl1keep, [inptr13, #196]\n" + "str q16, [outptr2], #0x10\n" + + "prfm pldl1keep, [inptr14, #196]\n" + "prfm pldl1keep, [inptr15, #196]\n" + "str q17, [outptr3], #0x10\n" + + "fadd v16.4s, v8.4s, v9.4s\n" + "fsub v17.4s, v9.4s, v10.4s\n" + + "fadd v16.4s, v16.4s, v10.4s\n" + "fsub v17.4s, v17.4s, v11.4s\n" + + "str q16, [outptr4], #0x10\n" + "fadd v16.4s, v12.4s, v13.4s\n" + "fsub v18.4s, v13.4s, v14.4s\n" + + "str q17, [outptr5], #0x10\n" + "fadd v16.4s, v16.4s, v14.4s\n" + "fsub v18.4s, v18.4s, v15.4s\n" + + "str q16, [outptr6], #0x10\n" + "str q18, [outptr7], #0x10\n" + + "subs %x[row], %x[row], #0x1\n" + "bne 1b\n" + + ".unreq inptr1\n" + ".unreq inptr2\n" + ".unreq inptr3\n" + ".unreq inptr4\n" + ".unreq inptr5\n" + ".unreq inptr6\n" + ".unreq inptr7\n" + ".unreq inptr8\n" + ".unreq inptr9\n" + ".unreq inptr10\n" + ".unreq inptr11\n" + ".unreq inptr12\n" + ".unreq inptr13\n" + ".unreq inptr14\n" + ".unreq inptr15\n" + ".unreq outptr1\n" + ".unreq outptr2\n" + ".unreq outptr3\n" + ".unreq outptr4\n" + ".unreq outptr5\n" + ".unreq outptr6\n" + ".unreq outptr7\n" + + : [row] "+r" (row), + [inptr] "+r" (inptr), + [outptr] "+r" (outptr) + : [n_channels] "r" (n_channels), + [sizeof_float] "i" (sizeof(float)) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "q16", "q17", "x0", "x1", "x2", "x3", "x4", + "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", + "x16", "x17", "x18", "x19", "x20", "x21", "x22", "cc", "memory" + ); +} + +/*****************************************************************************/ +// Compute ZFZ^T specializations + +template <> +template <> +inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zfzT( + const Tensor4DShape &output_shape, + float* const output, const float* const input +) { + const int tile_M = output_shape.n_rows / 2; + const int tile_N = output_shape.n_cols / 2; + int batch = output_shape.n_batches; + float *outptr = output; + const float *inptr = input; + + asm volatile ( + // Compute input pointers + "inptr1 .req x0\n" + "inptr2 .req x1\n" + "inptr3 .req x2\n" + "inptr4 .req x3\n" + "inptr5 .req x4\n" + "inptr6 .req x5\n" + "inptr7 .req x6\n" + "inptr8 .req x7\n" + + "mstride .req x8\n" + "mul mstride, %x[tile_M], %x[tile_N]\n" + "mul mstride, mstride, %x[n_channels]\n" + "lsl mstride, mstride, #2\n" // * sizeof(float) + + "add inptr1, %[inptr], mstride\n" + "add inptr2, inptr1, mstride\n" + "add inptr3, inptr2, mstride\n" + "add inptr4, inptr3, mstride\n" + "add inptr5, inptr4, mstride\n" + "add inptr6, inptr5, mstride\n" + "add inptr7, inptr6, mstride\n" + "add inptr8, inptr7, mstride\n" + + ".unreq mstride\n" + + // Compute initial output pointers + "outptr01 .req x8\n" + "outptr10 .req x9\n" + "outptr11 .req x10\n" + + "add outptr01, %x[outptr], %x[n_channels], LSL #2\n" + "add outptr10, %x[outptr], %x[row_stride], LSL #2\n" + "add outptr11, outptr10, %x[n_channels], LSL #2\n" + + "tile_i .req x11\n" + "tile_j .req x12\n" + "channel .req x13\n" + + "1:" // Loop over batches + "mov tile_i, %x[tile_M]\n" + + "2:" // Loop over rows of output tiles + "mov tile_j, %x[tile_N]\n" + + "3:" // Loop over columns of output tiles + "ldr q0, [%x[inptr]], #0x10\n" + "ldr q2, [inptr2], #0x10\n" + "subs channel, %x[n_channels], #0x4\n" + + "ldr q1, [inptr1], #0x10\n" + "ldr q3, [inptr3], #0x10\n" + "beq 6f\n" + + "4:" + "ldr q4, [inptr4], #0x10\n" + "ldr q5, [inptr5], #0x10\n" + "fadd v16.4s, v0.4s, v2.4s\n" + + "ldr q6, [inptr6], #0x10\n" + "ldr q7, [inptr7], #0x10\n" + "fadd v17.4s, v1.4s, v3.4s\n" + + "ldr q8, [%x[inptr]], #0x10\n" + "ldr q10, [inptr2], #0x10\n" + "fadd v16.4s, v16.4s, v4.4s\n" + + "ldr q9, [inptr1], #0x10\n" + "ldr q11, [inptr3], #0x10\n" + "fadd v17.4s, v17.4s, v5.4s\n" + + "str q16, [%x[outptr]], #0x10\n" + "prfm pldl1strm, [%x[inptr], #196]\n" + "fsub v18.4s, v2.4s, v4.4s\n" + + "str q17, [outptr01], #0x10\n" + "prfm pldl1strm, [inptr2, #196]\n" + "fsub v19.4s, v3.4s, v5.4s\n" + + "prfm pldl1strm, [inptr1, #196]\n" + "prfm pldl1strm, [inptr3, #196]\n" + "fsub v18.4s, v18.4s, v6.4s\n" + + "prfm pldl1strm, [inptr4, #196]\n" + "prfm pldl1strm, [inptr5, #196]\n" + "fsub v19.4s, v19.4s, v7.4s\n" + + "str q18, [outptr10], #0x10\n" + "prfm pldl1strm, [inptr6, #196]\n" + "prfm pldl1strm, [inptr7, #196]\n" + + "subs channel, channel, #0x4\n" + + "str q19, [outptr11], #0x10\n" + "beq 6f\n" // Branch to tail + + "ldr q12, [inptr4], #0x10\n" + "ldr q13, [inptr5], #0x10\n" + "fadd v16.4s, v8.4s, v10.4s\n" + + "ldr q14, [inptr6], #0x10\n" + "ldr q15, [inptr7], #0x10\n" + "fadd v17.4s, v9.4s, v11.4s\n" + + "ldr q0, [%x[inptr]], #0x10\n" + "ldr q2, [inptr2], #0x10\n" + "fadd v16.4s, v16.4s, v12.4s\n" + + "ldr q1, [inptr1], #0x10\n" + "ldr q3, [inptr3], #0x10\n" + "fadd v17.4s, v17.4s, v13.4s\n" + + "str q16, [%x[outptr]], #0x10\n" + "prfm pldl1strm, [%x[inptr], #196]\n" + "fsub v18.4s, v10.4s, v12.4s\n" + + "str q17, [outptr01], #0x10\n" + "prfm pldl1strm, [inptr2, #196]\n" + "fsub v19.4s, v11.4s, v13.4s\n" + + "prfm pldl1strm, [inptr1, #196]\n" + "prfm pldl1strm, [inptr3, #196]\n" + "fsub v18.4s, v18.4s, v14.4s\n" + + "prfm pldl1strm, [inptr4, #196]\n" + "prfm pldl1strm, [inptr5, #196]\n" + "fsub v19.4s, v19.4s, v15.4s\n" + + "str q18, [outptr10], #0x10\n" + "prfm pldl1strm, [inptr6, #196]\n" + "prfm pldl1strm, [inptr7, #196]\n" + + "subs channel, channel, #0x4\n" + + "str q19, [outptr11], #0x10\n" + "bne 4b\n" // Continue loop + + "5:" // Tail + "ldr q12, [inptr4], #0x10\n" + "ldr q13, [inptr5], #0x10\n" + "fadd v16.4s, v8.4s, v10.4s\n" + + "ldr q14, [inptr6], #0x10\n" + "ldr q15, [inptr7], #0x10\n" + "fadd v17.4s, v9.4s, v11.4s\n" + + "fadd v16.4s, v16.4s, v12.4s\n" + + "fadd v17.4s, v17.4s, v13.4s\n" + + "str q16, [%x[outptr]], #0x10\n" + "fsub v18.4s, v10.4s, v12.4s\n" + "fsub v19.4s, v11.4s, v13.4s\n" + + "str q17, [outptr01], #0x10\n" + "fsub v18.4s, v18.4s, v14.4s\n" + "fsub v19.4s, v19.4s, v15.4s\n" + + "str q18, [outptr10], #0x10\n" + "str q19, [outptr11], #0x10\n" + "b 7f\n" + + "6:" // Tail + "ldr q4, [inptr4], #0x10\n" + "ldr q5, [inptr5], #0x10\n" + "fadd v16.4s, v0.4s, v2.4s\n" + + "ldr q6, [inptr6], #0x10\n" + "ldr q7, [inptr7], #0x10\n" + "fadd v17.4s, v1.4s, v3.4s\n" + + "fadd v16.4s, v16.4s, v4.4s\n" + + "fadd v17.4s, v17.4s, v5.4s\n" + + "str q16, [%x[outptr]], #0x10\n" + "fsub v18.4s, v2.4s, v4.4s\n" + "fsub v19.4s, v3.4s, v5.4s\n" + + "str q17, [outptr01], #0x10\n" + "fsub v18.4s, v18.4s, v6.4s\n" + "fsub v19.4s, v19.4s, v7.4s\n" + + "str q18, [outptr10], #0x10\n" + "str q19, [outptr11], #0x10\n" + + "7:" + "add %x[outptr], %x[outptr], %x[n_channels], LSL #2\n" + "add outptr01, outptr01, %x[n_channels], LSL #2\n" + "add outptr10, outptr10, %x[n_channels], LSL #2\n" + "add outptr11, outptr11, %x[n_channels], LSL #2\n" + + "subs tile_j, tile_j, #1\n" + "bne 3b\n" + + // Progress the output pointers to the new row + "add %x[outptr], %x[outptr], %x[row_stride], LSL #2\n" + "add outptr01, outptr01, %x[row_stride], LSL #2\n" + "add outptr10, outptr10, %x[row_stride], LSL #2\n" + "add outptr11, outptr11, %x[row_stride], LSL #2\n" + + "subs tile_i, tile_i, #1\n" + "bne 2b\n" + + "subs %[batch], %[batch], #1\n" + "bne 1b\n" + "5:" + + ".unreq inptr1\n" + ".unreq inptr2\n" + ".unreq inptr3\n" + ".unreq inptr4\n" + ".unreq inptr5\n" + ".unreq inptr6\n" + ".unreq inptr7\n" + ".unreq inptr8\n" + ".unreq outptr01\n" + ".unreq outptr10\n" + ".unreq outptr11\n" + : [batch] "+r" (batch), + [outptr] "+r" (outptr), + [inptr] "+r" (inptr) + : [tile_M] "r" (tile_M), + [tile_N] "r" (tile_N), + [n_channels] "r" (output_shape.n_channels), + [row_stride] "r" (output_shape.n_cols * output_shape.n_channels) + : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", + "x12", "x13", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "cc", "memory" + ); +} +/*****************************************************************************/ + +/*****************************************************************************/ +template <> +inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::execute( + const Tensor4DShape &output_shape, + float* const matrices[16], float* const output +) { + // profiler prof; + + // Allocate memory for the intermediate matrices + const int tile_M = iceildiv(output_shape.n_rows, 2); + const int tile_N = iceildiv(output_shape.n_cols, 2); + const int n_rows = output_shape.n_batches * tile_M * tile_N; + const int n_channels = output_shape.n_channels; + float* matrices_zf = reinterpret_cast( + calloc(8 * n_rows * n_channels, sizeof(float)) + ); + + // Perform the first stage transform, computing ZF. + const auto f_compute_zf = [&] () { + switch (n_channels % 4) { + case 0: + compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); + break; + case 1: + compute_zf<1>(n_rows, n_channels, matrices_zf, matrices); + break; + case 2: + compute_zf<2>(n_rows, n_channels, matrices_zf, matrices); + break; + case 3: + compute_zf<3>(n_rows, n_channels, matrices_zf, matrices); + }; + }; + // prof("Compute ZF", f_compute_zf, 16 * n_rows * n_channels * sizeof(float), 0, 8 * n_rows * n_channels * sizeof(float)); + f_compute_zf(); + + // Perform the second stage transform, finishing Z F Z^T - variable dispatch + // based on size of the output and the channel tail. + const auto f_compute_zfzT = [&] () { + if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { + constexpr bool tail_M = true, tail_N = true; + switch (n_channels % 4) { + case 0: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 1: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 2: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 3: + compute_zfzT(output_shape, output, matrices_zf); + } + } else if (output_shape.n_rows % 2) { + constexpr bool tail_M = true, tail_N = false; + switch (n_channels % 4) { + case 0: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 1: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 2: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 3: + compute_zfzT(output_shape, output, matrices_zf); + } + } else if (output_shape.n_cols % 2) { + constexpr bool tail_M = false, tail_N = true; + switch (n_channels % 4) { + case 0: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 1: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 2: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 3: + compute_zfzT(output_shape, output, matrices_zf); + } + } else { + constexpr bool tail_M = false, tail_N = false; + switch (n_channels % 4) { + case 0: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 1: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 2: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 3: + compute_zfzT(output_shape, output, matrices_zf); + } + } + }; + // prof("Compute ZFZT", f_compute_zfzT, 8 * n_rows * n_channels * sizeof(float), 0, 4 * n_rows * n_channels * sizeof(float)); + f_compute_zfzT(); + + free(reinterpret_cast(matrices_zf)); +} +/*****************************************************************************/ + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/utils.hpp b/src/core/NEON/kernels/winograd/utils.hpp new file mode 100644 index 0000000000..14e709f028 --- /dev/null +++ b/src/core/NEON/kernels/winograd/utils.hpp @@ -0,0 +1,55 @@ + +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include + +inline double TimeInUs(void) { +#ifdef CYCLE_PROFILING + timespec t; + clock_gettime(CLOCK_THREAD_CPUTIME_ID, &t); + return 1e6*t.tv_sec + 1e-3*t.tv_nsec; +#else + return 0; +#endif +} + +inline int iceildiv(const int a, const int b) { + return (a + b - 1) / b; +} + +template +inline T roundup(const T a, const T b) { + return a + b - (a % b); +} + +inline void PrintMatrix(const float* const m, const int M, const int N, const int row_stride) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + printf("%.3f ", m[i*row_stride + j]); + } + printf("\n"); + } + printf("\n"); +} diff --git a/src/core/NEON/kernels/winograd/winograd_gemm.hpp b/src/core/NEON/kernels/winograd/winograd_gemm.hpp new file mode 100644 index 0000000000..59afa2f5ab --- /dev/null +++ b/src/core/NEON/kernels/winograd/winograd_gemm.hpp @@ -0,0 +1,345 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include +#include +#include + +#include "gemm.hpp" +#include "profiler.hpp" +#include "utils.hpp" +#include "shims.hpp" + +#include "transforms.hpp" + +namespace winograd { + /***************************************************************************/ + /* Implementation of the Winograd F(2x2, 3x3, 4x4) algorithm using GEMM + * internally. + */ + template + class Winograd2x2_3x3GEMM { + public: + /* Instantiate a new Winograd operator. + */ + Winograd2x2_3x3GEMM(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage); + virtual ~Winograd2x2_3x3GEMM(); + + /** Transform the weights into the Winograd domain. + */ + template > + void transform_weights(const TIn* const kernel, void *transform_working_space); + + /* Initializes matrices pointers, to be called once before execute() + */ + template > + void reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const TIn* const input, void* working_space); + + /* Apply the Winograd operator to some input. + */ + template > + void reshape_output(const Tensor4DShape& input_shape, const PaddingType padding_type, TOut* const output); + + + /* Apply the Winograd operator to some input. + */ + void execute(size_t first, size_t last); + + /* Get the memory required to transform the kernel. + */ + static inline size_t get_kernel_transform_working_size(const KernelShape &shape); + + /* Get the output shape of a convolution. + */ + static Tensor4DShape get_output_shape(const Tensor4DShape &input_shape, const KernelShape &k_shape, + const PaddingType padding_type); + + /* Get the memory required to instantiate a new Winograd operator. + */ + static size_t get_kernel_storage_size(const KernelShape &shape); + + /* Get the memory required to apply a Winograd operator to some input. + */ + static size_t get_working_space_size(const Tensor4DShape &input_shape,const KernelShape &k_shape, + const PaddingType padding); + + + Winograd2x2_3x3GEMM(const Winograd2x2_3x3GEMM &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + Winograd2x2_3x3GEMM &operator=(const Winograd2x2_3x3GEMM &) = delete; + /** Allow instances of this class to be moved */ + Winograd2x2_3x3GEMM(Winograd2x2_3x3GEMM &&) = default; + /** Allow instances of this class to be moved */ + Winograd2x2_3x3GEMM &operator=(Winograd2x2_3x3GEMM &&) = default; + + protected: + /* Get the memory required by a single "input" matrix. + */ + static size_t get_input_matrix_size(const Tensor4DShape &input_shape,const KernelShape &k_shape, + const PaddingType padding); + + /* Get the memory required by a single "output" matrix. + */ + static size_t get_output_matrix_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, + const PaddingType padding); + + /* Get the memory required by a single "kernel" matrix. + */ + static size_t get_kernel_matrix_size(const KernelShape &shape); + + const KernelShape kernel_shape; // Shape of applied kernel + const Tensor4DShape in_shape; + const PaddingType padding; + + const int kernel_matrix_row_stride; // Stride within kernel matrix + + const bool manage_kernel_storage; // Free kernel storage when done + void* const _kernel_storage; // Base pointer for kernel matrices + + profiler prof; // Profiler + + TIn *kernel_matrices[16]; // Prepared form of kernel + TIn *input_matrices[16]; + TOut *output_matrices[16]; + + + static const int M_BLOCK = 4; + static const int N_BLOCK = 16; + }; +} // namespace winograd + +template +size_t winograd::Winograd2x2_3x3GEMM::get_kernel_transform_working_size( + const KernelShape &shape +) +{ + // Need to re-order the kernel into HWIO form, require enough space to + // represent the tensor. + return sizeof(TIn) * shape.size(); +} + + +template +template +void winograd::Winograd2x2_3x3GEMM::transform_weights( + const TIn* const kernel, + void *transform_working_space +) +{ + const int kernel_matrix_size_bytes = get_kernel_matrix_size(kernel_shape); + int8_t* const ks_bytes = reinterpret_cast(_kernel_storage); + for (int i = 0; i < 16; i++) { + kernel_matrices[i] = reinterpret_cast( + ks_bytes + i*kernel_matrix_size_bytes); + } + + const TIn *kernel_hwio = kernel; + if( transform_working_space) + { + kernel_hwio = reinterpret_cast(transform_working_space); + ofm_ifm_h_w_to_h_w_ifm_ofm( + kernel, const_cast(kernel_hwio), + kernel_shape.n_output_channels, + kernel_shape.n_input_channels, + kernel_shape.n_rows, + kernel_shape.n_cols + ); + } + KernelTransform::execute( + kernel_shape, kernel_hwio, kernel_matrices[0], + kernel_matrix_size_bytes / sizeof(TIn), + kernel_matrix_row_stride + ); +} + +template +winograd::Winograd2x2_3x3GEMM::Winograd2x2_3x3GEMM( const KernelShape &kernel_shape, const Tensor4DShape input_shape, + const PaddingType padding_type, void *kernel_storage) + : kernel_shape(kernel_shape), in_shape(input_shape), padding(padding_type),kernel_matrix_row_stride(roundup(kernel_shape.n_output_channels, N_BLOCK)), manage_kernel_storage(false), + _kernel_storage(kernel_storage), prof() { + memset(kernel_matrices, 0x00, sizeof(TIn)*16); + memset(input_matrices, 0x00, sizeof(TIn)*16); + memset(output_matrices, 0x00, sizeof(TOut)*16); +} + +/*****************************************************************************/ +template +winograd::Winograd2x2_3x3GEMM::~Winograd2x2_3x3GEMM() {} + +/*****************************************************************************/ +template +template +void winograd::Winograd2x2_3x3GEMM::reshape_input( + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const TIn* const input, + void *working_space +) { + assert(working_space); + int8_t* const ws_bytes = reinterpret_cast(working_space); + // Split the working space into that required for 16 input matrices and + // output matrices. + const int in_matrix_stride_bytes = get_input_matrix_size(input_shape, kernel_shape, padding_type); + const int out_matrix_stride_bytes = get_output_matrix_size(input_shape, kernel_shape, padding_type); + + for (int i = 0; i < 16; i++) { + input_matrices[i] = reinterpret_cast( + ws_bytes + i*in_matrix_stride_bytes); + output_matrices[i] = reinterpret_cast( + ws_bytes + 16*in_matrix_stride_bytes + i*out_matrix_stride_bytes); + } + + // Compute shape for the GEMM + const auto output_shape = get_output_shape(input_shape,kernel_shape, padding_type); + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + const int K = kernel_shape.n_input_channels; + + const int in_matrix_row_stride = K; + const int in_matrix_batch_stride = tile_rows*tile_cols*in_matrix_row_stride; + + // Transform the input tensor into an appropriate form + auto input_prep = [&] () { + InputTransform::execute( + input, input_shape, padding_type, tile_rows, tile_cols, + input_matrices[0], in_matrix_stride_bytes / sizeof(TIn), + in_matrix_batch_stride, in_matrix_row_stride + ); + }; + prof( + "Input Prep", input_prep, + InputTransform::bytes_read(input_shape, output_shape), + InputTransform::flops_performed(input_shape, output_shape), + InputTransform::bytes_written(input_shape, output_shape) + ); + +} + +/*****************************************************************************/ +template +template +void winograd::Winograd2x2_3x3GEMM::reshape_output(const Tensor4DShape& input_shape, const PaddingType padding_type, TOut* const output) { + assert(output_matrices[0]); + const int out_matrix_stride_bytes = get_output_matrix_size(input_shape, kernel_shape, padding_type); + const auto output_shape = get_output_shape(input_shape,kernel_shape, padding_type); + const int out_matrix_row_stride = kernel_matrix_row_stride; + + // Transform the output tensor into an appropriate form + OutputTransform::execute( + output_shape, + output_matrices[0], + out_matrix_stride_bytes / sizeof(TOut), + out_matrix_row_stride, + output + ); +} + + +/*****************************************************************************/ +template +void winograd::Winograd2x2_3x3GEMM::execute( size_t first, size_t last ) { + assert(input_matrices[0] && kernel_matrices[0] && output_matrices[0]); + assert(first < 16 && last < 16 && first < last); + // Compute shape for the GEMM + const auto output_shape = get_output_shape(in_shape,kernel_shape, padding); + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + const int M = in_shape.n_batches * tile_rows * tile_cols; + const int K = kernel_shape.n_input_channels; + const int N = kernel_shape.n_output_channels; + + const int in_matrix_row_stride = K; + const int out_matrix_row_stride = kernel_matrix_row_stride; + // Perform the GEMMs + for (size_t i = first; i <= last; i++) { + BlockedGemm( + input_matrices[i], kernel_matrices[i], output_matrices[i], M, K, N, + in_matrix_row_stride, kernel_matrix_row_stride, out_matrix_row_stride + ); +// prof("GEMM", perform_gemm, 0, 2*M*K*N, 0); // TODO Memory + } + +} + +/*****************************************************************************/ +template +Tensor4DShape winograd::Winograd2x2_3x3GEMM::get_output_shape( + const Tensor4DShape &in_shape, const KernelShape &k_shape, const PaddingType padding) { + return Tensor4DShape { + in_shape.n_batches, + (padding == PADDING_SAME) ? in_shape.n_rows : in_shape.n_rows - 2, + (padding == PADDING_SAME) ? in_shape.n_cols : in_shape.n_cols - 2, + k_shape.n_output_channels + }; +} + +template +size_t winograd::Winograd2x2_3x3GEMM::get_kernel_storage_size( + const KernelShape &shape) { + return 16 * get_kernel_matrix_size(shape); +} + +template +size_t winograd::Winograd2x2_3x3GEMM::get_kernel_matrix_size( + const KernelShape &shape) { + const int K = shape.n_input_channels; + const int N = roundup(shape.n_output_channels, N_BLOCK); + return sizeof(TIn) * K * N; +} + +template +size_t winograd::Winograd2x2_3x3GEMM::get_working_space_size( + const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type +) { + return 16 * get_input_matrix_size(input_shape, k_shape, padding_type) + + 16 * get_output_matrix_size(input_shape, k_shape, padding_type); +} + +template +size_t winograd::Winograd2x2_3x3GEMM::get_input_matrix_size( + const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type +) { + // Compute shape for the GEMM + const auto output_shape = get_output_shape(input_shape, k_shape, padding_type); + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + const int M = roundup(tile_rows * tile_cols, M_BLOCK); + const int K = k_shape.n_input_channels; + + return input_shape.n_batches * M * K * sizeof(TIn); +} + +template +size_t winograd::Winograd2x2_3x3GEMM::get_output_matrix_size( + const Tensor4DShape& input_shape, const KernelShape &k_shape,const PaddingType padding_type +) { + // Compute shape for the GEMM + const auto output_shape = get_output_shape(input_shape, k_shape, padding_type); + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + const int M = roundup(tile_rows * tile_cols, M_BLOCK); + const int N = roundup(k_shape.n_output_channels, N_BLOCK); + + return input_shape.n_batches * M * N * sizeof(TOut); +} diff --git a/src/core/NEON/kernels/winograd/winograd_shim_nchw.hpp b/src/core/NEON/kernels/winograd/winograd_shim_nchw.hpp new file mode 100644 index 0000000000..c5bcffbaef --- /dev/null +++ b/src/core/NEON/kernels/winograd/winograd_shim_nchw.hpp @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include +#include + +#include "gemm.hpp" +#include "profiler.hpp" +#include "utils.hpp" +#include "shims.hpp" +#include "winograd_gemm.hpp" + +#include "transforms.hpp" + +#ifndef ALLOC_ALIGN +#define ALLOC_ALIGN 64 +#endif // ALLOC_ALIGN + + +namespace winograd_shim_nchw { + /***************************************************************************/ + /* Implementation of the Winograd F(2x2, 3x3, 4x4) algorithm using GEMM + * internally. + */ + template + class Winograd2x2_3x3GEMM : public winograd::Winograd2x2_3x3GEMM { + public: + /* Instantiate a new Winograd operator. + */ + Winograd2x2_3x3GEMM(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage); + + void nchw2nhwc( const Tensor4DShape& input_shape, const PaddingType padding_type, void *working_space, const TIn* const input); + void nhwc2nchw( const Tensor4DShape& input_shape, const PaddingType padding_type, void *working_space, TOut* const output); + + + std::pair get_nhwc_ptrs(const Tensor4DShape& input_shape,const PaddingType padding_type,void *working_space); + + static size_t get_working_space_size(const Tensor4DShape &input_shape,const KernelShape &k_shape, const PaddingType padding); + protected: + /* Get the memory required to store an NHWC copy of the input tensor. */ + static size_t get_working_nhwc_input_size(const Tensor4DShape &input_shape); + + /* Get the memory required to store an NHWC copy of the input tensor. */ + static size_t get_working_nhwc_output_size(const Tensor4DShape &output_shape, const KernelShape &k_shape, const PaddingType padding) ; + }; +} // namespace winograd + +/*****************************************************************************/ +template +winograd_shim_nchw::Winograd2x2_3x3GEMM::Winograd2x2_3x3GEMM( + const KernelShape &kernel_shape, const Tensor4DShape input_shape, + const PaddingType padding_type, void *kernel_storage +) : winograd::Winograd2x2_3x3GEMM(kernel_shape,input_shape,padding_type,kernel_storage) { +} + +/*****************************************************************************/ +template +void winograd_shim_nchw::Winograd2x2_3x3GEMM::nchw2nhwc(const Tensor4DShape& input_shape, const PaddingType padding_type, void *working_space, const TIn* const input) { + assert(working_space); + int8_t* const ws_bytes = reinterpret_cast(working_space); + + // Extract the top chunk of the working space to store the input and output + // tensors in NHWC format. + const int in_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_input_matrix_size(input_shape, this->kernel_shape, padding_type); + const int out_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_output_matrix_size(input_shape, this->kernel_shape, padding_type); + + // Allocate working space for the input and output in NHWC format + TIn* const input_nhwc = reinterpret_cast( + ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes) + ); + + // Re-order the input tensor + this->prof( + "NCHW -> NHWC", + [input, input_shape, input_nhwc] () { + nchw_to_nhwc( + input, input_nhwc, + input_shape.n_batches, + input_shape.n_channels, + input_shape.n_rows, + input_shape.n_cols + ); + }, + input_shape.size(), 0, input_shape.size() + ); +} + +/*****************************************************************************/ +template +void winograd_shim_nchw::Winograd2x2_3x3GEMM::nhwc2nchw(const Tensor4DShape& input_shape, const PaddingType padding_type, + void *working_space, TOut* const output) { + + assert(working_space); + int8_t* const ws_bytes = reinterpret_cast(working_space); + + // Extract the top chunk of the working space to store the input and output + // tensors in NHWC format. + const int in_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_input_matrix_size(input_shape, this->kernel_shape, padding_type); + const int out_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_output_matrix_size(input_shape, this->kernel_shape, padding_type); + + TOut* const output_nhwc = reinterpret_cast(ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes) + get_working_nhwc_input_size(input_shape)); + + // Re-order the output tensor into NCHW + const auto output_shape = winograd::Winograd2x2_3x3GEMM::get_output_shape(input_shape, this->kernel_shape, padding_type); + this->prof( + "NHWC -> NCHW", + [output_nhwc, output_shape, output] () { + nhwc_to_nchw( + output_nhwc, output, + output_shape.n_batches, + output_shape.n_rows, + output_shape.n_cols, + output_shape.n_channels + ); + }, + output_shape.size(), 0, output_shape.size() + ); +} + + +/*****************************************************************************/ +template +std::pair winograd_shim_nchw::Winograd2x2_3x3GEMM::get_nhwc_ptrs( + const Tensor4DShape& input_shape, + const PaddingType padding_type, + void *working_space +) { + assert(working_space); + int8_t* const ws_bytes = reinterpret_cast(working_space); + + // Extract the top chunk of the working space to store the input and output + // tensors in NHWC format. + const int in_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_input_matrix_size(input_shape, this->kernel_shape, padding_type); + const int out_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_output_matrix_size(input_shape, this->kernel_shape, padding_type); + + // Allocate working space for the input and output in NHWC format + TIn* input_nhwc = reinterpret_cast(ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes)); + TOut* output_nhwc = reinterpret_cast(ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes) + get_working_nhwc_input_size(input_shape)); + return std::make_pair(output_nhwc,input_nhwc); +} + + + + +/*****************************************************************************/ +template +size_t winograd_shim_nchw::Winograd2x2_3x3GEMM::get_working_space_size( + const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type +) { + // TODO Add memory required for NHWC copies of input tensors + return winograd::Winograd2x2_3x3GEMM::get_working_space_size( + input_shape, k_shape, padding_type) + + get_working_nhwc_input_size(input_shape) + + get_working_nhwc_output_size(input_shape, k_shape, padding_type); +} + +template +size_t winograd_shim_nchw::Winograd2x2_3x3GEMM::get_working_nhwc_input_size( + const Tensor4DShape& input_shape +) { + return roundup(input_shape.size() * sizeof(TIn), static_cast(ALLOC_ALIGN)); +} + +template +size_t winograd_shim_nchw::Winograd2x2_3x3GEMM::get_working_nhwc_output_size( + const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type +) { + const auto output_shape = winograd::Winograd2x2_3x3GEMM::get_output_shape(input_shape,k_shape, padding_type); + return roundup(output_shape.size() * sizeof(TIn), static_cast(ALLOC_ALIGN)); +} diff --git a/src/runtime/NEON/functions/NEWinogradLayer.cpp b/src/runtime/NEON/functions/NEWinogradLayer.cpp index a9dec4ea0d..3251de4ae4 100644 --- a/src/runtime/NEON/functions/NEWinogradLayer.cpp +++ b/src/runtime/NEON/functions/NEWinogradLayer.cpp @@ -83,18 +83,18 @@ void NEWinogradLayer::configure(const ITensor *input, const ITensor *weights, co // Get the memory required to instantiate a new Winograd operator. constexpr size_t kstore_alignment = 64; - const size_t kernel_storage_per_thread = Winograd3x3F32::get_kernel_storage_size(kernel_shape); + const size_t kernel_storage_per_thread = NEWinogradLayerKernel::get_kernel_storage_size(kernel_shape); _kernel_storage.allocator()->init(TensorInfo(TensorShape{ (kernel_storage_per_thread + kstore_alignment - 1) }, 1, DataType::U8)); _memory_group.manage(&_kernel_storage); // Get workbench size and allocate memory constexpr size_t wspace_alignment = 64; - const size_t ws_size = Winograd3x3F32::get_working_space_size(in_shape, kernel_shape, padding); + const size_t ws_size = NEWinogradLayerKernel::get_working_space_size(in_shape, kernel_shape, padding); _workspace.allocator()->init(TensorInfo(TensorShape{ (ws_size + wspace_alignment - 1) }, 1, DataType::U8)); _memory_group.manage(&_workspace); // Workspace for weights transform - const size_t weights_transform_size = Winograd3x3F32::get_kernel_transform_working_size(kernel_shape); + const size_t weights_transform_size = NEWinogradLayerKernel::get_kernel_transform_working_size(kernel_shape); _weights_workspace.allocator()->init(TensorInfo(TensorShape{ (weights_transform_size + wspace_alignment - 1) }, 1, DataType::U8)); _memory_group.manage(&_weights_workspace); @@ -125,7 +125,7 @@ void NEWinogradLayer::run() _conv->nchw2nhwc(in_shape, padding, _workspace.buffer(), reinterpret_cast(_input->buffer())); //Get ptrs into the workspace - std::pair nhwc_ptrs = _conv->get_nhwc_ptrs(in_shape, padding, _workspace.buffer()); + std::pair nhwc_ptrs = _conv->get_nhwc_ptrs(in_shape, padding, _workspace.buffer()); //Setup matrices ptrs and transfor the input tensor to the appropriate form before running GEMM. _conv->reshape_input(in_shape, padding, nhwc_ptrs.second, _workspace.buffer()); -- cgit v1.2.1