From 8951933e5dd7be8d922affea3cc23a48a05b694d Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Fri, 17 Nov 2017 11:52:36 +0000 Subject: COMPMID-687: Winograd layer. Change-Id: Ica682d08e851491bf4a26b8d17908c014844055e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/110990 Reviewed-by: Anthony Barbier Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com --- SConstruct | 2 + arm_compute/core/NEON/NEKernels.h | 1 + .../core/NEON/kernels/NEWinogradLayerKernel.h | 70 + arm_compute/core/NEON/kernels/winograd/alloc.hpp | 30 + arm_compute/core/NEON/kernels/winograd/gemm.hpp | 127 ++ .../core/NEON/kernels/winograd/gemm/a64_sgemm.hpp | 355 +++++ .../NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp | 1445 +++++++++++++++++++ arm_compute/core/NEON/kernels/winograd/perf.h | 32 + .../core/NEON/kernels/winograd/profiler.hpp | 244 ++++ arm_compute/core/NEON/kernels/winograd/shims.hpp | 319 +++++ arm_compute/core/NEON/kernels/winograd/tensor.hpp | 210 +++ .../core/NEON/kernels/winograd/transforms.hpp | 29 + .../kernels/winograd/transforms/input_2x2_3x3.hpp | 638 +++++++++ .../transforms/input_2x2_3x3/a64_float.hpp | 1498 ++++++++++++++++++++ .../input_2x2_3x3/a64_float_channelwise.hpp | 961 +++++++++++++ .../kernels/winograd/transforms/kernel_2x2_3x3.hpp | 195 +++ .../transforms/kernel_2x2_3x3/a64_float.hpp | 822 +++++++++++ .../kernels/winograd/transforms/output_2x2_3x3.hpp | 356 +++++ .../transforms/output_2x2_3x3/a64_float.hpp | 650 +++++++++ .../output_2x2_3x3/a64_float_two_stage.hpp | 655 +++++++++ arm_compute/core/NEON/kernels/winograd/utils.hpp | 55 + .../core/NEON/kernels/winograd/winograd_gemm.hpp | 346 +++++ .../NEON/kernels/winograd/winograd_shim_nchw.hpp | 192 +++ arm_compute/runtime/NEON/NEFunctions.h | 1 + .../runtime/NEON/functions/NEWinogradLayer.h | 84 ++ scripts/check_bad_style.sh | 16 +- scripts/clang_tidy_rules.py | 3 + src/core/NEON/kernels/NEWinogradLayerKernel.cpp | 60 + src/runtime/NEON/functions/NEWinogradLayer.cpp | 155 ++ tests/datasets/SmallConvolutionLayerDataset.h | 12 + tests/validation/NEON/ConvolutionLayer.cpp | 19 + tests/validation/fixtures/WinogradLayerFixture.h | 145 ++ 32 files changed, 9719 insertions(+), 8 deletions(-) create mode 100644 arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h create mode 100644 arm_compute/core/NEON/kernels/winograd/alloc.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/gemm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/perf.h create mode 100644 arm_compute/core/NEON/kernels/winograd/profiler.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/shims.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/tensor.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/utils.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp create mode 100644 arm_compute/runtime/NEON/functions/NEWinogradLayer.h create mode 100644 src/core/NEON/kernels/NEWinogradLayerKernel.cpp create mode 100644 src/runtime/NEON/functions/NEWinogradLayer.cpp create mode 100644 tests/validation/fixtures/WinogradLayerFixture.h diff --git a/SConstruct b/SConstruct index 6f4835828a..e7504228d3 100644 --- a/SConstruct +++ b/SConstruct @@ -180,6 +180,8 @@ if not GetOption("help"): if env['standalone']: env.Append(CXXFLAGS = ['-fPIC']) env.Append(LINKFLAGS = ['-static-libgcc','-static-libstdc++']) + if env['cppthreads']: + env.Append(LINKFLAGS = ['-lpthread']) if env['Werror']: env.Append(CXXFLAGS = ['-Werror']) diff --git a/arm_compute/core/NEON/NEKernels.h b/arm_compute/core/NEON/NEKernels.h index 7fb5f78f13..281f06305f 100644 --- a/arm_compute/core/NEON/NEKernels.h +++ b/arm_compute/core/NEON/NEKernels.h @@ -111,6 +111,7 @@ #include "arm_compute/core/NEON/kernels/NETransposeKernel.h" #include "arm_compute/core/NEON/kernels/NEWarpKernel.h" #include "arm_compute/core/NEON/kernels/NEWeightsReshapeKernel.h" +#include "arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h" #include "arm_compute/core/NEON/kernels/arm32/NEGEMMAArch32Kernel.h" #include "arm_compute/core/NEON/kernels/arm64/NEGEMMAArch64Kernel.h" #include "arm_compute/core/NEON/kernels/arm64/NEGEMMLowpAArch64A53Kernel.h" diff --git a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h new file mode 100644 index 0000000000..1e7ca64b8c --- /dev/null +++ b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__ +#define __ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__ + +#include "arm_compute/core/NEON/INEKernel.h" + +#include "arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp" + +namespace arm_compute +{ +class ITensor; + +class NEWinogradLayerKernel : public INEKernel +{ +public: + using Winograd3x3F32 = winograd_shim_nchw::Winograd2x2_3x3GEMM; + + /** Constructor */ + NEWinogradLayerKernel(); + + /** Prevent instances of this class from being copied (As this class contains pointers) */ + NEWinogradLayerKernel(const NEWinogradLayerKernel &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + NEWinogradLayerKernel &operator=(const NEWinogradLayerKernel &) = delete; + /** Allow instances of this class to be moved */ + NEWinogradLayerKernel(NEWinogradLayerKernel &&) = default; + /** Allow instances of this class to be moved */ + NEWinogradLayerKernel &operator=(NEWinogradLayerKernel &&) = default; + + virtual ~NEWinogradLayerKernel() = default; + + /** Initialise the kernel + * + * @param[in,out] output Output tensor to store the result of matrix multiplication. + * @param[in] convolver A pointer to the winograd convolver, this object must have been configured and is ready to execute 16 GEMMS . + */ + void configure(ITensor *output, Winograd3x3F32 *convolver); + + // Inherited methods overridden: + void run(const Window &window, const ThreadInfo &info) override; + +protected: + Winograd3x3F32 *_convolver; + ITensor *_output; +}; + +} // namespace arm_compute +#endif /*__ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__*/ diff --git a/arm_compute/core/NEON/kernels/winograd/alloc.hpp b/arm_compute/core/NEON/kernels/winograd/alloc.hpp new file mode 100644 index 0000000000..ef6f2b5115 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/alloc.hpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef ALLOC_ALIGN +#define ALLOCATE(x) aligned_alloc(ALLOC_ALIGN, x) +#else +#define ALLOCATE(x) malloc(x) +#endif diff --git a/arm_compute/core/NEON/kernels/winograd/gemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm.hpp new file mode 100644 index 0000000000..564016a646 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/gemm.hpp @@ -0,0 +1,127 @@ + +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include "utils.hpp" + +template +void Gemm(const TIn* const a, const TIn* const b, TOut *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride, + const bool a_transposed=false, + const bool b_transposed=false) { + // Array access methods + const auto A = [a, a_transposed, M, K, a_row_stride] (const int i, const int j) -> TIn { + return a[(!a_transposed) ? i*a_row_stride + j : i + j*M]; + }; + + const auto B = [b, b_transposed, K, N, b_row_stride] (const int i, const int j) -> TIn { + return b[(!b_transposed) ? i*b_row_stride + j : i + j*N]; + }; + + const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& { + return c[i*c_row_stride + j]; + }; + + // Perform the matrix multiplication + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < K; k++) { + C(i, j) += A(i, k) * B(k, j); + } + } + } +} + +template +void BlockedGemm( + const TIn* const a, const TIn* const b, TOut *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + // Array access methods + const auto A = [a, M, K, a_row_stride] (const int i, const int j) -> TIn { + return a[i*a_row_stride + j]; + }; + + const auto B = [b, K, N, b_row_stride] (const int i, const int j) -> TIn { + return b[i*b_row_stride + j]; + }; + + const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& { + return c[i*c_row_stride + j]; + }; + + const int M_BLOCKS = iceildiv(M, M_BLOCK); + const int N_BLOCKS = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < M_BLOCKS; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < N_BLOCKS; nblock++) { + // Create an appropriately sized block of accumulators + TOut accum[M_BLOCK][N_BLOCK]; + for (int i = 0; i < M_BLOCK; i++) { + for (int j = 0; j < N_BLOCK; j++) { + accum[i][j] = static_cast(0); + } + } + + // Perform this portion of the matrix multiply + for (int k = 0; k < K; k++) { + // Load elements of A + TIn elems_a[M_BLOCK]; + for (int i = 0; i < M_BLOCK; i++) { + elems_a[i] = A(mblock*M_BLOCK + i, k); + } + + // Load elements of B + TIn elems_b[N_BLOCK]; + for (int j = 0; j < N_BLOCK; j++) { + elems_b[j] = B(k, nblock*N_BLOCK + j); + } + + // Perform the partial matrix multiply + for (int i = 0; i < M_BLOCK; i++) { + for (int j = 0; j < N_BLOCK; j++) { + accum[i][j] += elems_a[i] * elems_b[j]; + } + } + } + + // Store the partial product + for (int i = 0; i < M_BLOCK; i++) { + for (int j = 0; j < N_BLOCK; j++) { + C(mblock*M_BLOCK + i, nblock*N_BLOCK + j) = accum[i][j]; + } + } + } + } +} + +#include "gemm/a64_sgemm.hpp" diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp new file mode 100644 index 0000000000..e1b7488c31 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp @@ -0,0 +1,355 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include +#include "../utils.hpp" + +#ifdef __aarch64__ + +template <> +inline void BlockedGemm<8, 12, float, float>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int M_BLOCK = 8; + const int N_BLOCK = 12; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = K; + + asm volatile ( + // Create an 8x12 block of accumulators + " A_1 .req v27\n" + "sA_1 .req s27\n" + " A_2 .req v28\n" + "sA_2 .req s28\n" + " A_3 .req v29\n" + "sA_3 .req s29\n" + " A_4 .req v30\n" + "sA_4 .req s30\n" + + " B_1 .req v24\n" " B_2 .req v25\n" " B_3 .req v26\n" + "qB_1 .req q24\n" "qB_2 .req q25\n" "qB_3 .req q26\n" + + " C_11 .req v0\n" " C_12 .req v1\n" " C_13 .req v2\n" + " C_21 .req v3\n" " C_22 .req v4\n" " C_23 .req v5\n" + " C_31 .req v6\n" " C_32 .req v7\n" " C_33 .req v8\n" + " C_41 .req v9\n" " C_42 .req v10\n" " C_43 .req v11\n" + " C_51 .req v12\n" " C_52 .req v13\n" " C_53 .req v14\n" + " C_61 .req v15\n" " C_62 .req v16\n" " C_63 .req v17\n" + " C_71 .req v18\n" " C_72 .req v19\n" " C_73 .req v20\n" + " C_81 .req v21\n" " C_82 .req v22\n" " C_83 .req v23\n" + + "qC_11 .req q0\n" "qC_12 .req q1\n" "qC_13 .req q2\n" + "qC_21 .req q3\n" "qC_22 .req q4\n" "qC_23 .req q5\n" + "qC_31 .req q6\n" "qC_32 .req q7\n" "qC_33 .req q8\n" + "qC_41 .req q9\n" "qC_42 .req q10\n" "qC_43 .req q11\n" + "qC_51 .req q12\n" "qC_52 .req q13\n" "qC_53 .req q14\n" + "qC_61 .req q15\n" "qC_62 .req q16\n" "qC_63 .req q17\n" + "qC_71 .req q18\n" "qC_72 .req q19\n" "qC_73 .req q20\n" + "qC_81 .req q21\n" "qC_82 .req q22\n" "qC_83 .req q23\n" + + "aptr1 .req x17\n" + "aptr2 .req x18\n" + "aptr3 .req x19\n" + "aptr4 .req x20\n" + "aptr5 .req x21\n" + "aptr6 .req x22\n" + "aptr7 .req x23\n" + + // Initialise accumulators with 0 + // Initialise pointers + "movi C_11.4s, #0\n" + "add aptr1, %x[aptr], %x[a_row_stride]\n" + "movi C_12.4s, #0\n" + "add aptr2, aptr1, %x[a_row_stride]\n" + "movi C_13.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride]\n" + "movi C_21.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride]\n" + "movi C_22.4s, #0\n" + "add aptr5, aptr4, %x[a_row_stride]\n" + "movi C_23.4s, #0\n" + "add aptr6, aptr5, %x[a_row_stride]\n" + "movi C_31.4s, #0\n" + "add aptr7, aptr6, %x[a_row_stride]\n" + "movi C_32.4s, #0\n" + "ldr qB_1, [%x[bptr]]\n" + "movi C_33.4s, #0\n" + "ldr qB_2, [%x[bptr], #0x10]\n" + "movi C_41.4s, #0\n" + "prfm pldl1keep, [%x[bptr], #0x00]\n" + "movi C_42.4s, #0\n" + "prfm pldl1keep, [%x[bptr], #0x10]\n" + "movi C_43.4s, #0\n" + "prfm pldl1keep, [%x[bptr], #0x20]\n" + "movi C_51.4s, #0\n" + "prfm pldl1keep, [%x[aptr], #0x00]\n" + "movi C_52.4s, #0\n" + "prfm pldl1keep, [ aptr1, #0x00]\n" + "movi C_53.4s, #0\n" + "prfm pldl1keep, [ aptr2, #0x00]\n" + "movi C_61.4s, #0\n" + "prfm pldl1keep, [ aptr3, #0x00]\n" + "movi C_62.4s, #0\n" + "prfm pldl1keep, [ aptr4, #0x00]\n" + "movi C_63.4s, #0\n" + "prfm pldl1keep, [ aptr5, #0x00]\n" + "movi C_71.4s, #0\n" + "prfm pldl1keep, [ aptr6, #0x00]\n" + "movi C_72.4s, #0\n" + "prfm pldl1keep, [ aptr7, #0x00]\n" + "movi C_73.4s, #0\n" + "ldr sA_1, [%x[aptr]], #0x4\n" + "movi C_81.4s, #0\n" + "ldr sA_2, [ aptr1], #0x4\n" + "movi C_82.4s, #0\n" + "ldr sA_3, [ aptr2], #0x4\n" + "movi C_83.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 2f\n" + + "1:" + "fmla C_11.4s, B_1.4s, A_1.s[0]\n" + "ldr qB_3, [%x[bptr], #0x20]\n" + "fmla C_12.4s, B_2.4s, A_1.s[0]\n" + "ldr sA_4, [ aptr3], #0x4\n" + "fmla C_13.4s, B_3.4s, A_1.s[0]\n" + "ldr sA_1, [ aptr4], #0x04\n" + + "fmla C_21.4s, B_1.4s, A_2.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride]\n" + "fmla C_22.4s, B_2.4s, A_2.s[0]\n" + "prfm pldl1keep, [ aptr3, #0x10]\n" + "fmla C_23.4s, B_3.4s, A_2.s[0]\n" + "ldr sA_2, [ aptr5], #0x04\n" + + "fmla C_31.4s, B_1.4s, A_3.s[0]\n" + "prfm pldl1keep, [%x[bptr], #0x00]\n" + "fmla C_32.4s, B_2.4s, A_3.s[0]\n" + "prfm pldl1keep, [%x[bptr], #0x10]\n" + "fmla C_33.4s, B_3.4s, A_3.s[0]\n" + "ldr sA_3, [ aptr6], #0x04\n" + + "fmla C_41.4s, B_1.4s, A_4.s[0]\n" + "prfm pldl1keep, [%x[bptr], #0x20]\n" + "fmla C_42.4s, B_2.4s, A_4.s[0]\n" + "prfm pldl1keep, [ aptr4, #0x10]\n" + "fmla C_43.4s, B_3.4s, A_4.s[0]\n" + "ldr sA_4, [ aptr7], #0x04\n" + + "fmla C_51.4s, B_1.4s, A_1.s[0]\n" + "prfm pldl1keep, [ aptr5, #0x10]\n" + "fmla C_52.4s, B_2.4s, A_1.s[0]\n" + "prfm pldl1keep, [ aptr6, #0x10]\n" + "fmla C_53.4s, B_3.4s, A_1.s[0]\n" + "ldr sA_1, [%x[aptr]], #0x04\n" + + "fmla C_61.4s, B_1.4s, A_2.s[0]\n" + "prfm pldl1keep, [ aptr7, #0x10]\n" + "fmla C_62.4s, B_2.4s, A_2.s[0]\n" + "subs %x[k], %x[k], #1\n" + "fmla C_63.4s, B_3.4s, A_2.s[0]\n" + "ldr sA_2, [ aptr1], #0x04\n" + + "fmla C_71.4s, B_1.4s, A_3.s[0]\n" + "prfm pldl1keep, [%x[aptr], #0x10]\n" + "fmla C_72.4s, B_2.4s, A_3.s[0]\n" + "prfm pldl1keep, [ aptr1, #0x10]\n" + "fmla C_73.4s, B_3.4s, A_3.s[0]\n" + "ldr sA_3, [ aptr2], #0x04\n" + + "fmla C_81.4s, B_1.4s, A_4.s[0]\n" + "prfm pldl1keep, [ aptr2, #0x10]\n" + "fmla C_82.4s, B_2.4s, A_4.s[0]\n" + "ldp qB_1, qB_2, [%x[bptr]]\n" + "fmla C_83.4s, B_3.4s, A_4.s[0]\n" + "bne 1b\n" + + "2:" + "fmla C_11.4s, B_1.4s, A_1.s[0]\n" + "ldr qB_3, [%x[bptr], #0x20]\n" + "fmla C_12.4s, B_2.4s, A_1.s[0]\n" + "stp qC_11, qC_12, [%x[cptr]]\n" + "fmla C_13.4s, B_3.4s, A_1.s[0]\n" + "str qC_13, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_1, [ aptr4], #0x04\n" + + "fmla C_21.4s, B_1.4s, A_2.s[0]\n" + "ldr sA_4, [ aptr3], #0x4\n" + "fmla C_22.4s, B_2.4s, A_2.s[0]\n" + "stp qC_21, qC_22, [%x[cptr]]\n" + "fmla C_23.4s, B_3.4s, A_2.s[0]\n" + "str qC_23, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_2, [ aptr5], #0x04\n" + + "fmla C_31.4s, B_1.4s, A_3.s[0]\n" + "fmla C_32.4s, B_2.4s, A_3.s[0]\n" + "stp qC_31, qC_32, [%x[cptr]]\n" + "fmla C_33.4s, B_3.4s, A_3.s[0]\n" + "str qC_33, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_3, [ aptr6], #0x04\n" + + "fmla C_41.4s, B_1.4s, A_4.s[0]\n" + "fmla C_42.4s, B_2.4s, A_4.s[0]\n" + "stp qC_41, qC_42, [%x[cptr]]\n" + "fmla C_43.4s, B_3.4s, A_4.s[0]\n" + "str qC_43, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_4, [ aptr7], #0x04\n" + + "fmla C_51.4s, B_1.4s, A_1.s[0]\n" + "fmla C_52.4s, B_2.4s, A_1.s[0]\n" + "stp qC_51, qC_52, [%x[cptr]]\n" + "fmla C_53.4s, B_3.4s, A_1.s[0]\n" + "str qC_53, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + "fmla C_61.4s, B_1.4s, A_2.s[0]\n" + "fmla C_62.4s, B_2.4s, A_2.s[0]\n" + "stp qC_61, qC_62, [%x[cptr]]\n" + "fmla C_63.4s, B_3.4s, A_2.s[0]\n" + "str qC_63, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + "fmla C_71.4s, B_1.4s, A_3.s[0]\n" + "fmla C_72.4s, B_2.4s, A_3.s[0]\n" + "stp qC_71, qC_72, [%x[cptr]]\n" + "fmla C_73.4s, B_3.4s, A_3.s[0]\n" + "str qC_73, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + "fmla C_81.4s, B_1.4s, A_4.s[0]\n" + "fmla C_82.4s, B_2.4s, A_4.s[0]\n" + "stp qC_81, qC_82, [%x[cptr]]\n" + "fmla C_83.4s, B_3.4s, A_4.s[0]\n" + "str qC_83, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + // Clear aliases + ".unreq aptr1\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + ".unreq aptr5\n" + ".unreq aptr6\n" + ".unreq aptr7\n" + + ".unreq A_1\n" ".unreq A_2\n" ".unreq A_3\n" ".unreq A_4\n" + ".unreq sA_1\n" ".unreq sA_2\n" ".unreq sA_3\n" ".unreq sA_4\n" + + ".unreq B_1\n" ".unreq B_2\n" ".unreq B_3\n" + ".unreq qB_1\n" ".unreq qB_2\n" ".unreq qB_3\n" + + ".unreq C_11\n" ".unreq C_12\n" ".unreq C_13\n" + ".unreq C_21\n" ".unreq C_22\n" ".unreq C_23\n" + ".unreq C_31\n" ".unreq C_32\n" ".unreq C_33\n" + ".unreq C_41\n" ".unreq C_42\n" ".unreq C_43\n" + ".unreq C_51\n" ".unreq C_52\n" ".unreq C_53\n" + ".unreq C_61\n" ".unreq C_62\n" ".unreq C_63\n" + ".unreq C_71\n" ".unreq C_72\n" ".unreq C_73\n" + ".unreq C_81\n" ".unreq C_82\n" ".unreq C_83\n" + + ".unreq qC_11\n" ".unreq qC_12\n" ".unreq qC_13\n" + ".unreq qC_21\n" ".unreq qC_22\n" ".unreq qC_23\n" + ".unreq qC_31\n" ".unreq qC_32\n" ".unreq qC_33\n" + ".unreq qC_41\n" ".unreq qC_42\n" ".unreq qC_43\n" + ".unreq qC_51\n" ".unreq qC_52\n" ".unreq qC_53\n" + ".unreq qC_61\n" ".unreq qC_62\n" ".unreq qC_63\n" + ".unreq qC_71\n" ".unreq qC_72\n" ".unreq qC_73\n" + ".unreq qC_81\n" ".unreq qC_82\n" ".unreq qC_83\n" + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride] "r" (a_row_stride * sizeof(float)), + [b_row_stride] "r" (b_row_stride * sizeof(float)), + [c_row_stride] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "x17", "x18", "x19", "x20", "x21", "x22", "x23" + ); + } + } +} + +/*****************************************************************************/ +/* 4x16 blocked GEMM with specialised tails + */ +#include "a64_sgemm_4x16.hpp" + +template <> +inline void BlockedGemm<4, 16, float, float>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + // Despatch based on tail of K + switch (K % 4) { + case 3: + sgemm_4x16_impl<3>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + case 2: + sgemm_4x16_impl<2>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + case 1: + sgemm_4x16_impl<1>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + case 0: + sgemm_4x16_impl<0>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + default: + assert(0); + break; + } +} + +#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp new file mode 100644 index 0000000000..e74610ef27 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp @@ -0,0 +1,1445 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +template +inline void sgemm_4x16_impl( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +); + +template <> +inline void sgemm_4x16_impl<0>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 0; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC12.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC13.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC14.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC21.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC22.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC23.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC24.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC31.4s, #0\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 2f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "2:" // Tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} + +template <> +inline void sgemm_4x16_impl<1>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 1; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC12.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC13.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC14.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC21.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC22.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC23.4s, #0\n" + "cbnz %x[k], 3f\n" + + // Prepare for tail in K + "movi vC24.4s, #0\n" + "ldr sA1, [%x[aptr]], #0x04\n" + "movi vC31.4s, #0\n" + "ldr sA2, [ aptr2], #0x04\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "b 2f\n" // Jump to tail + + "3:" // Prepare for loop over K + "movi vC24.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC31.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 4f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "4:" // Tail iteration + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr sA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr sA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + + "2:" // Common tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "ldr sA3, [ aptr3], #0x10\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "ldr sA4, [ aptr4], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} + +template <> +inline void sgemm_4x16_impl<2>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 2; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC12.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC13.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC14.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC21.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC22.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC23.4s, #0\n" + "cbnz %x[k], 3f\n" + + // Prepare for tail in K + "movi vC24.4s, #0\n" + "ldr dA1, [%x[aptr]], #0x08\n" + "movi vC31.4s, #0\n" + "ldr dA2, [ aptr2], #0x08\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "b 2f\n" // Jump to tail + + "3:" // Prepare for loop over K + "movi vC24.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC31.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 4f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "4:" // Tail iteration + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr dA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr dA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + + "2:" // Common tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr dA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr dA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} + +template <> +inline void sgemm_4x16_impl<3>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 3; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC12.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC13.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC14.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC21.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC22.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC23.4s, #0\n" + "cbnz %x[k], 3f\n" + + // Prepare for tail in K + "movi vC24.4s, #0\n" + "ldr dA1, [%x[aptr]], #0x08\n" + "movi vC31.4s, #0\n" + "ldr dA2, [ aptr2], #0x08\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "b 2f\n" // Jump to tail + + "3:" // Prepare for loop over K + "movi vC24.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC31.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 4f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "4:" // Tail iteration + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr dA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr dA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + + "2:" // Common tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr dA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr dA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "ldr sA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "ldr sA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "ldr sA3, [ aptr3], #0x10\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "ldr sA4, [ aptr4], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} diff --git a/arm_compute/core/NEON/kernels/winograd/perf.h b/arm_compute/core/NEON/kernels/winograd/perf.h new file mode 100644 index 0000000000..11fb0c452f --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/perf.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +/* Prototypes from perf.c */ + +void start_counter(int fd); +long long get_counter(int fd); +long long stop_counter(int fd); +int open_instruction_counter(void); +int open_cycle_counter(void); diff --git a/arm_compute/core/NEON/kernels/winograd/profiler.hpp b/arm_compute/core/NEON/kernels/winograd/profiler.hpp new file mode 100644 index 0000000000..143192b589 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/profiler.hpp @@ -0,0 +1,244 @@ + +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "perf.h" +#include + +class profiler { +private: +#ifdef CYCLE_PROFILING + struct ProfileEntry { + int event_id; + long int bytes_read, ops, bytes_written; + long int duration; + }; + + static const int maxevents = 10000; + ProfileEntry events[maxevents]; + int currentevent; + int countfd; + + std::map event_ids; + + int get_event_id(const char *id) { + if (!event_ids.count(id)) { + event_ids.emplace(id, event_ids.size()); + } + return event_ids[id]; + } +#endif // CYCLE_PROFILING + +public: +#ifdef CYCLE_PROFILING + profiler() { + currentevent = 0; + countfd = open_cycle_counter(); + } + + ~profiler() { + close(countfd); + + // Compute performance from recorded events + struct ProfileResult { + ProfileResult() : total_calls(0), + total_duration(0), + total_bytes_read(0), + total_ops(0), + total_bytes_written(0) { + } + + void operator+=(const ProfileEntry &rhs) { + total_calls++; + total_duration += rhs.duration; + total_bytes_read += rhs.bytes_read; + total_ops += rhs.ops; + total_bytes_written = rhs.bytes_written; + } + + float avg_duration(void) const { + return static_cast(total_duration) / + static_cast(total_calls); + } + + float bytes_read_per_cycle(void) const { + return static_cast(total_bytes_read) / + static_cast(total_duration); + } + + float ops_per_cycle(void) const { + return static_cast(total_ops) / + static_cast(total_duration); + } + + float bytes_written_per_cycle(void) const { + return static_cast(total_bytes_written) / + static_cast(total_duration); + } + + long int total_calls, + total_duration, + total_bytes_read, + total_ops, + total_bytes_written; + }; + + std::vector totals; + totals.resize(event_ids.size()); + for (int i = 0; i < currentevent; i++) { + const auto &event = events[i]; + totals[event.event_id] += event; + } + + // Get the longest label + int len_label = 0; + for (const auto &kv : event_ids) { + len_label = std::max(len_label, static_cast(strlen(kv.first))); + } + + // Get the longest values for every other field + const auto get_length_of_field = + [totals] (const char *title, auto f, auto len) -> size_t { + size_t l = strlen(title); + for (const auto &v : totals) { + l = std::max(l, len(f(v))); + } + return l; + }; + + // Get the strlen for an int + const auto intlen = [] (long int x) -> size_t { + size_t len = 0; + do { + x /= 10; + len++; + } while (x); + return len; + }; + + // Get the strlen for a float + const auto floatlen = [] (const int precision) { + return [precision] (float x) { + size_t len = 0; + + if (!std::isfinite(x)) { + return static_cast(3); + } + + do { + x /= 10.0f; + len++; + } while (x > 1.0f); + return len + 1 + precision; + }; + }; + + const int len_calls = get_length_of_field( + "Calls", [] (const auto &v) {return v.total_calls;}, + intlen + ); + const int len_duration = get_length_of_field( + "Duration", [] (const auto &v) {return v.total_duration;}, + intlen + ); + const int len_average_duration = get_length_of_field( + "Average", [] (const auto &v) {return v.avg_duration();}, + floatlen(2) + ); + const int len_reads_per_cycle = get_length_of_field( + "Reads / cycle", + [] (const auto &v) {return v.bytes_read_per_cycle();}, + floatlen(6) + ); + const int len_ops_per_cycle = get_length_of_field( + "Ops / cycle", + [] (const auto &v) {return v.ops_per_cycle();}, + floatlen(6) + ); + const int len_writes_per_cycle = get_length_of_field( + "Writes / cycle", + [] (const auto &v) {return v.bytes_written_per_cycle();}, + floatlen(6) + ); + + // Print header + printf( + "%*s %*s %*s %*s %*s %*s %*s\n", + len_label, "", + len_calls, "Calls", + len_duration, "Duration", + len_average_duration, "Average", + len_reads_per_cycle, "Reads / cycle", + len_ops_per_cycle, "Ops / cycle", + len_writes_per_cycle, "Writes / cycle" + ); + for (const auto &kv : event_ids) { + const auto id = kv.second; + printf( + "%*s %*ld %*ld %*.2f %*.6f %*.6f %*.6f\n", + len_label, kv.first, + len_calls, totals[id].total_calls, + len_duration, totals[id].total_duration, + len_average_duration, totals[id].avg_duration(), + len_reads_per_cycle, totals[id].bytes_read_per_cycle(), + len_ops_per_cycle, totals[id].ops_per_cycle(), + len_writes_per_cycle, totals[id].bytes_written_per_cycle() + ); + } + printf("\n"); + } +#endif // CYCLE_PROFILING + + template + void operator() (const char * event, + T func, + long int bytes_read = 0, + long int ops = 0, + long int bytes_written = 0) { +#ifdef CYCLE_PROFILING + if (currentevent==maxevents) { + func(); + } else { + start_counter(countfd); + func(); + long long cycs = stop_counter(countfd); + + // Store the profiling data + events[currentevent++] = { + get_event_id(event), bytes_read, ops, bytes_written, cycs + }; + } +#else + func(); +#endif // CYCLE_PROFILING + } +}; diff --git a/arm_compute/core/NEON/kernels/winograd/shims.hpp b/arm_compute/core/NEON/kernels/winograd/shims.hpp new file mode 100644 index 0000000000..249e5757f0 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/shims.hpp @@ -0,0 +1,319 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +/** Re-order a weight tensor from [Output feature map x Input feature map x + * Height x Width] format to [Height x Width x Input feature map x Output + * feature map] format. + */ +template +inline void ofm_ifm_h_w_to_h_w_ifm_ofm( + const T* const in, // Input in [Output x Input x Height x Width] form + T* const out, // Output in [Height x Width x Input x Output] form + const int n_output_feature_maps, + const int n_input_feature_maps, + const int n_rows, + const int n_cols, + int in_output_feature_map_stride=0, + int in_input_feature_map_stride=0, + int in_row_stride=0, + int out_row_stride=0, + int out_col_stride=0, + int out_input_feature_map_stride=0 +); + +/** Re-order a weight tensor from [Height x Width x Input feature map x Output + * feature map] format to [Output feature map x Input feature map x Height x + * Width] format. + */ +template +inline void h_w_ifm_ofm_to_ofm_ifm_h_w( + const T* const in, // Input in [Height x Width x Input x Output] form + T* const out, // Output in [Output x Input x Height x Width] form + const int n_rows, + const int n_cols, + const int n_input_feature_maps, + const int n_output_feature_maps, + int in_row_stride=0, + int in_col_stride=0, + int in_input_feature_map_stride=0, + int out_output_feature_map_stride=0, + int out_input_feature_map_stride=0, + int out_row_stride=0 +); + + +/* Re-order a tensor from NCHW format to NHWC. + */ +template +inline void nchw_to_nhwc( + const T* const in, + T* const out, + const int n_batches, + const int n_channels, + const int n_rows, + const int n_cols, + int in_batch_stride=0, + int in_channel_stride=0, + int in_row_stride=0, + int out_batch_stride=0, + int out_row_stride=0, + int out_col_stride=0 +) +{ + // Fill in the stride values + in_row_stride = (in_row_stride) ? in_row_stride : n_cols; + in_channel_stride = (in_channel_stride) ? in_channel_stride + : n_rows * in_row_stride; + in_batch_stride = (in_batch_stride) ? in_batch_stride + : n_channels * in_channel_stride; + + out_col_stride = (out_col_stride) ? out_col_stride : n_channels; + out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride; + out_batch_stride = (out_batch_stride) ? out_batch_stride + : n_rows * out_row_stride; + + // Perform the re-ordering + for (int n = 0; n < n_batches; n++) + { + const T* const in_batch = in + n*in_batch_stride; + T* const out_batch = out + n*out_batch_stride; + + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in_batch + i*in_row_stride; + T* const out_row = out_batch + i*out_row_stride; + + for (int j = 0; j < n_cols; j++) + { + const T* const in_col = in_row + j; + T* const out_col = out_row + j*out_col_stride; + + for (int c = 0; c < n_channels; c++) + { + const T* const in_channel = in_col + c*in_channel_stride; + out_col[c] = *(in_channel); + } + } + } + } +} + +/* Re-order a tensor from NHWC format to NCHW. + */ +template +inline void nhwc_to_nchw( + const T* const in, // Input data in NHWC form + T* const out, // Output data in NCHW form + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + int in_batch_stride=0, + int in_row_stride=0, + int in_col_stride=0, + int out_batch_stride=0, + int out_channel_stride=0, + int out_row_stride=0 +) +{ + // Fill in stride values + in_col_stride = (in_col_stride) ? in_col_stride : n_channels; + in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride; + in_batch_stride = (in_batch_stride) ? in_batch_stride + : n_rows * in_row_stride; + + out_row_stride = (out_row_stride) ? out_row_stride : n_cols; + out_channel_stride = (out_channel_stride) ? out_channel_stride + : n_rows * out_row_stride; + out_batch_stride = (out_batch_stride) ? out_batch_stride + : n_channels * out_channel_stride; + + // Perform the re-ordering + // For every batch + for (int n = 0; n < n_batches; n++) + { + const T* const in_batch = in + n*in_batch_stride; + T* const out_batch = out + n*out_batch_stride; + + // For every row + for (int i = 0; i < n_rows; i++) + { + const T* const in_i = in_batch + i*in_row_stride; + T* const out_i = out_batch + i*out_row_stride; + + // For every column + for (int j = 0; j < n_cols; j++) + { + const T* const in_j = in_i + j*in_col_stride; + T* const out_j = out_i + j; + + // For every channel + for (int c = 0; c < n_channels; c++) + { + const T* const in_channel = in_j + c; + T* const out_channel = out_j + c*out_channel_stride; + *(out_channel) = *(in_channel); + } + } + } + } +} + + +/*****************************************************************************/ +/* Generic weight re-order implementation. + */ +template +inline void ofm_ifm_h_w_to_h_w_ifm_ofm( + const T* const in, // Input in [Output x Input x Height x Width] form + T* const out, // Output in [Height x Width x Input x Output] form + const int n_output_feature_maps, + const int n_input_feature_maps, + const int n_rows, + const int n_cols, + int in_output_feature_map_stride, + int in_input_feature_map_stride, + int in_row_stride, + int out_row_stride, + int out_col_stride, + int out_input_feature_map_stride +) +{ + // Fill in stride values + in_row_stride = (in_row_stride) + ? in_row_stride + : n_cols; + in_input_feature_map_stride = (in_input_feature_map_stride) + ? in_input_feature_map_stride + : n_rows * in_row_stride; + in_output_feature_map_stride = (in_output_feature_map_stride) + ? in_output_feature_map_stride + : n_input_feature_maps * in_input_feature_map_stride; + + out_input_feature_map_stride = (out_input_feature_map_stride) + ? out_input_feature_map_stride + : n_output_feature_maps; + out_col_stride = (out_col_stride) + ? out_col_stride + : n_input_feature_maps * out_input_feature_map_stride; + out_row_stride = (out_row_stride) + ? out_row_stride + : n_cols * out_col_stride; + + // Perform the re-ordering + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in + i * in_row_stride; + T* out_row = out + i * out_row_stride; + + for (int j = 0; j < n_cols; j++) + { + const T* const in_col = in_row + j; + T* const out_col = out_row + j * out_col_stride; + + for (int ifm = 0; ifm < n_input_feature_maps; ifm++) + { + const T* const in_ifm = in_col + ifm * in_input_feature_map_stride; + T* const out_ifm = out_col + ifm * out_input_feature_map_stride; + + for (int ofm = 0; ofm < n_output_feature_maps; ofm++) + { + const T* const in_ofm = in_ifm + ofm * in_output_feature_map_stride; + T* const out_ofm = out_ifm + ofm; + *(out_ofm) = *(in_ofm); + } + } + } + } +} + +/*****************************************************************************/ +/* Generic weight re-order implementation. + */ +template +inline void h_w_ifm_ofm_to_ofm_ifm_h_w( + const T* const in, // Input in [Height x Width x Input x Output] form + T* const out, // Output in [Output x Input x Height x Width] form + const int n_rows, + const int n_cols, + const int n_input_feature_maps, + const int n_output_feature_maps, + int in_row_stride, + int in_col_stride, + int in_input_feature_map_stride, + int out_output_feature_map_stride, + int out_input_feature_map_stride, + int out_row_stride +) +{ + // Fill in the stride values + in_input_feature_map_stride = (in_input_feature_map_stride) + ? in_input_feature_map_stride + : n_output_feature_maps; + in_col_stride = (in_col_stride) + ? in_col_stride + : n_input_feature_maps * in_input_feature_map_stride; + in_row_stride = (in_row_stride) + ? in_row_stride + : n_cols * in_col_stride; + + out_row_stride = (out_row_stride) + ? out_row_stride + : n_cols; + out_input_feature_map_stride = (out_input_feature_map_stride) + ? out_input_feature_map_stride + : n_rows * out_row_stride; + out_output_feature_map_stride = (out_output_feature_map_stride) + ? out_output_feature_map_stride + : n_input_feature_maps * out_input_feature_map_stride; + + // Perform the re-ordering + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in + i * in_row_stride; + T* const out_row = out + i * out_row_stride; + + for (int j = 0; j < n_cols; j++) + { + const T* const in_col = in_row + j * in_col_stride; + T* const out_col = out_row + j; + + for (int ifm = 0; ifm < n_input_feature_maps; ifm++) + { + const T* const in_ifm = in_col + ifm * in_input_feature_map_stride; + T* const out_ifm = out_col + ifm * out_input_feature_map_stride; + + for (int ofm = 0; ofm < n_output_feature_maps; ofm++) + { + const T* const in_ofm = in_ifm + ofm; + T* const out_ofm = out_ifm + ofm * out_output_feature_map_stride; + *(out_ofm) = *(in_ofm); + } + } + } + } +} + diff --git a/arm_compute/core/NEON/kernels/winograd/tensor.hpp b/arm_compute/core/NEON/kernels/winograd/tensor.hpp new file mode 100644 index 0000000000..70ef65d2a5 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/tensor.hpp @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include +#include +#include + +#include "alloc.hpp" + +/*****************************************************************************/ +/* Padding definitions */ +enum PaddingType { + PADDING_SAME, PADDING_VALID +}; + +/*****************************************************************************/ +/* Shape of a kernel */ +struct KernelShape { + int n_output_channels, n_rows, n_cols, n_input_channels; + + int size(void) const { + return n_output_channels * n_rows * n_cols * n_input_channels; + } +}; + +struct Tensor4DShape { + int n_batches, + n_rows, + n_cols, + n_channels; + + int size() const { + return n_batches * n_rows * n_cols * n_channels; + } + + bool TestEq(const Tensor4DShape& other) const { + return (n_batches == other.n_batches && + n_rows == other.n_rows && + n_cols == other.n_cols && + n_channels == other.n_channels); + } +}; + +template +class Tensor4D final { + public: + Tensor4D(ShapeT shape) : + _shape(shape), + _data(reinterpret_cast(ALLOCATE(size_bytes()))) { + Clear(); + } + + ~Tensor4D() { + free(_data); + } + + T* ptr() const { + return _data; + } + + const ShapeT& shape() const { + return _shape; + } + + size_t size_bytes() const { + return _shape.size() * sizeof(T); + } + + bool TestEq(Tensor4D& other) const; + T& element(int, int, int, int) const; + void Print() const; + + void Clear() { + Fill(static_cast(0)); + } + + void Fill(T val) { + for (int i = 0; i < _shape.size(); i++) + _data[i] = val; + } + + void TestPattern() { + for (int i = 0; i < _shape.size(); i++) + _data[i] = static_cast(i); + } + + void Rand(const int seed=2311) { + std::mt19937 gen(seed); + std::uniform_int_distribution<> dis(-50, +50); + + for (int i = 0; i < _shape.size(); i++) { + _data[i] = static_cast(dis(gen)); + } + } + Tensor4D(const Tensor4D &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + Tensor4D &operator=(const Tensor4D &) = delete; + /** Allow instances of this class to be moved */ + Tensor4D(Tensor4D &&) = default; + /** Allow instances of this class to be moved */ + Tensor4D &operator=(Tensor4D &&) = default; + + + private: + const ShapeT _shape; + T* const _data; +}; + + +template <> +inline float& Tensor4D::element(int n, int i, int j, int c) const { + int index = ((n*_shape.n_rows + i)*_shape.n_cols + j)*_shape.n_channels + c; + return _data[index]; +} + + +template <> +inline float& Tensor4D::element(int oc, int i, int j, int ic) const { + int index = ((i*_shape.n_cols + j)*_shape.n_input_channels + ic)*_shape.n_output_channels + oc; + return _data[index]; +} + +template <> +inline bool Tensor4D::TestEq(Tensor4D& other) const { + // Test equivalence, printing errors + // First test the shapes are the same + if (!_shape.TestEq(other.shape())) { + printf("Tensors have different shapes.\n"); + return false; + } else { + int incorrects = 0; + + for (int n = 0; n < _shape.n_batches; n++) { + for (int i = 0; i < _shape.n_rows; i++) { + for (int j = 0; j < _shape.n_cols; j++) { + for (int c = 0; c < _shape.n_channels; c++) { + // Check elements for equivalence + const auto a = this->element(n, i, j, c); + const auto b = other.element(n, i, j, c); + + if (a != b) { + printf("Difference at element {%d, %d, %d, %d}: %.3f != %.3f\n", n, i, j, c, a, b); + + if (++incorrects > 100) { + printf("More than 100 incorrect values, stopping test.\n"); + return false; + } + } + } + } + } + } + + return incorrects == 0; + } +} + + +template <> +inline void Tensor4D::Print() const { + for (int n = 0; n < _shape.n_batches; n++) { + for (int c = 0; c < _shape.n_channels; c++) { + for (int i = 0; i < _shape.n_rows; i++) { + for (int j = 0; j < _shape.n_cols; j++) { + printf("%5.2f ", element(n, i, j, c)); + } + printf("\n"); + } + printf("\n"); + } + } +} + + +template <> +inline void Tensor4D::Print() const { + for (int oc = 0; oc < _shape.n_output_channels; oc++) { + for (int ic = 0; ic < _shape.n_input_channels; ic++) { + for (int i = 0; i < _shape.n_rows; i++) { + for (int j = 0; j < _shape.n_cols; j++) { + printf("%5.2f ", element(oc, i, j, ic)); + } + printf("\n"); + } + printf("\n"); + } + } +} diff --git a/arm_compute/core/NEON/kernels/winograd/transforms.hpp b/arm_compute/core/NEON/kernels/winograd/transforms.hpp new file mode 100644 index 0000000000..8546ee9e2e --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms.hpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include "transforms/input_2x2_3x3.hpp" +#include "transforms/kernel_2x2_3x3.hpp" +#include "transforms/output_2x2_3x3.hpp" diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp new file mode 100644 index 0000000000..7013c66ac0 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp @@ -0,0 +1,638 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include "../tensor.hpp" + +namespace winograd { + /* Transform an input tensor into the Winograd domain. + */ + template + struct Winograd2x2_3x3GemmInput { + static void execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride + ); + + static size_t bytes_read(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + return input_shape.n_batches * tile_rows * (16 + 8*(tile_cols - 1)) * input_shape.n_channels * sizeof(T); + } + + static int flops_performed(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + return input_shape.n_batches * tile_rows * (32 + 24*(tile_cols - 1)) * input_shape.n_channels; + } + + static size_t bytes_written(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + const int M = input_shape.n_batches * tile_rows * tile_cols; + return 16 * M * input_shape.n_channels * sizeof(T); + } + + protected: + template + static void process_tile_tensor( + const int tile_M, // Number of rows of tiles + const int tile_N, // Number of columns of tiles + int n_channels, // Number of input channels + const T* const input, // Base input pointer (appropriate to batch and channel) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + T* const matrix, // 1st output matrix (appropriate to batch and channel) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix + ); + + template + static void process_tile_row( + const int tile_N, // Number of tiles in the row + const T* const input, // Base input pointer (appropriate to batch, channel and row) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + T* const matrix, // 1st output matrix (appropriate to batch, channel and row) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix + ); + }; + + template + struct Winograd2x2_3x3GemmInputChannelwise { + static void execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride + ); + + static size_t bytes_read(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + // We read as many bytes as we write + return bytes_written(input_shape, output_shape); + } + + static int flops_performed(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + return input_shape.n_batches * tile_rows * 32 * tile_cols * input_shape.n_channels; + } + + static size_t bytes_written(const Tensor4DShape &input_shape, + const Tensor4DShape &output_shape) { + return winograd::Winograd2x2_3x3GemmInput::bytes_written(input_shape, output_shape); + } + + protected: + typedef void (*tilefunc)(int, const T*, int, int, T*, int); + template + static void process_tile( + int n_channels, // Number of channels in the tile + const T* const input_base, + const int input_row_stride, + const int input_col_stride, + T* const matrix_base, + const int matrix_stride + ); + + private: + template + static void _process_tile( + int &n_channels, const T* &inptr, + const int input_row_stride, const int input_col_stride, + T* &outptr, const int matrix_stride + ); + }; +} + +/*****************************************************************************/ +// Include specialised implementations here +#include "input_2x2_3x3/a64_float.hpp" +#include "input_2x2_3x3/a64_float_channelwise.hpp" +/*****************************************************************************/ + +/*****************************************************************************/ +template +void winograd::Winograd2x2_3x3GemmInput::execute( + const T *inptr_base, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride +) { + // Select an appropriate matrix processing method for the shape and padding + // of the input tensor. + typedef void (*tensorfunc)(int, int, int, const T*, int, int, T*, int, int); + const auto process_tensor = [&padding_type, &input_shape] () -> tensorfunc { + if (padding_type == PADDING_VALID) { + const int pad_bottom = input_shape.n_rows % 2; + const int pad_right = input_shape.n_cols % 2; + + if (pad_bottom == 0 && pad_right == 0) { + return process_tile_tensor; + } else if (pad_bottom == 0 && pad_right == 1) { + return process_tile_tensor; + } else if (pad_bottom == 1 && pad_right == 0) { + return process_tile_tensor; + } else if (pad_bottom == 1 && pad_right == 1) { + return process_tile_tensor; + } + } else { // PADDING_SAME + const int pad_bottom = 1 + input_shape.n_rows % 2; + const int pad_right = 1 + input_shape.n_cols % 2; + + if (pad_bottom == 1 && pad_right == 1) { + return process_tile_tensor; + } else if (pad_bottom == 1 && pad_right == 2) { + return process_tile_tensor; + } else if (pad_bottom == 2 && pad_right == 1) { + return process_tile_tensor; + } else if (pad_bottom == 2 && pad_right == 2) { + return process_tile_tensor; + } + } + + printf("%s::%u Uncovered case.\n", __FILE__, __LINE__); + exit(-1); + return NULL; // No function found + } (); + + // Compute strides + const int input_row_stride = input_shape.n_cols * input_shape.n_channels; + const int input_col_stride = input_shape.n_channels; + + // Process each batch of the tensor in turn. + for (int batch = 0; batch < input_shape.n_batches; batch++) { + // Work out pointers + const T *inptr = inptr_base + (batch * input_shape.n_rows * + input_shape.n_cols * input_shape.n_channels); + T *outptr = outptr_base + batch * matrix_batch_stride; + + // Delegate doing the actual work + process_tensor( + tile_M, tile_N, input_shape.n_channels, + inptr, input_row_stride, input_col_stride, + outptr, matrix_stride, matrix_row_stride + ); + } +} + +/*****************************************************************************/ +template +template +void winograd::Winograd2x2_3x3GemmInput::process_tile_tensor( + const int tile_M, // Number of rows of tiles + const int tile_N, // Number of columns of tiles + int n_channels, // Number of input channels + const T* const input, // Base input pointer (appropriate to batch and channel) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + T* const matrix, // 1st output matrix (appropriate to batch and channel) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix +) { + // Base row processing functions + typedef void (*rowfunc)(int, const T*, int, int, T*, int, int); + const rowfunc process_top_row[3] = { + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 1> + : process_tile_row<1, 1, 0, pad_right, 1>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 2> + : process_tile_row<1, 1, 0, pad_right, 2>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 4> + : process_tile_row<1, 1, 0, pad_right, 4>, + }; + const rowfunc process_middle_row[3] = { + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 1> + : process_tile_row<0, 1, 0, pad_right, 1>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 2> + : process_tile_row<0, 1, 0, pad_right, 2>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, 0, pad_right, 4> + : process_tile_row<0, 1, 0, pad_right, 4>, + }; + const rowfunc process_bottom_row[3] = { + (padding == PADDING_VALID) + ? process_tile_row<0, 0, pad_bottom, pad_right, 1> + : process_tile_row<0, 1, pad_bottom, pad_right, 1>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, pad_bottom, pad_right, 2> + : process_tile_row<0, 1, pad_bottom, pad_right, 2>, + (padding == PADDING_VALID) + ? process_tile_row<0, 0, pad_bottom, pad_right, 4> + : process_tile_row<0, 1, pad_bottom, pad_right, 4>, + }; + + // Method to get an input pointer for the given tile row + const auto get_inptr = [&input, &input_row_stride] (const int tile_i) { + if (padding == PADDING_VALID) { + return input + 2 * tile_i * input_row_stride; + } else { + return input + (2 * tile_i - (tile_i ? 1 : 0)) * input_row_stride; + } + }; + + // Wrapper to process a row of tiles, covering all channels. + const auto process_row = + [tile_N, input_row_stride, input_col_stride, matrix_stride, matrix_row_stride, n_channels] + (const rowfunc f[3], const T *inptr, T *outptr) { + int rem_channels = n_channels; + + // While there remain channels to process continue to process the + // row. + for (; rem_channels >= 4; rem_channels -= 4, inptr += 4, outptr += 4) { + f[2](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); + } + for (; rem_channels >= 2; rem_channels -= 2, inptr += 2, outptr += 2) { + f[1](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); + } + if (rem_channels) { + f[0](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); + } + }; + + // Process all rows of tiles in the tensor + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + T* const m_row = matrix + tile_i * tile_N * matrix_row_stride; + const T *row_inptr = get_inptr(tile_i); + + if (tile_i == 0) { + // Top row of the input + process_row(process_top_row, row_inptr, m_row); + } else if (tile_i == tile_M - 1) { + // Bottom row of the input + process_row(process_bottom_row, row_inptr, m_row); + } else { + // Any other row of the input + process_row(process_middle_row, row_inptr, m_row); + } + } +} + +/*****************************************************************************/ +template +template +void winograd::Winograd2x2_3x3GemmInput::process_tile_row( + const int tile_N, // Number of tiles in the row + const T* const input, // Base input pointer (appropriate to batch, channel and row) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + T* const matrix, // 1st output matrix (appropriate to batch, channel and row) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix +) { + // Construct copies of the pointers + const T *inptr = input; + T *outptr = matrix; + + // Storage for the tensors x, X.T x, and X.T x X. + T x[4][4][proc_channels], XTx[4][4][proc_channels], XTxX[4][4][proc_channels]; + + // For every tile in the row + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + // Determine the padding for the tile + const int tile_pad_left = (tile_j == 0) ? pad_left : 0; + const int tile_pad_right = (tile_j == tile_N - 1) ? pad_right : 0; + + // Load tile values. If this is the first tile in the row then we must load + // all values, otherwise we can just load the final two columns of the input. + for (int i = 0; i < 4; i++) { + for (int j = ((tile_j == 0) ? 0 : 2); j < 4; j++) { + // Fill with padding if required + if (i < pad_top || 4 - pad_bottom <= i || + j < tile_pad_left || 4 - tile_pad_right <= j) { + for (int c = 0; c < proc_channels; c++) { + x[i][j][c] = static_cast(0); // Padding + } + } else { + // Load values, note that the initial padding offsets the pointer we + // were provided. + for (int c = 0; c < proc_channels; c++) { + const int row_offset = (i - pad_top) * input_row_stride; + const int col_offset = (j - tile_pad_left) * input_col_stride; + x[i][j][c] = inptr[row_offset + col_offset + c]; + } + } + } + } + + // Compute the matrix X.T x. Note, can elide operations depending on the + // padding. Furthermore, if this isn't the left-most tile we can skip half + // of the operations by copying results from the previous version of X.T x. + // This latter optimisation can be simplified by unrolling the outermost + // loop by two and by renaming the registers containing XTx. + if (tile_j == 0) { + for (int j = 0; j < 4; j++) { + for (int c = 0; c < proc_channels; c++) { + XTx[0][j][c] = x[0][j][c] - x[2][j][c]; + XTx[1][j][c] = x[1][j][c] + x[2][j][c]; + XTx[2][j][c] = -x[1][j][c] + x[2][j][c]; + XTx[3][j][c] = x[1][j][c] - x[3][j][c]; + } + } + } else { + for (int j = 0; j < 2; j++) { + for (int c = 0; c < proc_channels; c++) { + XTx[0][j][c] = XTx[0][j + 2][c]; + XTx[1][j][c] = XTx[1][j + 2][c]; + XTx[2][j][c] = XTx[2][j + 2][c]; + XTx[3][j][c] = XTx[3][j + 2][c]; + } + } + for (int j = 2; j < 4; j++) { + for (int c = 0; c < proc_channels; c++) { + XTx[0][j][c] = x[0][j][c] - x[2][j][c]; + XTx[1][j][c] = x[1][j][c] + x[2][j][c]; + XTx[2][j][c] = -x[1][j][c] + x[2][j][c]; + XTx[3][j][c] = x[1][j][c] - x[3][j][c]; + } + } + } + + // Compute the matrix X.T x X. Note, can elide operations based on the + // padding. + for (int i = 0; i < 4; i++) { + for (int c = 0; c < proc_channels; c++) { + XTxX[i][0][c] = XTx[i][0][c] - XTx[i][2][c]; + XTxX[i][1][c] = XTx[i][1][c] + XTx[i][2][c]; + XTxX[i][2][c] = -XTx[i][1][c] + XTx[i][2][c]; + XTxX[i][3][c] = XTx[i][1][c] - XTx[i][3][c]; + } + } + + // Store the output matrix (X.T x X) + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + // Get a pointer to the relevant output matrix + T *mptr = outptr + (i*4 + j)*matrix_stride; + + // Write out the channels + for (int c = 0; c < proc_channels; c++) { + mptr[c] = XTxX[i][j][c]; + } + } + } + + // Update the pointers + inptr += input_col_stride * ((tile_j == 0 && pad_left) ? 1 : 2); + outptr += matrix_row_stride; + } +} + +/*****************************************************************************/ +template +void winograd::Winograd2x2_3x3GemmInputChannelwise::execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride +) { + const int n_channels = input_shape.n_channels; + const int input_col_stride = n_channels; + const int input_row_stride = input_shape.n_cols * input_col_stride; + + // Determine the padding and hence select appropriate methods for each tile. + tilefunc fs[3][3]; + + if (padding_type == PADDING_VALID) { + constexpr int pad_top = 0; + constexpr int pad_left = 0; + const int pad_right = input_shape.n_cols % 2 == 0; + + fs[0][0] = process_tile; + fs[0][1] = process_tile; + fs[0][2] = (pad_right) ? process_tile : process_tile; + + fs[1][0] = process_tile<0, pad_left, 0, 0>; + fs[1][1] = process_tile<0, 0, 0, 0>; + fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 0> : process_tile<0, 0, 0, 1>; + + if (input_shape.n_rows % 2 == 0) { + constexpr int pad_bottom = 0; + fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; + fs[2][1] = process_tile<0, 0, pad_bottom, 0>; + fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>; + } else { + constexpr int pad_bottom = 1; + fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; + fs[2][1] = process_tile<0, 0, pad_bottom, 0>; + fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>; + } + } else { + constexpr int pad_top = 1; + constexpr int pad_left = 1; + const int pad_right = input_shape.n_cols % 2 == 0; + + fs[0][0] = process_tile; + fs[0][1] = process_tile; + fs[0][2] = (pad_right) ? process_tile : process_tile; + + fs[1][0] = process_tile<0, pad_left, 0, 0>; + fs[1][1] = process_tile<0, 0, 0, 0>; + fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 1> : process_tile<0, 0, 0, 2>; + + if (input_shape.n_rows % 2 == 0) { + constexpr int pad_bottom = 1; + fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; + fs[2][1] = process_tile<0, 0, pad_bottom, 0>; + fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>; + } else { + constexpr int pad_bottom = 2; + fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; + fs[2][1] = process_tile<0, 0, pad_bottom, 0>; + fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>; + } + } + + // Process each tile in turn + for (int batch = 0; batch < input_shape.n_batches; batch++) { + const T* const input_base_batch = inptr + batch*input_shape.n_rows*input_shape.n_cols*n_channels; + + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + const int row_offset = (tile_i == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1); + const T* const input_base_row = input_base_batch + (2*tile_i - row_offset)*input_shape.n_cols*n_channels; + + // Select the set of functions for the row + const int fs_i = (tile_i == 0) ? 0 : ((tile_i < tile_M - 1) ? 1 : 2); + + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + // Select the function for the column + const int fs_j = (tile_j == 0) ? 0 : ((tile_j < tile_N - 1) ? 1 : 2); + const auto f = fs[fs_i][fs_j]; + + // Get pointers into the input and outputs + const int col_offset = (tile_j == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1); + const T* const input_base_col = input_base_row + (2*tile_j - col_offset)*n_channels; + T* const matrix_base = outptr_base + batch*matrix_batch_stride + (tile_i*tile_N + tile_j)*matrix_row_stride; + f(n_channels, input_base_col, input_row_stride, input_col_stride, + matrix_base, matrix_stride); + } + } + } +} + +template +template +void winograd::Winograd2x2_3x3GemmInputChannelwise::process_tile( + int n_channels, // Number of channels in the tile + const T* const input_base, + const int input_row_stride, + const int input_col_stride, + T* const matrix_base, + const int matrix_stride +) { + // Copy pointers + const T *inptr = input_base; + T *outptr = matrix_base; + + // Process channels (modifies inptr, outptr and n_channels) + _process_tile( + n_channels, inptr, input_row_stride, input_col_stride, + outptr, matrix_stride + ); + _process_tile( + n_channels, inptr, input_row_stride, input_col_stride, + outptr, matrix_stride + ); + _process_tile( + n_channels, inptr, input_row_stride, input_col_stride, + outptr, matrix_stride + ); +} + +template +template +void winograd::Winograd2x2_3x3GemmInputChannelwise::_process_tile( + int &n_channels, + const T* &inptr, const int input_row_stride, const int input_col_stride, + T* &outptr, const int matrix_stride +) { + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + T* outptrs[4] = { + outptr, + outptr + matrix_stride * 4, + outptr + matrix_stride * 8, + outptr + matrix_stride * 12 + }; + + // The matrix X; zeroed to account for padding. + T x[4][4]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + x[i][j] = 0; + } + } + + // The matrices X.T x and U + T XTx[4][4], U[4][4]; + + // Now progress through each channel + for (; n_channels >= proc_channels; n_channels -= proc_channels) { + for (int n = 0; n < proc_channels; n++) { + // Load the matrix X + for (int cell_i = pad_top, i = 0; cell_i < 4 - pad_bottom; cell_i++, i++) { + for (int cell_j = pad_left, j = 0; cell_j < 4 - pad_right; cell_j++, j++) { + x[cell_i][cell_j] = inptr[i*input_row_stride + j*input_col_stride]; + } + } + inptr++; + + // Compute the matrix X.T + for (int j = 0; j < 4; j++) { + XTx[0][j] = x[0][j] - x[2][j]; + XTx[1][j] = x[1][j] + x[2][j]; + XTx[2][j] = x[2][j] - x[1][j]; + XTx[3][j] = x[1][j] - x[3][j]; + } + + // Hence compute the matrix U + for (int i = 0; i < 4; i++) { + U[i][0] = XTx[i][0] - XTx[i][2]; + U[i][1] = XTx[i][1] + XTx[i][2]; + U[i][2] = XTx[i][2] - XTx[i][1]; + U[i][3] = XTx[i][1] - XTx[i][3]; + } + + // Store the matrix U + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + outptrs[i][j * matrix_stride] = U[i][j]; + } + outptrs[i]++; + } + } + } + + // Update the output pointer for future calls + outptr = outptrs[0]; +} diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp new file mode 100644 index 0000000000..a99cbe325b --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp @@ -0,0 +1,1498 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include "../input_2x2_3x3.hpp" + +#ifdef __aarch64__ +namespace winograd { + +// Pad left by one column, pad right by one column, no upper or lower padding, 4 channels +template <> +template <> +inline void Winograd2x2_3x3GemmInput::process_tile_row<0, 1, 0, 1, 4>( + const int tile_N, // Number of tiles in the row + const float* const input, // Base input pointer (appropriate to batch, channel and row) + const int input_row_stride, // Stride between rows of the input + const int input_col_stride, // Stride between columns of the input + float* const matrix, // 1st output matrix (appropriate to batch, channel and row) + const int matrix_stride, // Stride between matrices + const int matrix_row_stride // Stride between rows of the output matrix +) { + /* SIMD register allocation + * ======================== + * + * In the following code we read 4x4 tiles of a matrix `x`, with which we + * compute another matrix `X.T x` where: + * + * / 1 0 0 0 \ + * X = | 0 1 -1 1 | + * | -1 1 1 0 | + * \ 0 0 0 -1 / + * + * Hence, `X.T` is a program which operates upon rows of the matrix `X`. + * We subsequently compute and store the matrix `U = (X.T x) X`. + * + * Importantly, each iteration of the loop below loads a new matrix `x'` + * where the final two columns of `x'` are the first two columns of the + * previous `x`. That is: + * + * x11 x12 x13 x14 + * x21 x22 x23 x24 + * x31 x32 x33 x34 + * x41 x42 x43 x44 + * + * x'11 x'12 x'13 x'14 + * x'21 x'22 x'23 x'24 + * x'31 x'32 x'33 x'34 + * x'41 x'42 x'43 x'44 + * + * Consequently, while the first iteration of the below loop must load 16 + * values for `x`, the second need load only 8. *Furthermore*, since we noted + * above that the operation `X.T x` was a program which operated upon *rows* + * of the matrix `x` it follows that that the relation that `x'[i][1] = + * x[i][3]` and `x'[i][2] = x[i][4]` applies also the matrices `X.T x'` and + * `X.T x`. That is: + * + * (X.T x)11 (X.T x)12 (X.T x)13 (X.T x)14 + * (X.T x)21 (X.T x)22 (X.T x)23 (X.T x)24 + * (X.T x)31 (X.T x)32 (X.T x)33 (X.T x)34 + * (X.T x)41 (X.T x)42 (X.T x)43 (X.T x)44 + * + * (X.T x')11 (X.T x')12 (X.T x')13 (X.T x')14 + * (X.T x')12 (X.T x')12 (X.T x')12 (X.T x')12 + * (X.T x')13 (X.T x')13 (X.T x')13 (X.T x')13 + * (X.T x')14 (X.T x')14 (X.T x')14 (X.T x')14 + * + * Hence, as well as not needing to load new values for x'[i][1..2] it is + * also unnecessary to recompute values for (X.T x')[i][1..2]. + * + * Following this we break the registers into blocks `A` and `B` used by the + * two stages of the unrolled loop. These registers named such that the + * latter columns of `A` become the earlier columns of `B` and vice-versa: + * + * AXTx11 AXTx12 > AXTx13 AXTx14 | + * AXTx21 AXTx22 > AXTx23 AXTx24 | + * AXTx31 AXTx32 > AXTx33 AXTx34 | + * AXTx41 AXTx42 > AXTx43 AXTx44 | + * + * BXTx13 BXTx14 | BXTx11 BXTx12 > + * BXTx23 BXTx24 | BXTx21 BXTx22 > + * BXTx33 BXTx34 | BXTx31 BXTx32 > + * BXTx43 BXTx44 | BXTx41 BXTx42 > + * + * These 32 named registers require only 16 architectural registers. 1 + * additional architectural register is used as scratch space and 8 + * architectural registers are used to load in the values x[1..4][3,4]. + * + * Input and output addressing + * =========================== + * TODO Description + */ + const float *inptr0 = input; + const float *inptr1 = input + input_row_stride; + const float *inptr2 = input + input_row_stride * 2; + const float *inptr3 = input + input_row_stride * 3; + + float *outptr0 = matrix; + float *outptr4 = matrix + matrix_stride * 4; + float *outptr8 = matrix + matrix_stride * 8; + float *outptr12 = matrix + matrix_stride * 12; + + int tile_j = tile_N; // Tiles to process + + asm volatile ( + // Named SIMD registers according to the policy given above + // Registers into which to load the latter two columns of `x` + "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n" + "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" + "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" + "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n" + + // Registers for storing X.T x (both A and B halves) + "AXTx11 .req v8\n" "BXTx13 .req v8\n" + "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" + "AXTx21 .req v10\n" "BXTx23 .req v10\n" + "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" + "AXTx31 .req v12\n" "BXTx33 .req v12\n" + "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" + "AXTx41 .req v14\n" "BXTx43 .req v14\n" + "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" + "AXTx13 .req v16\n" "BXTx11 .req v16\n" + "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" + "AXTx23 .req v18\n" "BXTx21 .req v18\n" + "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" + "AXTx33 .req v20\n" "BXTx31 .req v20\n" + "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" + "AXTx43 .req v22\n" "BXTx41 .req v22\n" + "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" + + // Result register (TODO Does using more registers yield better + // performance) + "U .req v24\n qU .req q24\n" + + // ---------------------------------------------------------------------- + // Head of loop + // Loads a complete 4x4 tile of x, computes X.T x, computes and stores + // `U = X.T x X`. Prepares for the 'A' half of the loop. + // NOTE: Since the first tile has the leftmost column padded we can + // skip 4 loads and 4 calculations for the matrix X.T x X. + + // Temporarily alias registers for computing the first (non-padded) + // column of x. + "x_12 .req v0\n qx_12 .req q0\n" + "x_22 .req v1\n qx_22 .req q1\n" + "x_32 .req v2\n qx_32 .req q2\n" + "x_42 .req v3\n qx_42 .req q3\n" + + "ldr qx_12, [%x[inptr0]]\n" + "ldr qx_22, [%x[inptr1]]\n" + "ldr qx_32, [%x[inptr2]]\n" + "ldr qx_42, [%x[inptr3]]\n" + + "fsub BXTx12.4s, x_12.4s, x_32.4s\n" + "fadd BXTx22.4s, x_22.4s, x_32.4s\n" + "fsub BXTx32.4s, x_32.4s, x_22.4s\n" + "fsub BXTx42.4s, x_22.4s, x_42.4s\n" + + ".unreq x_12\n .unreq qx_12\n" + ".unreq x_22\n .unreq qx_22\n" + ".unreq x_32\n .unreq qx_32\n" + ".unreq x_42\n .unreq qx_42\n" + + // Load and compute latter two columns of the first tile. Progress the + // input pointers (by three columns so that the each points are the + // second column of the next tile, that is, each points at the first + // column which must be read for the next tile. + "ldr qx_13, [%x[inptr0], %x[colstride1]]\n" + "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" + "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" + "ldr qx_43, [%x[inptr3], %x[colstride1]]\n" + + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride2]]\n" + + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" + + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" + + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride2]]\n" + + "fsub BXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride3]\n" + + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride3]\n" + + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride3]\n" + + "fsub BXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride3]\n" + + // Compute and store U for the first tile + // First row + "fneg U.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fneg U.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fneg U.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row, simultaneously load the first column of inputs for the + // next tile. + "fneg U.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + // Update the loop counter, subtract two to account for both the head and + // the tail. + "subs %x[tile_j], %x[tile_j], #2\n" + "beq 2f\n" // Jump to "A" tail if out of tiles + + // ---------------------------------------------------------------------- + "1:" + // Start part A + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fsub AXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "fsub AXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" + "fsub AXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride2]\n" + "fadd AXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub AXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "fsub AXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride2]\n" + + // Compute and store U. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, AXTx12.4s, AXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, AXTx22.4s, AXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, AXTx32.4s, AXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, AXTx42.4s, AXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + "subs %x[tile_j], %x[tile_j], #1\n" + "beq 3f\n" // Jump to 'B' tail + + // Start part B + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" + "fsub BXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride2]\n" + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "fsub BXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride2]\n" + + // Compute and store U. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + "subs %x[tile_j], %x[tile_j], #1\n" + "bne 1b\n" // Continue loop, otherwise flow into 'A' tail + + // ---------------------------------------------------------------------- + "2:" + // 'A' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fsub AXTx13.4s, x_13.4s, x_33.4s\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "fsub AXTx43.4s, x_23.4s, x_43.4s\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" + + "b 4f\n" // Jump to end of function + + // ---------------------------------------------------------------------- + "3:" + // 'B' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" + + // ---------------------------------------------------------------------- + "4:" + // End of function + + // Clear names + ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n" + ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" + ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" + ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n" + ".unreq AXTx11\n" ".unreq BXTx13\n" + ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" + ".unreq AXTx21\n" ".unreq BXTx23\n" + ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" + ".unreq AXTx31\n" ".unreq BXTx33\n" + ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" + ".unreq AXTx41\n" ".unreq BXTx43\n" + ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" + ".unreq AXTx13\n" ".unreq BXTx11\n" + ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" + ".unreq AXTx23\n" ".unreq BXTx21\n" + ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" + ".unreq AXTx33\n" ".unreq BXTx31\n" + ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" + ".unreq AXTx43\n" ".unreq BXTx41\n" + ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" + ".unreq U\n" ".unreq qU\n" + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [inptr3] "+r" (inptr3), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [tile_j] "+r" (tile_j) // Tile counter + : [colstride1] "r" (1 * input_col_stride * sizeof(float)), + [colstride2] "r" (2 * input_col_stride * sizeof(float)), + [colstride3] "r" (3 * input_col_stride * sizeof(float)), + [mstride1] "r" (1 * matrix_stride * sizeof(float)), + [mstride2] "r" (2 * matrix_stride * sizeof(float)), + [mstride3] "r" (3 * matrix_stride * sizeof(float)), + [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24" + ); +} + +// Pad top, left and right by 1. +template <> +template <> +inline void Winograd2x2_3x3GemmInput::process_tile_row<1, 1, 0, 1, 4>( + const int tile_N, + const float* const input, + const int input_row_stride, + const int input_col_stride, + float* const matrix, + const int matrix_stride, + const int matrix_row_stride +) { + const float *inptr0 = input; + const float *inptr1 = input + input_row_stride; + const float *inptr2 = input + input_row_stride * 2; + + float *outptr0 = matrix; + float *outptr4 = matrix + matrix_stride * 4; + float *outptr8 = matrix + matrix_stride * 8; + float *outptr12 = matrix + matrix_stride * 12; + + int tile_j = tile_N; // Tiles to process + + asm volatile ( + // Named SIMD registers according to the policy given above + // Registers into which to load the latter two columns of `x` + // NOTE: We need only load the latter three rows since we know that the + // first row is padded. + "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" + "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" + "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n" + + // Registers for storing X.T x (both A and B halves) + "AXTx11 .req v8\n" "BXTx13 .req v8\n" + "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" + "AXTx21 .req v10\n" "BXTx23 .req v10\n" + "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" + "AXTx31 .req v12\n" "BXTx33 .req v12\n" + "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" + "AXTx41 .req v14\n" "BXTx43 .req v14\n" + "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" + "AXTx13 .req v16\n" "BXTx11 .req v16\n" + "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" + "AXTx23 .req v18\n" "BXTx21 .req v18\n" + "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" + "AXTx33 .req v20\n" "BXTx31 .req v20\n" + "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" + "AXTx43 .req v22\n" "BXTx41 .req v22\n" + "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" + + // Result register (TODO Does using more registers yield better + // performance) + "U .req v24\n qU .req q24\n" + + // ---------------------------------------------------------------------- + // Head of loop + // Loads a complete 4x4 tile of x, computes X.T x, computes and stores + // `U = X.T x X`. Prepares for the 'A' half of the loop. + // NOTE: Since the first tile has the leftmost column padded we can + // skip 4 loads and 4 calculations for the matrix X.T x X. + + // Temporarily alias registers for computing the first (non-padded) + // column of x. + "x_22 .req v1\n qx_22 .req q1\n" + "x_32 .req v2\n qx_32 .req q2\n" + "x_42 .req v3\n qx_42 .req q3\n" + + "ldr qx_22, [%x[inptr1]]\n" + "ldr qx_32, [%x[inptr2]]\n" + "ldr qx_42, [%x[inptr3]]\n" + + "fneg BXTx12.4s, x_32.4s\n" + "fadd BXTx22.4s, x_22.4s, x_32.4s\n" + "fsub BXTx32.4s, x_32.4s, x_22.4s\n" + "fsub BXTx42.4s, x_22.4s, x_42.4s\n" + + ".unreq x_22\n .unreq qx_22\n" + ".unreq x_32\n .unreq qx_32\n" + ".unreq x_42\n .unreq qx_42\n" + + // Load and compute latter two columns of the first tile. Progress the + // input pointers (by three columns so that the each points are the + // second column of the next tile, that is, each points at the first + // column which must be read for the next tile. + "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" + "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" + "ldr qx_43, [%x[inptr3], %x[colstride1]]\n" + + "fneg BXTx13.4s, x_33.4s\n" + + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" + + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" + + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride2]]\n" + + "fneg BXTx14.4s, x_34.4s\n" + + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride3]\n" + + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride3]\n" + + "fsub BXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride3]\n" + + // Compute and store U for the first tile + // First row + "fneg U.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fneg U.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fneg U.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row, simultaneously load the first column of inputs for the + // next tile. + "fneg U.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + // Update the loop counter, subtract two to account for both the head and + // the tail. + "subs %x[tile_j], %x[tile_j], #2\n" + "beq 2f\n" // Jump to "A" tail if out of tiles + + // ---------------------------------------------------------------------- + "1:" + // Start part A + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fneg AXTx13.4s, x_33.4s\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "fsub AXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" + "fneg AXTx14.4s, x_34.4s\n" + "fadd AXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub AXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "fsub AXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride2]\n" + + // Compute and store U. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, AXTx12.4s, AXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, AXTx22.4s, AXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, AXTx32.4s, AXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, AXTx42.4s, AXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + "subs %x[tile_j], %x[tile_j], #1\n" + "beq 3f\n" // Jump to 'B' tail + + // Start part B + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fneg BXTx13.4s, x_33.4s\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" + "fneg BXTx14.4s, x_34.4s\n" + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "fsub BXTx44.4s, x_24.4s, x_44.4s\n" + "add %x[inptr3], %x[inptr3], %x[colstride2]\n" + + // Compute and store U. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "ldr qx_43, [%x[inptr3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + "subs %x[tile_j], %x[tile_j], #1\n" + "bne 1b\n" // Continue loop, otherwise flow into 'A' tail + + // ---------------------------------------------------------------------- + "2:" + // 'A' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fneg AXTx13.4s, x_33.4s\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "fsub AXTx43.4s, x_23.4s, x_43.4s\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" + + "b 4f\n" // Jump to end of function + + // ---------------------------------------------------------------------- + "3:" + // 'B' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fneg BXTx13.4s, x_33.4s\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "fsub BXTx43.4s, x_23.4s, x_43.4s\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" + + // ---------------------------------------------------------------------- + "4:" + // End of function + + // Clear names + ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" + ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" + ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n" + ".unreq AXTx11\n" ".unreq BXTx13\n" + ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" + ".unreq AXTx21\n" ".unreq BXTx23\n" + ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" + ".unreq AXTx31\n" ".unreq BXTx33\n" + ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" + ".unreq AXTx41\n" ".unreq BXTx43\n" + ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" + ".unreq AXTx13\n" ".unreq BXTx11\n" + ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" + ".unreq AXTx23\n" ".unreq BXTx21\n" + ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" + ".unreq AXTx33\n" ".unreq BXTx31\n" + ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" + ".unreq AXTx43\n" ".unreq BXTx41\n" + ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" + ".unreq U\n" ".unreq qU\n" + : [inptr1] "+r" (inptr0), // Offset to account for padded row + [inptr2] "+r" (inptr1), // Offset to account for padded row + [inptr3] "+r" (inptr2), // Offset to account for padded row + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [tile_j] "+r" (tile_j) // Tile counter + : [colstride1] "r" (1 * input_col_stride * sizeof(float)), + [colstride2] "r" (2 * input_col_stride * sizeof(float)), + [colstride3] "r" (3 * input_col_stride * sizeof(float)), + [mstride1] "r" (1 * matrix_stride * sizeof(float)), + [mstride2] "r" (2 * matrix_stride * sizeof(float)), + [mstride3] "r" (3 * matrix_stride * sizeof(float)), + [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24" + ); +} + +// Pad left, right and bottom by 1. +template <> +template <> +inline void Winograd2x2_3x3GemmInput::process_tile_row<0, 1, 1, 1, 4>( + const int tile_N, + const float* const input, + const int input_row_stride, + const int input_col_stride, + float* const matrix, + const int matrix_stride, + const int matrix_row_stride +) { + const float *inptr0 = input; + const float *inptr1 = input + input_row_stride; + const float *inptr2 = input + input_row_stride * 2; + + float *outptr0 = matrix; + float *outptr4 = matrix + matrix_stride * 4; + float *outptr8 = matrix + matrix_stride * 8; + float *outptr12 = matrix + matrix_stride * 12; + + int tile_j = tile_N; // Tiles to process + + asm volatile ( + // Named SIMD registers according to the policy given above + // Registers into which to load the latter two columns of `x` + // NOTE: Bottom row is not required since since it is padded. + "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n" + "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" + "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" + + // Registers for storing X.T x (both A and B halves) + "AXTx11 .req v8\n" "BXTx13 .req v8\n" + "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" + "AXTx21 .req v10\n" "BXTx23 .req v10\n" + "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" + "AXTx31 .req v12\n" "BXTx33 .req v12\n" + "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" + "AXTx41 .req v14\n" "BXTx43 .req v14\n" + "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" + "AXTx13 .req v16\n" "BXTx11 .req v16\n" + "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" + "AXTx23 .req v18\n" "BXTx21 .req v18\n" + "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" + "AXTx33 .req v20\n" "BXTx31 .req v20\n" + "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" + "AXTx43 .req v22\n" "BXTx41 .req v22\n" + "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" + + // Result register (TODO Does using more registers yield better + // performance) + "U .req v24\n qU .req q24\n" + + // ---------------------------------------------------------------------- + // Head of loop + // Loads a complete 4x4 tile of x, computes X.T x, computes and stores + // `U = X.T x X`. Prepares for the 'A' half of the loop. + // NOTE: Since the first tile has the leftmost column padded we can + // skip 4 loads and 4 calculations for the matrix X.T x X. + + // Temporarily alias registers for computing the first (non-padded) + // column of x. + "x_12 .req v0\n qx_12 .req q0\n" + "x_22 .req v1\n qx_22 .req q1\n" + "x_32 .req v2\n qx_32 .req q2\n" + + "ldr qx_12, [%x[inptr0]]\n" + "ldr qx_22, [%x[inptr1]]\n" + "ldr qx_32, [%x[inptr2]]\n" + + "fsub BXTx12.4s, x_12.4s, x_32.4s\n" + "fadd BXTx22.4s, x_22.4s, x_32.4s\n" + "fsub BXTx32.4s, x_32.4s, x_22.4s\n" + "mov BXTx42.16b, x_22.16b\n" // Probably should do better + + ".unreq x_12\n .unreq qx_12\n" + ".unreq x_22\n .unreq qx_22\n" + ".unreq x_32\n .unreq qx_32\n" + + // Load and compute latter two columns of the first tile. Progress the + // input pointers (by three columns so that the each points are the + // second column of the next tile, that is, each points at the first + // column which must be read for the next tile. + "ldr qx_13, [%x[inptr0], %x[colstride1]]\n" + "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" + "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" + + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride2]]\n" + + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" + + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" + + "mov BXTx43.16b, x_23.16b\n" + "fsub BXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride3]\n" + + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride3]\n" + + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride3]\n" + + "mov BXTx44.16b, x_24.16b\n" + + // Compute and store U for the first tile + // First row + "fneg U.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fneg U.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fneg U.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row, simultaneously load the first column of inputs for the + // next tile. + "fneg U.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + // Update the loop counter, subtract two to account for both the head and + // the tail. + "subs %x[tile_j], %x[tile_j], #2\n" + "beq 2f\n" // Jump to "A" tail if out of tiles + + // ---------------------------------------------------------------------- + "1:" + // Start part A + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fsub AXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "mov AXTx43.16b, x_23.16b\n" + + "fsub AXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride2]\n" + "fadd AXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub AXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "mov AXTx44.16b, x_24.16b\n" + + // Compute and store U. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, AXTx12.4s, AXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, AXTx22.4s, AXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, AXTx32.4s, AXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, AXTx42.4s, AXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + + "subs %x[tile_j], %x[tile_j], #1\n" + "beq 3f\n" // Jump to 'B' tail + + // Start part B + // Load last column of this tile (the first column has already been + // loaded) and compute latter two columns of X.T x. + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" + "mov BXTx43.16b, x_23.16b\n" + + "fsub BXTx14.4s, x_14.4s, x_34.4s\n" + "add %x[inptr0], %x[inptr0], %x[colstride2]\n" + "fadd BXTx24.4s, x_24.4s, x_34.4s\n" + "add %x[inptr1], %x[inptr1], %x[colstride2]\n" + "fsub BXTx34.4s, x_34.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], %x[colstride2]\n" + "mov BXTx44.16b, x_24.16b\n" + + // Compute and store U. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, BXTx12.4s, BXTx14.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fsub U.4s, BXTx22.4s, BXTx24.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, BXTx32.4s, BXTx34.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "ldr qx_13, [%x[inptr0]]\n" + + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "ldr qx_23, [%x[inptr1]]\n" + + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "ldr qx_33, [%x[inptr2]]\n" + + "fsub U.4s, BXTx42.4s, BXTx44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + + "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" + "subs %x[tile_j], %x[tile_j], #1\n" + "bne 1b\n" // Continue loop, otherwise flow into 'A' tail + + // ---------------------------------------------------------------------- + "2:" + // 'A' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fsub AXTx13.4s, x_13.4s, x_33.4s\n" + "fadd AXTx23.4s, x_23.4s, x_33.4s\n" + "fsub AXTx33.4s, x_33.4s, x_23.4s\n" + "mov AXTx43.16b, x_23.16b\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, AXTx11.4s, AXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, AXTx12.4s, AXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, AXTx13.4s, AXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, AXTx21.4s, AXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, AXTx22.4s, AXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, AXTx23.4s, AXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, AXTx31.4s, AXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, AXTx32.4s, AXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, AXTx33.4s, AXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, AXTx41.4s, AXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, AXTx42.4s, AXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, AXTx43.4s, AXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" + + "b 4f\n" // Jump to end of function + + // ---------------------------------------------------------------------- + "3:" + // 'B' tail + // Since the final column is padding and the last-but-one column has + // already been loaded just compute the 3rd column of `X.T x'. + "fsub BXTx13.4s, x_13.4s, x_33.4s\n" + "fadd BXTx23.4s, x_23.4s, x_33.4s\n" + "fsub BXTx33.4s, x_33.4s, x_23.4s\n" + "mov BXTx43.16b, x_23.16b\n" + + // Compute and store U. Modified to account for the final column of X.T + // x containing padding. Note, it is also unnecessary to update the + // output pointers. + // First row + "fsub U.4s, BXTx11.4s, BXTx13.4s\n" + "str qU, [%x[outptr0]]\n" + "fadd U.4s, BXTx12.4s, BXTx13.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, BXTx13.4s, BXTx12.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" + + // Second row + "fsub U.4s, BXTx21.4s, BXTx23.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, BXTx22.4s, BXTx23.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fsub U.4s, BXTx23.4s, BXTx22.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" + + // Third row + "fsub U.4s, BXTx31.4s, BXTx33.4s\n" + "str qU, [%x[outptr8]]\n" + "fadd U.4s, BXTx32.4s, BXTx33.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, BXTx33.4s, BXTx32.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" + + // Fourth row + "fsub U.4s, BXTx41.4s, BXTx43.4s\n" + "str qU, [%x[outptr12]]\n" + "fadd U.4s, BXTx42.4s, BXTx43.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, BXTx43.4s, BXTx42.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" + + // ---------------------------------------------------------------------- + "4:" + // End of function + + // Clear names + ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n" + ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" + ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" + ".unreq AXTx11\n" ".unreq BXTx13\n" + ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" + ".unreq AXTx21\n" ".unreq BXTx23\n" + ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" + ".unreq AXTx31\n" ".unreq BXTx33\n" + ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" + ".unreq AXTx41\n" ".unreq BXTx43\n" + ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" + ".unreq AXTx13\n" ".unreq BXTx11\n" + ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" + ".unreq AXTx23\n" ".unreq BXTx21\n" + ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" + ".unreq AXTx33\n" ".unreq BXTx31\n" + ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" + ".unreq AXTx43\n" ".unreq BXTx41\n" + ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" + ".unreq U\n" ".unreq qU\n" + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [tile_j] "+r" (tile_j) // Tile counter + : [colstride1] "r" (1 * input_col_stride * sizeof(float)), + [colstride2] "r" (2 * input_col_stride * sizeof(float)), + [colstride3] "r" (3 * input_col_stride * sizeof(float)), + [mstride1] "r" (1 * matrix_stride * sizeof(float)), + [mstride2] "r" (2 * matrix_stride * sizeof(float)), + [mstride3] "r" (3 * matrix_stride * sizeof(float)), + [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24" + ); +} +} +#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp new file mode 100644 index 0000000000..ad1ad55291 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp @@ -0,0 +1,961 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include "../input_2x2_3x3.hpp" + +#ifdef __aarch64__ + +namespace winograd { + +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 0, 0, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 1*input_row_stride; + auto inptr2 = inptr0 + 2*input_row_stride; + auto inptr3 = inptr0 + 3*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_11 .req v0\n" "qX_11 .req q0\n" + "X_12 .req v1\n" "qX_12 .req q1\n" + "X_13 .req v2\n" "qX_13 .req q2\n" + "X_14 .req v3\n" "qX_14 .req q3\n" + "X_21 .req v4\n" "qX_21 .req q4\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_24 .req v7\n" "qX_24 .req q7\n" + "X_31 .req v8\n" "qX_31 .req q8\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_34 .req v11\n" "qX_34 .req q11\n" + "X_41 .req v12\n" "qX_41 .req q12\n" + "X_42 .req v13\n" "qX_42 .req q13\n" + "X_43 .req v14\n" "qX_43 .req q14\n" + "X_44 .req v15\n" "qX_44 .req q15\n" + "xX_11 .req v16\n" + "xX_12 .req v17\n" + "xX_13 .req v18\n" + "xX_14 .req v19\n" + "xX_21 .req v20\n" + "xX_22 .req v21\n" + "xX_23 .req v22\n" + "xX_24 .req v23\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req v27\n" + "xX_41 .req v28\n" + "xX_42 .req v29\n" + "xX_43 .req v30\n" + "xX_44 .req v31\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_11, [%x[inptr0]]\n" + "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" + "ldr qX_14, [%x[inptr0], %x[colstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qX_21, [%x[inptr1]]\n" + "fsub xX_11.4s, x_11.4s, x_13.4s\n" + "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" + "fadd xX_12.4s, x_12.4s, x_13.4s\n" + "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" + "fsub xX_13.4s, x_13.4s, x_12.4s\n" + "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" + "fsub xX_14.4s, x_12.4s, x_14.4s\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qX_31, [%x[inptr2]]\n" + "fsub xX_21.4s, x_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" + "fsub xX_24.4s, x_22.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "ldr qX_41, [%x[inptr3]]\n" + "fsub xX_31.4s, x_31.4s, x_33.4s\n" + "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "ldr qX_44, [%x[inptr3], %x[colstride3]]\n" + "fsub xX_34.4s, x_32.4s, x_34.4s\n" + "add %x[inptr3], %x[inptr3], #0x10\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fsub xX_41.4s, x_41.4s, x_43.4s\n" + + "fsub U.4s, xX_11.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fsub U.4s, xX_12.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, xX_13.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, xX_14.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd xX_42.4s, x_42.4s, x_43.4s\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub xX_43.4s, x_43.4s, x_42.4s\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fsub xX_44.4s, x_42.4s, x_44.4s\n" + + "fsub U.4s, xX_21.4s, xX_41.4s\n" + "str qU, [%x[outptr12]]\n" + "fsub U.4s, xX_22.4s, xX_42.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, xX_23.4s, xX_43.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "fsub U.4s, xX_24.4s, xX_44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq qU\n" + ".unreq U\n" + ".unreq X_11\n" ".unreq qX_11\n" + ".unreq X_12\n" ".unreq qX_12\n" + ".unreq X_13\n" ".unreq qX_13\n" + ".unreq X_14\n" ".unreq qX_14\n" + ".unreq X_21\n" ".unreq qX_21\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_24\n" ".unreq qX_24\n" + ".unreq X_31\n" ".unreq qX_31\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_34\n" ".unreq qX_34\n" + ".unreq X_41\n" ".unreq qX_41\n" + ".unreq X_42\n" ".unreq qX_42\n" + ".unreq X_43\n" ".unreq qX_43\n" + ".unreq X_44\n" ".unreq qX_44\n" + ".unreq xX_11\n" + ".unreq xX_12\n" + ".unreq xX_13\n" + ".unreq xX_14\n" + ".unreq xX_21\n" + ".unreq xX_22\n" + ".unreq xX_23\n" + ".unreq xX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + ".unreq xX_41\n" + ".unreq xX_42\n" + ".unreq xX_43\n" + ".unreq xX_44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [inptr3] "+r" (inptr3), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [colstride3] "r" (input_col_stride * sizeof(float) * 3), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} + +// Pad top by 1 +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<1, 0, 0, 0, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 0*input_row_stride; + auto inptr2 = inptr0 + 1*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_21 .req v4\n" "qX_21 .req q4\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_24 .req v7\n" "qX_24 .req q7\n" + "X_31 .req v8\n" "qX_31 .req q8\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_34 .req v11\n" "qX_34 .req q11\n" + "X_41 .req v12\n" "qX_41 .req q12\n" + "X_42 .req v13\n" "qX_42 .req q13\n" + "X_43 .req v14\n" "qX_43 .req q14\n" + "X_44 .req v15\n" "qX_44 .req q15\n" + "xX_21 .req v20\n" + "xX_22 .req v21\n" + "xX_23 .req v22\n" + "xX_24 .req v23\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req v27\n" + "xX_41 .req v28\n" + "xX_42 .req v29\n" + "xX_43 .req v30\n" + "xX_44 .req v31\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_21, [%x[inptr1]]\n" + "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" + "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" + "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qX_31, [%x[inptr2]]\n" + "fsub xX_21.4s, x_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" + "fsub xX_24.4s, x_22.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "ldr qX_41, [%x[inptr3]]\n" + "fsub xX_31.4s, x_31.4s, x_33.4s\n" + "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "ldr qX_44, [%x[inptr3], %x[colstride3]]\n" + "fsub xX_34.4s, x_32.4s, x_34.4s\n" + "add %x[inptr3], %x[inptr3], #0x10\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fsub xX_41.4s, x_41.4s, x_43.4s\n" + + "fneg U.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fneg U.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fneg U.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fneg U.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd xX_42.4s, x_42.4s, x_43.4s\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub xX_43.4s, x_43.4s, x_42.4s\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fsub xX_44.4s, x_42.4s, x_44.4s\n" + + "fsub U.4s, xX_21.4s, xX_41.4s\n" + "str qU, [%x[outptr12]]\n" + "fsub U.4s, xX_22.4s, xX_42.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, xX_23.4s, xX_43.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "fsub U.4s, xX_24.4s, xX_44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq qU\n" + ".unreq U\n" + ".unreq X_21\n" ".unreq qX_21\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_24\n" ".unreq qX_24\n" + ".unreq X_31\n" ".unreq qX_31\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_34\n" ".unreq qX_34\n" + ".unreq X_41\n" ".unreq qX_41\n" + ".unreq X_42\n" ".unreq qX_42\n" + ".unreq X_43\n" ".unreq qX_43\n" + ".unreq X_44\n" ".unreq qX_44\n" + ".unreq xX_21\n" + ".unreq xX_22\n" + ".unreq xX_23\n" + ".unreq xX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + ".unreq xX_41\n" + ".unreq xX_42\n" + ".unreq xX_43\n" + ".unreq xX_44\n" + + : [inptr1] "+r" (inptr0), // Offset for missing row + [inptr2] "+r" (inptr1), // Offset for missing row + [inptr3] "+r" (inptr2), // Offset for missing row + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [colstride3] "r" (input_col_stride * sizeof(float) * 3), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} + +// Pad left by 1 +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 1, 0, 0, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 1*input_row_stride; + auto inptr2 = inptr0 + 2*input_row_stride; + auto inptr3 = inptr0 + 3*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_12 .req v1\n" "qX_12 .req q1\n" + "X_13 .req v2\n" "qX_13 .req q2\n" + "X_14 .req v3\n" "qX_14 .req q3\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_24 .req v7\n" "qX_24 .req q7\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_34 .req v11\n" "qX_34 .req q11\n" + "X_42 .req v13\n" "qX_42 .req q13\n" + "X_43 .req v14\n" "qX_43 .req q14\n" + "X_44 .req v15\n" "qX_44 .req q15\n" + "xX_11 .req v16\n" + "xX_12 .req v17\n" + "xX_13 .req v18\n" + "xX_14 .req v19\n" + "xX_21 .req v20\n" + "xX_22 .req v21\n" + "xX_23 .req v22\n" + "xX_24 .req v23\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req v27\n" + "xX_41 .req v28\n" + "xX_42 .req v29\n" + "xX_43 .req v30\n" + "xX_44 .req v31\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_12, [%x[inptr0]]\n" + "ldr qX_13, [%x[inptr0], %x[colstride1]]\n" + "ldr qX_14, [%x[inptr0], %x[colstride2]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "fneg xX_11.4s, x_13.4s\n" + "ldr qX_22, [%x[inptr1]]\n" + "fadd xX_12.4s, x_12.4s, x_13.4s\n" + "ldr qX_23, [%x[inptr1], %x[colstride1]]\n" + "fsub xX_13.4s, x_13.4s, x_12.4s\n" + "ldr qX_24, [%x[inptr1], %x[colstride2]]\n" + "fsub xX_14.4s, x_12.4s, x_14.4s\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "fneg xX_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride1]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "ldr qX_34, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_24.4s, x_22.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "fneg xX_31.4s, x_33.4s\n" + "ldr qX_42, [%x[inptr3]]\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "ldr qX_43, [%x[inptr3], %x[colstride1]]\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "ldr qX_44, [%x[inptr3], %x[colstride2]]\n" + "fsub xX_34.4s, x_32.4s, x_34.4s\n" + "add %x[inptr3], %x[inptr3], #0x10\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fneg xX_41.4s, x_43.4s\n" + + "fsub U.4s, xX_11.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fsub U.4s, xX_12.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, xX_13.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, xX_14.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd xX_42.4s, x_42.4s, x_43.4s\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub xX_43.4s, x_43.4s, x_42.4s\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fsub xX_44.4s, x_42.4s, x_44.4s\n" + + "fsub U.4s, xX_21.4s, xX_41.4s\n" + "str qU, [%x[outptr12]]\n" + "fsub U.4s, xX_22.4s, xX_42.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, xX_23.4s, xX_43.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "fsub U.4s, xX_24.4s, xX_44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq X_12\n" ".unreq qX_12\n" + ".unreq X_13\n" ".unreq qX_13\n" + ".unreq X_14\n" ".unreq qX_14\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_24\n" ".unreq qX_24\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_34\n" ".unreq qX_34\n" + ".unreq X_42\n" ".unreq qX_42\n" + ".unreq X_43\n" ".unreq qX_43\n" + ".unreq X_44\n" ".unreq qX_44\n" + ".unreq xX_11\n" + ".unreq xX_12\n" + ".unreq xX_13\n" + ".unreq xX_14\n" + ".unreq xX_21\n" + ".unreq xX_22\n" + ".unreq xX_23\n" + ".unreq xX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + ".unreq xX_41\n" + ".unreq xX_42\n" + ".unreq xX_43\n" + ".unreq xX_44\n" + ".unreq U\n" + ".unreq qU\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [inptr3] "+r" (inptr3), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} + +// Pad bottom by 1 +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 1, 0, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 1*input_row_stride; + auto inptr2 = inptr0 + 2*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_11 .req v0\n" "qX_11 .req q0\n" + "X_12 .req v1\n" "qX_12 .req q1\n" + "X_13 .req v2\n" "qX_13 .req q2\n" + "X_14 .req v3\n" "qX_14 .req q3\n" + "X_21 .req v4\n" "qX_21 .req q4\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_24 .req v7\n" "qX_24 .req q7\n" + "X_31 .req v8\n" "qX_31 .req q8\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_34 .req v11\n" "qX_34 .req q11\n" + "xX_11 .req v16\n" + "xX_12 .req v17\n" + "xX_13 .req v18\n" + "xX_14 .req v19\n" + "xX_21 .req v20\n" "qxX_21 .req q20\n" + "xX_22 .req v21\n" "qxX_22 .req q21\n" + "xX_23 .req v22\n" "qxX_23 .req q22\n" + "xX_24 .req v23\n" "qxX_24 .req q23\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req v27\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_11, [%x[inptr0]]\n" + "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" + "ldr qX_14, [%x[inptr0], %x[colstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qX_21, [%x[inptr1]]\n" + "fsub xX_11.4s, x_11.4s, x_13.4s\n" + "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" + "fadd xX_12.4s, x_12.4s, x_13.4s\n" + "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" + "fsub xX_13.4s, x_13.4s, x_12.4s\n" + "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" + "fsub xX_14.4s, x_12.4s, x_14.4s\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qX_31, [%x[inptr2]]\n" + "fsub xX_21.4s, x_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" + "fsub xX_24.4s, x_22.4s, x_24.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "fsub xX_31.4s, x_31.4s, x_33.4s\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "fsub xX_34.4s, x_32.4s, x_34.4s\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fsub U.4s, xX_11.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fsub U.4s, xX_12.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, xX_13.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, xX_14.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "str qxX_21, [%x[outptr12]]\n" + "str qxX_22, [%x[outptr12], %x[mstride1]]\n" + "str qxX_23, [%x[outptr12], %x[mstride2]]\n" + "str qxX_24, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq qU\n" + ".unreq U\n" + ".unreq X_11\n" ".unreq qX_11\n" + ".unreq X_12\n" ".unreq qX_12\n" + ".unreq X_13\n" ".unreq qX_13\n" + ".unreq X_14\n" ".unreq qX_14\n" + ".unreq X_21\n" ".unreq qX_21\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_24\n" ".unreq qX_24\n" + ".unreq X_31\n" ".unreq qX_31\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_34\n" ".unreq qX_34\n" + ".unreq xX_11\n" + ".unreq xX_12\n" + ".unreq xX_13\n" + ".unreq xX_14\n" + ".unreq xX_21\n" ".unreq qxX_21\n" + ".unreq xX_22\n" ".unreq qxX_22\n" + ".unreq xX_23\n" ".unreq qxX_23\n" + ".unreq xX_24\n" ".unreq qxX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [colstride3] "r" (input_col_stride * sizeof(float) * 3), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} + +// Pad right by 1 +template <> +template <> +inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 0, 1, 4>( + int &n_channels, // Number of channels in the tile + const float* &inptr0, + const int input_row_stride, + const int input_col_stride, + float* &outptr0, + const int matrix_stride +) { + // We use 4 pointers to point to the starting position on each row and use + // three offsets to extract elements from each of the other 3 columns. + auto inptr1 = inptr0 + 1*input_row_stride; + auto inptr2 = inptr0 + 2*input_row_stride; + auto inptr3 = inptr0 + 3*input_row_stride; + + // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three + // offsets to access the intermediate matrices. + auto outptr1 = outptr0 + matrix_stride * 4; + auto outptr2 = outptr0 + matrix_stride * 8; + auto outptr3 = outptr0 + matrix_stride * 12; + + for (; n_channels > 3; n_channels -= 4) { + asm volatile ( + "X_11 .req v0\n" "qX_11 .req q0\n" + "X_12 .req v1\n" "qX_12 .req q1\n" + "X_13 .req v2\n" "qX_13 .req q2\n" + "X_21 .req v4\n" "qX_21 .req q4\n" + "X_22 .req v5\n" "qX_22 .req q5\n" + "X_23 .req v6\n" "qX_23 .req q6\n" + "X_31 .req v8\n" "qX_31 .req q8\n" + "X_32 .req v9\n" "qX_32 .req q9\n" + "X_33 .req v10\n" "qX_33 .req q10\n" + "X_41 .req v12\n" "qX_41 .req q12\n" + "X_42 .req v13\n" "qX_42 .req q13\n" + "X_43 .req v14\n" "qX_43 .req q14\n" + "xX_11 .req v16\n" + "xX_12 .req v17\n" + "xX_13 .req v18\n" + "xX_14 .req x_12\n" + "xX_21 .req v20\n" + "xX_22 .req v21\n" + "xX_23 .req v22\n" + "xX_24 .req x_22\n" + "xX_31 .req v24\n" + "xX_32 .req v25\n" + "xX_33 .req v26\n" + "xX_34 .req x_32\n" + "xX_41 .req v28\n" + "xX_42 .req v29\n" + "xX_43 .req v30\n" + "xX_44 .req x_42\n" + " U .req v0\n" + "qU .req q0\n" + + // Load the tile, and compute compute the matrix xX + "ldr qX_11, [%x[inptr0]]\n" + "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qX_21, [%x[inptr1]]\n" + "fsub xX_11.4s, x_11.4s, x_13.4s\n" + "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" + "fadd xX_12.4s, x_12.4s, x_13.4s\n" + "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" + "fsub xX_13.4s, x_13.4s, x_12.4s\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qX_31, [%x[inptr2]]\n" + "fsub xX_21.4s, x_21.4s, x_23.4s\n" + "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" + "fadd xX_22.4s, x_22.4s, x_23.4s\n" + "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" + "fsub xX_23.4s, x_23.4s, x_22.4s\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + "ldr qX_41, [%x[inptr3]]\n" + "fsub xX_31.4s, x_31.4s, x_33.4s\n" + "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" + "fadd xX_32.4s, x_32.4s, x_33.4s\n" + "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" + "fsub xX_33.4s, x_33.4s, x_32.4s\n" + "add %x[inptr3], %x[inptr3], #0x10\n" + + // Complete computing xX while beginning to compute and store + // $U = X.T x X$ + + "fsub xX_41.4s, x_41.4s, x_43.4s\n" + + "fsub U.4s, xX_11.4s, xX_31.4s\n" + "str qU, [%x[outptr0]]\n" + "fsub U.4s, xX_12.4s, xX_32.4s\n" + "str qU, [%x[outptr0], %x[mstride1]]\n" + "fsub U.4s, xX_13.4s, xX_33.4s\n" + "str qU, [%x[outptr0], %x[mstride2]]\n" + "fsub U.4s, xX_14.4s, xX_34.4s\n" + "str qU, [%x[outptr0], %x[mstride3]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd xX_42.4s, x_42.4s, x_43.4s\n" + + "fadd U.4s, xX_21.4s, xX_31.4s\n" + "str qU, [%x[outptr4]]\n" + "fadd U.4s, xX_22.4s, xX_32.4s\n" + "str qU, [%x[outptr4], %x[mstride1]]\n" + "fadd U.4s, xX_23.4s, xX_33.4s\n" + "str qU, [%x[outptr4], %x[mstride2]]\n" + "fadd U.4s, xX_24.4s, xX_34.4s\n" + "str qU, [%x[outptr4], %x[mstride3]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fsub xX_43.4s, x_43.4s, x_42.4s\n" + + "fsub U.4s, xX_31.4s, xX_21.4s\n" + "str qU, [%x[outptr8]]\n" + "fsub U.4s, xX_32.4s, xX_22.4s\n" + "str qU, [%x[outptr8], %x[mstride1]]\n" + "fsub U.4s, xX_33.4s, xX_23.4s\n" + "str qU, [%x[outptr8], %x[mstride2]]\n" + "fsub U.4s, xX_34.4s, xX_24.4s\n" + "str qU, [%x[outptr8], %x[mstride3]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fsub U.4s, xX_21.4s, xX_41.4s\n" + "str qU, [%x[outptr12]]\n" + "fsub U.4s, xX_22.4s, xX_42.4s\n" + "str qU, [%x[outptr12], %x[mstride1]]\n" + "fsub U.4s, xX_23.4s, xX_43.4s\n" + "str qU, [%x[outptr12], %x[mstride2]]\n" + "fsub U.4s, xX_24.4s, xX_44.4s\n" + "str qU, [%x[outptr12], %x[mstride3]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + ".unreq qU\n" + ".unreq U\n" + ".unreq X_11\n" ".unreq qX_11\n" + ".unreq X_12\n" ".unreq qX_12\n" + ".unreq X_13\n" ".unreq qX_13\n" + ".unreq X_21\n" ".unreq qX_21\n" + ".unreq X_22\n" ".unreq qX_22\n" + ".unreq X_23\n" ".unreq qX_23\n" + ".unreq X_31\n" ".unreq qX_31\n" + ".unreq X_32\n" ".unreq qX_32\n" + ".unreq X_33\n" ".unreq qX_33\n" + ".unreq X_41\n" ".unreq qX_41\n" + ".unreq X_42\n" ".unreq qX_42\n" + ".unreq X_43\n" ".unreq qX_43\n" + ".unreq xX_11\n" + ".unreq xX_12\n" + ".unreq xX_13\n" + ".unreq xX_14\n" + ".unreq xX_21\n" + ".unreq xX_22\n" + ".unreq xX_23\n" + ".unreq xX_24\n" + ".unreq xX_31\n" + ".unreq xX_32\n" + ".unreq xX_33\n" + ".unreq xX_34\n" + ".unreq xX_41\n" + ".unreq xX_42\n" + ".unreq xX_43\n" + ".unreq xX_44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [inptr3] "+r" (inptr3), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr1), + [outptr8] "+r" (outptr2), + [outptr12] "+r" (outptr3) + : [colstride1] "r" (input_col_stride * sizeof(float)), + [colstride2] "r" (input_col_stride * sizeof(float) * 2), + [mstride1] "r" (matrix_stride * sizeof(float)), + [mstride2] "r" (matrix_stride * sizeof(float) * 2), + [mstride3] "r" (matrix_stride * sizeof(float) * 3) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31" + ); + } +} +} +#endif diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp new file mode 100644 index 0000000000..033442aa14 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +namespace winograd { + /* Transform a kernel into the Winograd domain. + * + * NOTE: It is assumed that the kernel is in the form [height x width x + * input_channels x output_channel]. + */ + template + struct winograd2x2_3x3_gemm_kernel_transform_impl{ + static void execute( + const KernelShape &shape, + const T* const kernel, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride + ); + + protected: + template + static void transform_kernel( + const T* const kernel, + const int n_input_channels, + const int n_output_channels, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride + ); + }; +} + +/*****************************************************************************/ +/* Transform a fp32 kernel into the Winograd domain. + */ +#include "kernel_2x2_3x3/a64_float.hpp" // AArch64 specialisations + +namespace winograd +{ +template <> +inline void winograd2x2_3x3_gemm_kernel_transform_impl::execute( + const KernelShape &shape, + const float* const kernel, + float* const matrix_base, + const int matrix_stride, + const int matrix_row_stride +) { + // Delegate based on tail size + const int n_input_channels = shape.n_input_channels; + const int n_output_channels = shape.n_output_channels; + + switch (n_output_channels % 4) { + case 0: + transform_kernel<0>( + kernel, n_input_channels, n_output_channels, + matrix_base, matrix_stride, matrix_row_stride + ); + break; + case 1: + transform_kernel<1>( + kernel, n_input_channels, n_output_channels, + matrix_base, matrix_stride, matrix_row_stride + ); + break; + case 2: + transform_kernel<2>( + kernel, n_input_channels, n_output_channels, + matrix_base, matrix_stride, matrix_row_stride + ); + break; + case 3: + transform_kernel<3>( + kernel, n_input_channels, n_output_channels, + matrix_base, matrix_stride, matrix_row_stride + ); + break; + default: + ARM_COMPUTE_ERROR("Cannot happen"); + break; + } +} + +template <> +template +inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel( + const float* const kernel, + const int n_input_channels, + const int n_output_channels, + float* const matrix_base, + const int mstride, + const int matrix_row_stride +) { + // Use one input pointer for each row of the kernel, use two additional + // offsets to extract columns. + const int kernel_col_stride = n_input_channels * n_output_channels; + const int kernel_row_stride = 3 * kernel_col_stride; + const float *inptr0 = kernel; + const float *inptr1 = kernel + kernel_row_stride; + const float *inptr2 = kernel + kernel_row_stride*2; + + // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three + // offsets to extract further matrices. + float *outptr0 = matrix_base; + float *outptr4 = matrix_base + mstride * 4; + float *outptr8 = matrix_base + mstride * 8; + float *outptr12 = matrix_base + mstride * 12; + + // For every input channel + for (int in_c = 0; in_c < n_input_channels; in_c++) { + // For every output channel + for (int c = 0; c < n_output_channels; c++) { + // Read in the kernel + float w11 = inptr0[0], w12 = inptr0[kernel_col_stride], w13 = inptr0[kernel_col_stride*2]; + float w21 = inptr1[0], w22 = inptr1[kernel_col_stride], w23 = inptr1[kernel_col_stride*2]; + float w31 = inptr2[0], w32 = inptr2[kernel_col_stride], w33 = inptr2[kernel_col_stride*2]; + + // Progress input pointers + inptr0++; + inptr1++; + inptr2++; + + // Compute the kernel W w, note we need only compute the middle two rows + // (2 and 3) because the first and last rows are merely copies of values + // from the matrix w. + float Ww11 = w11, Ww12 = w12, Ww13 = w13; + float Ww21 = 0.5*(w11 + w21 + w31), Ww22 = 0.5*(w12 + w22 + w32), Ww23 = 0.5*(w13 + w23 + w33); + float Ww31 = 0.5*(w11 - w21 + w31), Ww32 = 0.5*(w12 - w22 + w32), Ww33 = 0.5*(w13 - w23 + w33); + float Ww41 = w31, Ww42 = w32, Ww43 = w33; + + // Hence compute W w W.T; again note we need compute only the middle two + // columns since the first and last columns are copies of the first and + // last columns of the previous matrix. + float WwWT11 = Ww11, WwWT12 = 0.5*(Ww11 + Ww12 + Ww13), WwWT13 = 0.5*(Ww11 - Ww12 + Ww13), WwWT14 = Ww13; + float WwWT21 = Ww21, WwWT22 = 0.5*(Ww21 + Ww22 + Ww23), WwWT23 = 0.5*(Ww21 - Ww22 + Ww23), WwWT24 = Ww23; + float WwWT31 = Ww31, WwWT32 = 0.5*(Ww31 + Ww32 + Ww33), WwWT33 = 0.5*(Ww31 - Ww32 + Ww33), WwWT34 = Ww33; + float WwWT41 = Ww41, WwWT42 = 0.5*(Ww41 + Ww42 + Ww43), WwWT43 = 0.5*(Ww41 - Ww42 + Ww43), WwWT44 = Ww43; + + // Store the computed weights + outptr0[0 * mstride] = WwWT11; + outptr0[1 * mstride] = WwWT12; + outptr0[2 * mstride] = WwWT13; + outptr0[3 * mstride] = WwWT14; + + outptr4[0 * mstride] = WwWT21; + outptr4[1 * mstride] = WwWT22; + outptr4[2 * mstride] = WwWT23; + outptr4[3 * mstride] = WwWT24; + + outptr8[0 * mstride] = WwWT31; + outptr8[1 * mstride] = WwWT32; + outptr8[2 * mstride] = WwWT33; + outptr8[3 * mstride] = WwWT34; + + outptr12[0 * mstride] = WwWT41; + outptr12[1 * mstride] = WwWT42; + outptr12[2 * mstride] = WwWT43; + outptr12[3 * mstride] = WwWT44; + + // Progress output pointers + outptr0++; + outptr4++; + outptr8++; + outptr12++; + } + + // Progression to complete stride + outptr0 += matrix_row_stride - n_output_channels; + outptr4 += matrix_row_stride - n_output_channels; + outptr8 += matrix_row_stride - n_output_channels; + outptr12 += matrix_row_stride - n_output_channels; + } +} +} diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp new file mode 100644 index 0000000000..3dd62d1ac1 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp @@ -0,0 +1,822 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ +namespace winograd { +template <> +template <> +inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<0>( + const float* const kernel, + const int n_input_channels, + const int n_output_channels, + float* const matrix_base, + const int mstride, + const int matrix_row_stride +) { + // Use one input pointer for each row of the kernel, use two additional + // offsets to extract columns. + const int kernel_col_stride = n_input_channels * n_output_channels; + const int kernel_row_stride = 3 * kernel_col_stride; + const float *inptr0 = kernel; + const float *inptr1 = kernel + kernel_row_stride; + const float *inptr2 = kernel + kernel_row_stride*2; + + // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three + // offsets to extract further matrices. + float *outptr0 = matrix_base; + float *outptr4 = matrix_base + mstride * 4; + float *outptr8 = matrix_base + mstride * 8; + float *outptr12 = matrix_base + mstride * 12; + + // For every input channel + for (int in_c = 0; in_c < n_input_channels; in_c++) { + int n_remaining_channels = n_output_channels; + + asm volatile ( + // Registers into which to read the kernel + "w_11 .req v0\n" "qw_11 .req q0\n" + "w_12 .req v1\n" "qw_12 .req q1\n" + "w_13 .req v2\n" "qw_13 .req q2\n" + "w_21 .req v3\n" "qw_21 .req q3\n" + "w_22 .req v4\n" "qw_22 .req q4\n" + "w_23 .req v5\n" "qw_23 .req q5\n" + "w_31 .req v6\n" "qw_31 .req q6\n" + "w_32 .req v7\n" "qw_32 .req q7\n" + "w_33 .req v8\n" "qw_33 .req q8\n" + + // Transformed matrix Ww + "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" + "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" + "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" + "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" + + // Output matrix U = WwWT + "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" + "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" + "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" + "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" + + // Storage view of output matrices + "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" + "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" + "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" + "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" + + "half .req v23\n" // {0.5, ..., 0.5} + "dup half.4s, %w[one_half]\n" + "scratch .req v24\n" + + "1:" + // Load tile of the kernel + "ldr qw_11, [%x[inptr0]]\n" + "str qU11, [%x[outptr0]]\n" + "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" + "str qU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qw_21, [%x[inptr1]]\n" + "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qw_31, [%x[inptr2]]\n" + "str qU41, [%x[outptr12]]\n" + "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" + "str qU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.4s, w_11.4s, w_31.4s\n" + "fmul Ww21.4s, scratch.4s, half.4s\n" + "fmla Ww21.4s, w_21.4s, half.4s\n" + "str qU21, [%x[outptr4]]\n" + "fmul Ww31.4s, scratch.4s, half.4s\n" + "fmls Ww31.4s, w_21.4s, half.4s\n" + "str qU31, [%x[outptr8]]\n" + + "fadd scratch.4s, w_12.4s, w_32.4s\n" + "fmul Ww22.4s, scratch.4s, half.4s\n" + "fmla Ww22.4s, w_22.4s, half.4s\n" + "fmul Ww32.4s, scratch.4s, half.4s\n" + "fmls Ww32.4s, w_22.4s, half.4s\n" + + "fadd scratch.4s, w_13.4s, w_33.4s\n" + "fmul Ww23.4s, scratch.4s, half.4s\n" + "fmla Ww23.4s, w_23.4s, half.4s\n" + "str qU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.4s, scratch.4s, half.4s\n" + "fmls Ww33.4s, w_23.4s, half.4s\n" + "str qU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns + // of U and update output pointers + "fadd scratch.4s, Ww11.4s, Ww13.4s\n" + "fmul U12.4s, scratch.4s, half.4s\n" + "fmla U12.4s, Ww12.4s, half.4s\n" + "str qU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.4s, scratch.4s, half.4s\n" + "fmls U13.4s, Ww12.4s, half.4s\n" + "str qU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd scratch.4s, Ww21.4s, Ww23.4s\n" + "fmul U22.4s, scratch.4s, half.4s\n" + "fmla U22.4s, Ww22.4s, half.4s\n" + "str qU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.4s, scratch.4s, half.4s\n" + "fmls U23.4s, Ww22.4s, half.4s\n" + "str qU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fadd scratch.4s, Ww31.4s, Ww33.4s\n" + "fmul U32.4s, scratch.4s, half.4s\n" + "fmla U32.4s, Ww32.4s, half.4s\n" + "str qU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.4s, scratch.4s, half.4s\n" + "fmls U33.4s, Ww32.4s, half.4s\n" + "str qU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fadd scratch.4s, Ww41.4s, Ww43.4s\n" + "fmul U42.4s, scratch.4s, half.4s\n" + "fmla U42.4s, Ww42.4s, half.4s\n" + "str qU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.4s, scratch.4s, half.4s\n" + "fmls U43.4s, Ww42.4s, half.4s\n" + "str qU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" + "bne 1b\n" + + // Clear aliases + ".unreq half\n" + ".unreq scratch\n" + ".unreq w_11\n" ".unreq qw_11\n" + ".unreq w_12\n" ".unreq qw_12\n" + ".unreq w_13\n" ".unreq qw_13\n" + ".unreq w_21\n" ".unreq qw_21\n" + ".unreq w_22\n" ".unreq qw_22\n" + ".unreq w_23\n" ".unreq qw_23\n" + ".unreq w_31\n" ".unreq qw_31\n" + ".unreq w_32\n" ".unreq qw_32\n" + ".unreq w_33\n" ".unreq qw_33\n" + ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" + ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" + ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" + ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" + ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" + ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" + ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" + ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" + ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" + ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" + ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" + ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [n_remaining_channels] "+r" (n_remaining_channels) + : [mstride1] "r" (sizeof(float) * mstride), + [mstride2] "r" (sizeof(float) * mstride * 2), + [mstride3] "r" (sizeof(float) * mstride * 3), + [colstride1] "r" (sizeof(float) * kernel_col_stride), + [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), + [one_half] "r" (0.5f) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24" + ); + + // Progression to complete stride + outptr0 += matrix_row_stride - n_output_channels; + outptr4 += matrix_row_stride - n_output_channels; + outptr8 += matrix_row_stride - n_output_channels; + outptr12 += matrix_row_stride - n_output_channels; + } +} + +template <> +template <> +inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<2>( + const float* const kernel, + const int n_input_channels, + const int n_output_channels, + float* const matrix_base, + const int mstride, + const int matrix_row_stride +) { + // Use one input pointer for each row of the kernel, use two additional + // offsets to extract columns. + const int kernel_col_stride = n_input_channels * n_output_channels; + const int kernel_row_stride = 3 * kernel_col_stride; + const float *inptr0 = kernel; + const float *inptr1 = kernel + kernel_row_stride; + const float *inptr2 = kernel + kernel_row_stride*2; + + // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three + // offsets to extract further matrices. + float *outptr0 = matrix_base; + float *outptr4 = matrix_base + mstride * 4; + float *outptr8 = matrix_base + mstride * 8; + float *outptr12 = matrix_base + mstride * 12; + + // For every input channel + for (int in_c = 0; in_c < n_input_channels; in_c++) { + int n_remaining_channels = n_output_channels; + + asm volatile ( + // Registers into which to read the kernel + "w_11 .req v0\n" "qw_11 .req q0\n" "dw_11 .req d0\n" + "w_12 .req v1\n" "qw_12 .req q1\n" "dw_12 .req d1\n" + "w_13 .req v2\n" "qw_13 .req q2\n" "dw_13 .req d2\n" + "w_21 .req v3\n" "qw_21 .req q3\n" "dw_21 .req d3\n" + "w_22 .req v4\n" "qw_22 .req q4\n" "dw_22 .req d4\n" + "w_23 .req v5\n" "qw_23 .req q5\n" "dw_23 .req d5\n" + "w_31 .req v6\n" "qw_31 .req q6\n" "dw_31 .req d6\n" + "w_32 .req v7\n" "qw_32 .req q7\n" "dw_32 .req d7\n" + "w_33 .req v8\n" "qw_33 .req q8\n" "dw_33 .req d8\n" + + // Transformed matrix Ww + "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" + "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" + "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" + "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" + + // Output matrix U = WwWT + "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" + "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" + "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" + "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" + + // Storage view of output matrices + "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" + "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" + "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" + "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" + + "dU11 .req d0\n" "dU12 .req d15\n" "dU13 .req d16\n" "dU14 .req d2\n" + "dU21 .req d9\n" "dU22 .req d17\n" "dU23 .req d18\n" "dU24 .req d11\n" + "dU31 .req d12\n" "dU32 .req d19\n" "dU33 .req d20\n" "dU34 .req d14\n" + "dU41 .req d6\n" "dU42 .req d21\n" "dU43 .req d22\n" "dU44 .req d8\n" + + "half .req v23\n" // {0.5, ..., 0.5} + "dup half.4s, %w[one_half]\n" + "scratch .req v24\n" + + // Subtract the tail from the number of remaining channels and jump to + // the tail if necessary. + "subs %x[n_remaining_channels], %x[n_remaining_channels], #2\n" + "beq 2f\n" + + "1:" + // Load tile of the kernel + "ldr qw_11, [%x[inptr0]]\n" + "str qU11, [%x[outptr0]]\n" + "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" + "str qU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qw_21, [%x[inptr1]]\n" + "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qw_31, [%x[inptr2]]\n" + "str qU41, [%x[outptr12]]\n" + "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" + "str qU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.4s, w_11.4s, w_31.4s\n" + "fmul Ww21.4s, scratch.4s, half.4s\n" + "fmla Ww21.4s, w_21.4s, half.4s\n" + "str qU21, [%x[outptr4]]\n" + "fmul Ww31.4s, scratch.4s, half.4s\n" + "fmls Ww31.4s, w_21.4s, half.4s\n" + "str qU31, [%x[outptr8]]\n" + + "fadd scratch.4s, w_12.4s, w_32.4s\n" + "fmul Ww22.4s, scratch.4s, half.4s\n" + "fmla Ww22.4s, w_22.4s, half.4s\n" + "fmul Ww32.4s, scratch.4s, half.4s\n" + "fmls Ww32.4s, w_22.4s, half.4s\n" + + "fadd scratch.4s, w_13.4s, w_33.4s\n" + "fmul Ww23.4s, scratch.4s, half.4s\n" + "fmla Ww23.4s, w_23.4s, half.4s\n" + "str qU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.4s, scratch.4s, half.4s\n" + "fmls Ww33.4s, w_23.4s, half.4s\n" + "str qU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns + // of U and update output pointers + "fadd scratch.4s, Ww11.4s, Ww13.4s\n" + "fmul U12.4s, scratch.4s, half.4s\n" + "fmla U12.4s, Ww12.4s, half.4s\n" + "str qU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.4s, scratch.4s, half.4s\n" + "fmls U13.4s, Ww12.4s, half.4s\n" + "str qU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd scratch.4s, Ww21.4s, Ww23.4s\n" + "fmul U22.4s, scratch.4s, half.4s\n" + "fmla U22.4s, Ww22.4s, half.4s\n" + "str qU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.4s, scratch.4s, half.4s\n" + "fmls U23.4s, Ww22.4s, half.4s\n" + "str qU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fadd scratch.4s, Ww31.4s, Ww33.4s\n" + "fmul U32.4s, scratch.4s, half.4s\n" + "fmla U32.4s, Ww32.4s, half.4s\n" + "str qU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.4s, scratch.4s, half.4s\n" + "fmls U33.4s, Ww32.4s, half.4s\n" + "str qU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fadd scratch.4s, Ww41.4s, Ww43.4s\n" + "fmul U42.4s, scratch.4s, half.4s\n" + "fmla U42.4s, Ww42.4s, half.4s\n" + "str qU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.4s, scratch.4s, half.4s\n" + "fmls U43.4s, Ww42.4s, half.4s\n" + "str qU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" + "bne 1b\n" + + // Tail size 2 + "2:" + // Load tile of the kernel + "ldr dw_11, [%x[inptr0]]\n" + "str dU11, [%x[outptr0]]\n" + "ldr dw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr dw_13, [%x[inptr0], %x[colstride2]]\n" + "str dU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x08\n" + + "ldr dw_21, [%x[inptr1]]\n" + "ldr dw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr dw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x08\n" + + "ldr dw_31, [%x[inptr2]]\n" + "str dU41, [%x[outptr12]]\n" + "ldr dw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr dw_33, [%x[inptr2], %x[colstride2]]\n" + "str dU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x08\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.2s, w_11.2s, w_31.2s\n" + "fmul Ww21.2s, scratch.2s, half.2s\n" + "fmla Ww21.2s, w_21.2s, half.2s\n" + "str dU21, [%x[outptr4]]\n" + "fmul Ww31.2s, scratch.2s, half.2s\n" + "fmls Ww31.2s, w_21.2s, half.2s\n" + "str dU31, [%x[outptr8]]\n" + + "fadd scratch.2s, w_12.2s, w_32.2s\n" + "fmul Ww22.2s, scratch.2s, half.2s\n" + "fmla Ww22.2s, w_22.2s, half.2s\n" + "fmul Ww32.2s, scratch.2s, half.2s\n" + "fmls Ww32.2s, w_22.2s, half.2s\n" + + "fadd scratch.2s, w_13.2s, w_33.2s\n" + "fmul Ww23.2s, scratch.2s, half.2s\n" + "fmla Ww23.2s, w_23.2s, half.2s\n" + "str dU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.2s, scratch.2s, half.2s\n" + "fmls Ww33.2s, w_23.2s, half.2s\n" + "str dU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns of + // U and update output pointers + "fadd scratch.2s, Ww11.2s, Ww13.2s\n" + "fmul U12.2s, scratch.2s, half.2s\n" + "fmla U12.2s, Ww12.2s, half.2s\n" + "str dU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.2s, scratch.2s, half.2s\n" + "fmls U13.2s, Ww12.2s, half.2s\n" + "str dU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x08\n" + + "fadd scratch.2s, Ww21.2s, Ww23.2s\n" + "fmul U22.2s, scratch.2s, half.2s\n" + "fmla U22.2s, Ww22.2s, half.2s\n" + "str dU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.2s, scratch.2s, half.2s\n" + "fmls U23.2s, Ww22.2s, half.2s\n" + "str dU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x08\n" + + "fadd scratch.2s, Ww31.2s, Ww33.2s\n" + "fmul U32.2s, scratch.2s, half.2s\n" + "fmla U32.2s, Ww32.2s, half.2s\n" + "str dU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.2s, scratch.2s, half.2s\n" + "fmls U33.2s, Ww32.2s, half.2s\n" + "str dU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x08\n" + + "fadd scratch.2s, Ww41.2s, Ww43.2s\n" + "fmul U42.2s, scratch.2s, half.2s\n" + "fmla U42.2s, Ww42.2s, half.2s\n" + "str dU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.2s, scratch.2s, half.2s\n" + "fmls U43.2s, Ww42.2s, half.2s\n" + "str dU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x08\n" + + // Clear aliases + ".unreq half\n" + ".unreq scratch\n" + ".unreq w_11\n" ".unreq qw_11\n" ".unreq dw_11\n" + ".unreq w_12\n" ".unreq qw_12\n" ".unreq dw_12\n" + ".unreq w_13\n" ".unreq qw_13\n" ".unreq dw_13\n" + ".unreq w_21\n" ".unreq qw_21\n" ".unreq dw_21\n" + ".unreq w_22\n" ".unreq qw_22\n" ".unreq dw_22\n" + ".unreq w_23\n" ".unreq qw_23\n" ".unreq dw_23\n" + ".unreq w_31\n" ".unreq qw_31\n" ".unreq dw_31\n" + ".unreq w_32\n" ".unreq qw_32\n" ".unreq dw_32\n" + ".unreq w_33\n" ".unreq qw_33\n" ".unreq dw_33\n" + ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" + ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" + ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" + ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" + ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" + ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" + ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" + ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" + ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" + ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" + ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" + ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" + ".unreq dU11\n" ".unreq dU12\n" ".unreq dU13\n" ".unreq dU14\n" + ".unreq dU21\n" ".unreq dU22\n" ".unreq dU23\n" ".unreq dU24\n" + ".unreq dU31\n" ".unreq dU32\n" ".unreq dU33\n" ".unreq dU34\n" + ".unreq dU41\n" ".unreq dU42\n" ".unreq dU43\n" ".unreq dU44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [n_remaining_channels] "+r" (n_remaining_channels) + : [mstride1] "r" (sizeof(float) * mstride), + [mstride2] "r" (sizeof(float) * mstride * 2), + [mstride3] "r" (sizeof(float) * mstride * 3), + [colstride1] "r" (sizeof(float) * kernel_col_stride), + [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), + [one_half] "r" (0.5f) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24" + ); + + // Progression to complete stride + outptr0 += matrix_row_stride - n_output_channels; + outptr4 += matrix_row_stride - n_output_channels; + outptr8 += matrix_row_stride - n_output_channels; + outptr12 += matrix_row_stride - n_output_channels; + } +} + +template <> +template <> +inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<1>( + const float* const kernel, + const int n_input_channels, + const int n_output_channels, + float* const matrix_base, + const int mstride, + const int matrix_row_stride +) { + // Use one input pointer for each row of the kernel, use two additional + // offsets to extract columns. + const int kernel_col_stride = n_input_channels * n_output_channels; + const int kernel_row_stride = 3 * kernel_col_stride; + const float *inptr0 = kernel; + const float *inptr1 = kernel + kernel_row_stride; + const float *inptr2 = kernel + kernel_row_stride*2; + + // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three + // offsets to extract further matrices. + float *outptr0 = matrix_base; + float *outptr4 = matrix_base + mstride * 4; + float *outptr8 = matrix_base + mstride * 8; + float *outptr12 = matrix_base + mstride * 12; + + // For every input channel + for (int in_c = 0; in_c < n_input_channels; in_c++) { + int n_remaining_channels = n_output_channels; + + asm volatile ( + // Registers into which to read the kernel + "w_11 .req v0\n" "qw_11 .req q0\n" "sw_11 .req s0\n" + "w_12 .req v1\n" "qw_12 .req q1\n" "sw_12 .req s1\n" + "w_13 .req v2\n" "qw_13 .req q2\n" "sw_13 .req s2\n" + "w_21 .req v3\n" "qw_21 .req q3\n" "sw_21 .req s3\n" + "w_22 .req v4\n" "qw_22 .req q4\n" "sw_22 .req s4\n" + "w_23 .req v5\n" "qw_23 .req q5\n" "sw_23 .req s5\n" + "w_31 .req v6\n" "qw_31 .req q6\n" "sw_31 .req s6\n" + "w_32 .req v7\n" "qw_32 .req q7\n" "sw_32 .req s7\n" + "w_33 .req v8\n" "qw_33 .req q8\n" "sw_33 .req s8\n" + + // Transformed matrix Ww + "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" + "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" + "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" + "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" + + // Output matrix U = WwWT + "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" + "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" + "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" + "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" + + // Storage view of output matrices + "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" + "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" + "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" + "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" + + "sU11 .req s0\n" "sU12 .req s15\n" "sU13 .req s16\n" "sU14 .req s2\n" + "sU21 .req s9\n" "sU22 .req s17\n" "sU23 .req s18\n" "sU24 .req s11\n" + "sU31 .req s12\n" "sU32 .req s19\n" "sU33 .req s20\n" "sU34 .req s14\n" + "sU41 .req s6\n" "sU42 .req s21\n" "sU43 .req s22\n" "sU44 .req s8\n" + + "half .req v23\n" // {0.5, ..., 0.5} + "dup half.4s, %w[one_half]\n" + "scratch .req v24\n" + + // Subtract the tail from the number of remaining channels and jump to + // the tail if necessary. + "subs %x[n_remaining_channels], %x[n_remaining_channels], #1\n" + "beq 2f\n" + + "1:" + // Load tile of the kernel + "ldr qw_11, [%x[inptr0]]\n" + "str qU11, [%x[outptr0]]\n" + "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" + "str qU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "ldr qw_21, [%x[inptr1]]\n" + "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x10\n" + + "ldr qw_31, [%x[inptr2]]\n" + "str qU41, [%x[outptr12]]\n" + "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" + "str qU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x10\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.4s, w_11.4s, w_31.4s\n" + "fmul Ww21.4s, scratch.4s, half.4s\n" + "fmla Ww21.4s, w_21.4s, half.4s\n" + "str qU21, [%x[outptr4]]\n" + "fmul Ww31.4s, scratch.4s, half.4s\n" + "fmls Ww31.4s, w_21.4s, half.4s\n" + "str qU31, [%x[outptr8]]\n" + + "fadd scratch.4s, w_12.4s, w_32.4s\n" + "fmul Ww22.4s, scratch.4s, half.4s\n" + "fmla Ww22.4s, w_22.4s, half.4s\n" + "fmul Ww32.4s, scratch.4s, half.4s\n" + "fmls Ww32.4s, w_22.4s, half.4s\n" + + "fadd scratch.4s, w_13.4s, w_33.4s\n" + "fmul Ww23.4s, scratch.4s, half.4s\n" + "fmla Ww23.4s, w_23.4s, half.4s\n" + "str qU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.4s, scratch.4s, half.4s\n" + "fmls Ww33.4s, w_23.4s, half.4s\n" + "str qU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns + // of U and update output pointers + "fadd scratch.4s, Ww11.4s, Ww13.4s\n" + "fmul U12.4s, scratch.4s, half.4s\n" + "fmla U12.4s, Ww12.4s, half.4s\n" + "str qU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.4s, scratch.4s, half.4s\n" + "fmls U13.4s, Ww12.4s, half.4s\n" + "str qU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x10\n" + + "fadd scratch.4s, Ww21.4s, Ww23.4s\n" + "fmul U22.4s, scratch.4s, half.4s\n" + "fmla U22.4s, Ww22.4s, half.4s\n" + "str qU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.4s, scratch.4s, half.4s\n" + "fmls U23.4s, Ww22.4s, half.4s\n" + "str qU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x10\n" + + "fadd scratch.4s, Ww31.4s, Ww33.4s\n" + "fmul U32.4s, scratch.4s, half.4s\n" + "fmla U32.4s, Ww32.4s, half.4s\n" + "str qU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.4s, scratch.4s, half.4s\n" + "fmls U33.4s, Ww32.4s, half.4s\n" + "str qU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x10\n" + + "fadd scratch.4s, Ww41.4s, Ww43.4s\n" + "fmul U42.4s, scratch.4s, half.4s\n" + "fmla U42.4s, Ww42.4s, half.4s\n" + "str qU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.4s, scratch.4s, half.4s\n" + "fmls U43.4s, Ww42.4s, half.4s\n" + "str qU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x10\n" + + "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" + "bne 1b\n" + + // Tail size 1 + "2:" + // Load tile of the kernel + "ldr sw_11, [%x[inptr0]]\n" + "str sU11, [%x[outptr0]]\n" + "ldr sw_12, [%x[inptr0], %x[colstride1]]\n" + "ldr sw_13, [%x[inptr0], %x[colstride2]]\n" + "str sU14, [%x[outptr0], %x[mstride3]]\n" + "add %x[inptr0], %x[inptr0], #0x04\n" + + "ldr sw_21, [%x[inptr1]]\n" + "ldr sw_22, [%x[inptr1], %x[colstride1]]\n" + "ldr sw_23, [%x[inptr1], %x[colstride2]]\n" + "add %x[inptr1], %x[inptr1], #0x04\n" + + "ldr sw_31, [%x[inptr2]]\n" + "str sU41, [%x[outptr12]]\n" + "ldr sw_32, [%x[inptr2], %x[colstride1]]\n" + "ldr sw_33, [%x[inptr2], %x[colstride2]]\n" + "str sU44, [%x[outptr12], %x[mstride3]]\n" + "add %x[inptr2], %x[inptr2], #0x04\n" + + // Compute 2nd and 3rd rows of Ww + "fadd scratch.2s, w_11.2s, w_31.2s\n" + "fmul Ww21.2s, scratch.2s, half.2s\n" + "fmla Ww21.2s, w_21.2s, half.2s\n" + "str sU21, [%x[outptr4]]\n" + "fmul Ww31.2s, scratch.2s, half.2s\n" + "fmls Ww31.2s, w_21.2s, half.2s\n" + "str sU31, [%x[outptr8]]\n" + + "fadd scratch.2s, w_12.2s, w_32.2s\n" + "fmul Ww22.2s, scratch.2s, half.2s\n" + "fmla Ww22.2s, w_22.2s, half.2s\n" + "fmul Ww32.2s, scratch.2s, half.2s\n" + "fmls Ww32.2s, w_22.2s, half.2s\n" + + "fadd scratch.2s, w_13.2s, w_33.2s\n" + "fmul Ww23.2s, scratch.2s, half.2s\n" + "fmla Ww23.2s, w_23.2s, half.2s\n" + "str sU24, [%x[outptr4], %x[mstride3]]\n" + "fmul Ww33.2s, scratch.2s, half.2s\n" + "fmls Ww33.2s, w_23.2s, half.2s\n" + "str sU34, [%x[outptr8], %x[mstride3]]\n" + + // Compute and store U, only need to compute the 2nd and 3rd columns of + // U and update output pointers + "fadd scratch.2s, Ww11.2s, Ww13.2s\n" + "fmul U12.2s, scratch.2s, half.2s\n" + "fmla U12.2s, Ww12.2s, half.2s\n" + "str sU12, [%x[outptr0], %x[mstride1]]\n" + "fmul U13.2s, scratch.2s, half.2s\n" + "fmls U13.2s, Ww12.2s, half.2s\n" + "str sU13, [%x[outptr0], %x[mstride2]]\n" + "add %x[outptr0], %x[outptr0], #0x04\n" + + "fadd scratch.2s, Ww21.2s, Ww23.2s\n" + "fmul U22.2s, scratch.2s, half.2s\n" + "fmla U22.2s, Ww22.2s, half.2s\n" + "str sU22, [%x[outptr4], %x[mstride1]]\n" + "fmul U23.2s, scratch.2s, half.2s\n" + "fmls U23.2s, Ww22.2s, half.2s\n" + "str sU23, [%x[outptr4], %x[mstride2]]\n" + "add %x[outptr4], %x[outptr4], #0x04\n" + + "fadd scratch.2s, Ww31.2s, Ww33.2s\n" + "fmul U32.2s, scratch.2s, half.2s\n" + "fmla U32.2s, Ww32.2s, half.2s\n" + "str sU32, [%x[outptr8], %x[mstride1]]\n" + "fmul U33.2s, scratch.2s, half.2s\n" + "fmls U33.2s, Ww32.2s, half.2s\n" + "str sU33, [%x[outptr8], %x[mstride2]]\n" + "add %x[outptr8], %x[outptr8], #0x04\n" + + "fadd scratch.2s, Ww41.2s, Ww43.2s\n" + "fmul U42.2s, scratch.2s, half.2s\n" + "fmla U42.2s, Ww42.2s, half.2s\n" + "str sU42, [%x[outptr12], %x[mstride1]]\n" + "fmul U43.2s, scratch.2s, half.2s\n" + "fmls U43.2s, Ww42.2s, half.2s\n" + "str sU43, [%x[outptr12], %x[mstride2]]\n" + "add %x[outptr12], %x[outptr12], #0x04\n" + + // Clear aliases + ".unreq half\n" + ".unreq scratch\n" + ".unreq w_11\n" ".unreq qw_11\n" ".unreq sw_11\n" + ".unreq w_12\n" ".unreq qw_12\n" ".unreq sw_12\n" + ".unreq w_13\n" ".unreq qw_13\n" ".unreq sw_13\n" + ".unreq w_21\n" ".unreq qw_21\n" ".unreq sw_21\n" + ".unreq w_22\n" ".unreq qw_22\n" ".unreq sw_22\n" + ".unreq w_23\n" ".unreq qw_23\n" ".unreq sw_23\n" + ".unreq w_31\n" ".unreq qw_31\n" ".unreq sw_31\n" + ".unreq w_32\n" ".unreq qw_32\n" ".unreq sw_32\n" + ".unreq w_33\n" ".unreq qw_33\n" ".unreq sw_33\n" + ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" + ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" + ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" + ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" + ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" + ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" + ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" + ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" + ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" + ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" + ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" + ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" + ".unreq sU11\n" ".unreq sU12\n" ".unreq sU13\n" ".unreq sU14\n" + ".unreq sU21\n" ".unreq sU22\n" ".unreq sU23\n" ".unreq sU24\n" + ".unreq sU31\n" ".unreq sU32\n" ".unreq sU33\n" ".unreq sU34\n" + ".unreq sU41\n" ".unreq sU42\n" ".unreq sU43\n" ".unreq sU44\n" + + : [inptr0] "+r" (inptr0), + [inptr1] "+r" (inptr1), + [inptr2] "+r" (inptr2), + [outptr0] "+r" (outptr0), + [outptr4] "+r" (outptr4), + [outptr8] "+r" (outptr8), + [outptr12] "+r" (outptr12), + [n_remaining_channels] "+r" (n_remaining_channels) + : [mstride1] "r" (sizeof(float) * mstride), + [mstride2] "r" (sizeof(float) * mstride * 2), + [mstride3] "r" (sizeof(float) * mstride * 3), + [colstride1] "r" (sizeof(float) * kernel_col_stride), + [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), + [one_half] "r" (0.5f) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24" + ); + + // Progression to complete stride + outptr0 += matrix_row_stride - n_output_channels; + outptr4 += matrix_row_stride - n_output_channels; + outptr8 += matrix_row_stride - n_output_channels; + outptr12 += matrix_row_stride - n_output_channels; + } +} +} +#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp new file mode 100644 index 0000000000..0992c0bb44 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +namespace winograd { + /* Transform from the Winograd domain back to the spatial domain. + */ + template + struct Winograd2x2_3x3GemmOutput { + static void execute( + const Tensor4DShape &output_shape, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output + ); + + protected: + /* Specialised implementation method. */ + template + static void _execute( + const Tensor4DShape &output_shape, + T *output, + const T *input, + const int matrix_stride, + const int matrix_row_stride + ); + }; + + /* Two-stage implementation of the transformation from the Winograd domain. + * + * First computes Z.F and then computes (Z.F).Z^T. + */ + template + struct Winograd2x2_3x3GemmOutput_TwoStage { + static void execute( + const Tensor4DShape &output_shape, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output + ); + + protected: + template + static void compute_zf( + const int n_rows, const int n_channels, + T* const zf, const T* const input[16] + ); + + template + static void compute_zfzT( + const Tensor4DShape &output_shape, + T* const output, const T* const zf + ); + }; +} + +#include "output_2x2_3x3/a64_float.hpp" +// #include "output_2x2_3x3/a64_float_two_stage.hpp" + +/*****************************************************************************/ +/* +template +void winograd::Winograd2x2_3x3GemmOutput::execute( + const Tensor4DShape &output_shape, + const int tile_M, + const int tile_N, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output +) { + T* const antipadding = reinterpret_cast(malloc(sizeof(T) * output_shape.n_channels)); + + // Get input pointers + const T* inptrs[16]; + for (int i = 0; i < 16; i++) { + inptrs[i] = matrices[i]; + } + + for (int batch = 0; batch < output_shape.n_batches; batch++) { + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + // Get pointers for each of the 4 output cells required for this computation + T* outptrs[4]; + for (int cell_i = 0, c = 0; cell_i < 2; cell_i++) { + for (int cell_j = 0; cell_j < 2; cell_j++, c++) { + const int i = tile_i*2 + cell_i; + const int j = tile_j*2 + cell_j; + + if (i < output_shape.n_rows && j < output_shape.n_cols) { + outptrs[c] = output + ( + (batch*output_shape.n_rows + i) * output_shape.n_cols + + j) * output_shape.n_channels; + } else { + outptrs[c] = antipadding; + } + } // cell_j + } // cell_i + + for (int n = 0; n < output_shape.n_channels; n++) { + // Read 16 values and progress pointers + T v[16]; + for (int i = 0; i < 16; i++) { + v[i] = *(inptrs[i]++); + } + + // Compute output for 4 pixels + *(outptrs[0]++) = v[ 0] + v[ 1] + v[ 2] + + v[ 4] + v[ 5] + v[ 6] + + v[ 8] + v[ 9] + v[10]; + *(outptrs[1]++) = v[ 1] - v[ 2] - v[ 3] + + v[ 5] - v[ 6] - v[ 7] + + v[ 9] - v[10] - v[11]; + *(outptrs[2]++) = v[ 4] + v[ 5] + v[ 6] - + v[ 8] - v[ 9] - v[10] - + v[12] - v[13] - v[14]; + *(outptrs[3]++) = v[ 5] - v[ 6] - v[ 7] - + v[ 9] + v[10] + v[11] - + v[13] + v[14] + v[15]; + } // output_channel + } // tile_j + } // tile_i + } // batch + + free(antipadding); +} +*/ + +/*****************************************************************************/ +/* +template +void winograd::Winograd2x2_3x3GemmOutput_TwoStage::execute( + const Tensor4DShape &output_shape, + T* const matrices[16], T* const output +) { + // Allocate memory for the intermediate matrices + const int tile_M = iceildiv(output_shape.n_rows, 2); + const int tile_N = iceildiv(output_shape.n_cols, 2); + const int n_rows = output_shape.n_batches * tile_M * tile_N; + const int n_channels = output_shape.n_channels; + T* matrices_zf = reinterpret_cast( + calloc(8 * n_rows * n_channels, sizeof(T)) + ); + + // Perform the first stage transform, computing ZF. + // Specializations should dispatch to different methods based on tail size. + compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); + + // Perform the second stage transform, finishing Z F Z^T - variable dispatch + // based on size of the output. Specialisations can also dispatch based on + // the tail-size of the channel. + if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { + compute_zfzT(output_shape, output, matrices_zf); + } else if (output_shape.n_rows % 2) { + compute_zfzT(output_shape, output, matrices_zf); + } else if (output_shape.n_cols % 2) { + compute_zfzT(output_shape, output, matrices_zf); + } else { + compute_zfzT(output_shape, output, matrices_zf); + } + + free(reinterpret_cast(matrices_zf)); +} + +template +template +void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zf( + const int n_rows, const int n_channels, + T* output, const T* const input[16] +) { + // Extract 8 output pointers + T* outptr[8]; + for (int i = 0; i < 8; i++) { + outptr[i] = output + i*n_rows*n_channels; + } + + // Copy the 16 input pointers + const T* inptr[16]; + for (int i = 0; i < 16; i++) { + inptr[i] = input[i]; + } + + // For every row of the matrices + for (int i = 0; i < n_rows; i++) { + // For every channel + for (int j = 0; j < n_channels; j++) { + // Extract values from the input matrices + T val[16]; + for (int n = 0; n < 16; n++) { + val[n] = *(inptr[n]++); + } + + // Compute output values + *(outptr[0]++) = val[0] + val[1] + val[2]; + *(outptr[1]++) = val[1] - val[2] - val[3]; + *(outptr[2]++) = val[4] + val[5] + val[6]; + *(outptr[3]++) = val[5] - val[6] - val[7]; + *(outptr[4]++) = val[8] + val[9] + val[10]; + *(outptr[5]++) = val[9] - val[10] - val[11]; + *(outptr[6]++) = val[12] + val[13] + val[14]; + *(outptr[7]++) = val[13] - val[14] - val[15]; + } + } +} + +template +template +void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zfzT( + const Tensor4DShape &output_shape, + T* const output, const T* const input +) { + // Sizing information + const int tile_M = output_shape.n_rows / 2; + const int tile_N = output_shape.n_cols / 2; + + const int n_rows = (output_shape.n_batches * + (tile_M + (tail_M ? 1 : 0)) * + (tile_N + (tail_N ? 1 : 0))); + const int n_channels = output_shape.n_channels; + + // Extract 8 input pointers + const T* inptr[8]; + for (int i = 0; i < 8; i++) { + inptr[i] = input + i*n_rows*n_channels; + } + + // Extract 4 output pointers + T* outptr00 = output; + T* outptr01 = outptr00 + n_channels; + T* outptr10 = outptr00 + output_shape.n_cols * n_channels; + T* outptr11 = outptr10 + n_channels; + + // Progress over the output tiles, generating output values. + for (int batch = 0; batch < output_shape.n_batches; batch++) { + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + T v[8]; + for (int i = 0; i < 8; i++) { + v[i] = *(inptr[i]++); + } + + // Compute the output values and progress the output pointers. + *(outptr00++) = v[0] + v[2] + v[4]; + *(outptr01++) = v[1] + v[3] + v[5]; + *(outptr10++) = v[2] - v[4] - v[6]; + *(outptr11++) = v[3] - v[5] - v[7]; + } + + // Progress the output pointers to the next column + outptr00 += n_channels; + outptr01 += n_channels; + outptr10 += n_channels; + outptr11 += n_channels; + } + + if (tail_N) { + // Only evaluate the left-most columns of the output + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + T v[8]; + for (int i = 0; i < 4; i++) { + v[i * 2] = *inptr[i * 2]; + } + for (int i = 0; i < 8; i++) { + inptr[i]++; + } + + // Compute the output values and progress the output pointers. + *(outptr00++) = v[0] + v[2] + v[4]; + *(outptr10++) = v[2] - v[4] - v[6]; + } + + // Progress the output pointers to the next column + outptr01 += n_channels; // Account for being skipped above + outptr11 += n_channels; // Account for being skipped above + } + + // Progress the output pointers to the next row + outptr00 += output_shape.n_cols * n_channels; + outptr01 += output_shape.n_cols * n_channels; + outptr10 += output_shape.n_cols * n_channels; + outptr11 += output_shape.n_cols * n_channels; + } + + if (tail_M) { + // Only work on the upper row of the output + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + T v[8]; + for (int i = 0; i < 8; i++) { + v[i] = *(inptr[i]++); + } + + // Compute the output values and progress the output pointers. + *(outptr00++) = v[0] + v[2] + v[4]; + *(outptr01++) = v[1] + v[3] + v[5]; + } + + // Progress the output pointers to the next column + outptr00 += n_channels; + outptr01 += n_channels; + outptr10 += 2 * n_channels; // Account for being skipped above + outptr11 += 2 * n_channels; // Account for being skipped above + } + + if (tail_N) { + // Only evaluate the upper-left cell of the output + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + T v[8]; + for (int i = 0; i < 3; i++) { + v[i * 2] = *inptr[i * 2]; + } + for (int i = 0; i < 8; i++) { + inptr[i]++; + } + + // Compute the output values and progress the output pointers. + *(outptr00++) = v[0] + v[2] + v[4]; + } + + // Progress the output pointers to the next column + outptr01 += n_channels; // Account for being skipped above + outptr10 += n_channels; // Account for being skipped above + outptr11 += n_channels; // Account for being skipped above + } + } + } +} +*/ diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp new file mode 100644 index 0000000000..5925f9d569 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp @@ -0,0 +1,650 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +/* Float implementation for AArch64. + */ +#ifdef __aarch64__ +namespace winograd { + + +template <> +template <> +inline void Winograd2x2_3x3GemmOutput::_execute( + const Tensor4DShape &output_shape, + float *output, + const float *input, + const int mstride, + const int matrix_row_stride +) { + const int tile_M = output_shape.n_rows / 2; + const int tile_N = output_shape.n_cols / 2; + int batch = output_shape.n_batches; + float *outptr = output; + + const float *inptr0 = input; + const float *inptr4 = input + 4 * mstride; + const float *inptr8 = input + 8 * mstride; + const float *inptr12 = input + 12 * mstride; + + const size_t col_stride = sizeof(float) * output_shape.n_channels; + const size_t row_stride = col_stride * tile_N * 2; + + asm volatile ( + // Aliases for elements of the input matrix `F` + // V-register Q-register + "F11 .req v0\n" "qF11 .req q0\n" + "F12 .req v1\n" "qF12 .req q1\n" + "F13 .req v2\n" "qF13 .req q2\n" + "F14 .req v3\n" "qF14 .req q3\n" + "F21 .req v4\n" "qF21 .req q4\n" + "F22 .req v5\n" "qF22 .req q5\n" + "F23 .req v6\n" "qF23 .req q6\n" + "F24 .req v7\n" "qF24 .req q7\n" + "F31 .req v8\n" "qF31 .req q8\n" + "F32 .req v9\n" "qF32 .req q9\n" + "F33 .req v10\n" "qF33 .req q10\n" + "F34 .req v11\n" "qF34 .req q11\n" + "F41 .req v12\n" "qF41 .req q12\n" + "F42 .req v13\n" "qF42 .req q13\n" + "F43 .req v14\n" "qF43 .req q14\n" + "F44 .req v15\n" "qF44 .req q15\n" + + // Aliases for elements of the intermediate matrix `FZ` + "FZ11 .req v16\n" + "FZ12 .req v17\n" + "FZ21 .req v18\n" + "FZ22 .req v19\n" + "FZ31 .req v20\n" + "FZ32 .req v21\n" + "FZ41 .req v22\n" + "FZ42 .req v23\n" + + // Aliases for elements of the output matrix `f` (called `g` due to case + // insensitivity of aliases). + " g11 .req v24\n" + "qg11 .req q24\n" + " g12 .req v25\n" + "qg12 .req q25\n" + " g21 .req v26\n" + "qg21 .req q26\n" + " g22 .req v27\n" + "qg22 .req q27\n" + + // Prepare the various strides + "col_stride .req %x[col_stride]\n" + "row_stride .req %x[row_stride]\n" + "row_plus_col_stride .req %x[row_plus_col_stride]\n" + + "mstride1 .req %x[mstride1]\n" + "mstride2 .req %x[mstride2]\n" + "mstride3 .req %x[mstride3]\n" + + "tile_i .req x19\n" // Tile row counter + "tile_j .req x20\n" // Tile column counter + "channel .req x21\n" // Channel counter + + "1:" // Loop over batches + "mov tile_i, %x[tile_M]\n" // Reset tile row counter + + "2:" // Loop over rows of tiles + "mov tile_j, %x[tile_N]\n" // Reset tile column counter + + "3:" // Loop over columns of tiles + // Perform initial loads of the matrix `F` + "ldr qF11, [%x[inptr0]]\n" + "ldr qF12, [%x[inptr0], mstride1]\n" + "ldr qF13, [%x[inptr0], mstride2]\n" + "ldr qF14, [%x[inptr0], mstride3]\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + "ldr qF21, [%x[inptr4]]\n" + "ldr qF22, [%x[inptr4], mstride1]\n" + "subs channel, %x[n_channels], #4\n" // Reset channel counter + + "ldr qF23, [%x[inptr4], mstride2]\n" + "ldr qF24, [%x[inptr4], mstride3]\n" + "add %x[inptr4], %x[inptr4], #0x10\n" + "beq 5f\n" // Jump straight to tail if necessary + + "4:" // Loop over channels + "ldr qF31, [%x[inptr8]]\n" + "fadd FZ11.4s, F11.4s, F12.4s\n" + + "ldr qF32, [%x[inptr8], mstride1]\n" + "fsub FZ12.4s, F12.4s, F13.4s\n" + + "ldr qF33, [%x[inptr8], mstride2]\n" + "fadd FZ11.4s, FZ11.4s, F13.4s\n" + + "ldr qF34, [%x[inptr8], mstride3]\n" + "fsub FZ12.4s, FZ12.4s, F14.4s\n" + + "ldr qF41, [%x[inptr12]]\n" + "fadd FZ21.4s, F21.4s, F22.4s\n" + + "ldr qF42, [%x[inptr12], mstride1]\n" + "fsub FZ22.4s, F22.4s, F23.4s\n" + + "ldr qF43, [%x[inptr12], mstride2]\n" + "fadd FZ21.4s, FZ21.4s, F23.4s\n" + + "ldr qF44, [%x[inptr12], mstride3]\n" + "fsub FZ22.4s, FZ22.4s, F24.4s\n" + + "fadd FZ31.4s, F31.4s, F32.4s\n" + "add %x[inptr8], %x[inptr8], #0x10\n" + + "fsub FZ32.4s, F32.4s, F33.4s\n" + "add %x[inptr12], %x[inptr12], #0x10\n" + + "fadd FZ31.4s, FZ31.4s, F33.4s\n" + + "fsub FZ32.4s, FZ32.4s, F34.4s\n" + + "fadd g11.4s, FZ11.4s, FZ21.4s\n" + + "fadd g12.4s, FZ12.4s, FZ22.4s\n" + + "fadd g11.4s, g11.4s, FZ31.4s\n" + + "fadd g12.4s, g12.4s, FZ32.4s\n" + + "ldr qF11, [%x[inptr0]]\n" + "fadd FZ41.4s, F41.4s, F42.4s\n" + + "ldr qF12, [%x[inptr0], mstride1]\n" + "fsub g21.4s, FZ21.4s, FZ31.4s\n" + + "ldr qF13, [%x[inptr0], mstride2]\n" + "fsub FZ42.4s, F42.4s, F43.4s\n" + + "ldr qF14, [%x[inptr0], mstride3]\n" + "str qg11, [%x[outptr]]\n" + + "ldr qF21, [%x[inptr4]]\n" + "fadd FZ41.4s, FZ41.4s, F43.4s\n" + + "ldr qF22, [%x[inptr4], mstride1]\n" + "str qg12, [%x[outptr], col_stride]\n" + + "ldr qF23, [%x[inptr4], mstride2]\n" + "fsub FZ42.4s, FZ42.4s, F44.4s\n" + + "ldr qF24, [%x[inptr4], mstride3]\n" + "fsub g22.4s, FZ22.4s, FZ32.4s\n" + + "fsub g21.4s, g21.4s, FZ41.4s\n" + "add %x[inptr0], %x[inptr0], #0x10\n" + + "fsub g22.4s, g22.4s, FZ42.4s\n" + "add %x[inptr4], %x[inptr4], #0x10\n" + + "subs channel, channel, #4\n" + + "str qg21, [%x[outptr], row_stride]\n" + + "str qg22, [%x[outptr], row_plus_col_stride]\n" + + "add %x[outptr], %x[outptr], #0x10\n" + + "bne 4b\n" + + "5:" // Channel tail + "ldr qF31, [%x[inptr8]]\n" + "fadd FZ11.4s, F11.4s, F12.4s\n" + + "ldr qF32, [%x[inptr8], mstride1]\n" + "fsub FZ12.4s, F12.4s, F13.4s\n" + + "ldr qF33, [%x[inptr8], mstride2]\n" + "fadd FZ11.4s, FZ11.4s, F13.4s\n" + + "ldr qF34, [%x[inptr8], mstride3]\n" + "fsub FZ12.4s, FZ12.4s, F14.4s\n" + + "ldr qF41, [%x[inptr12]]\n" + "fadd FZ21.4s, F21.4s, F22.4s\n" + + "ldr qF42, [%x[inptr12], mstride1]\n" + "fsub FZ22.4s, F22.4s, F23.4s\n" + + "ldr qF43, [%x[inptr12], mstride2]\n" + "fadd FZ21.4s, FZ21.4s, F23.4s\n" + + "ldr qF44, [%x[inptr12], mstride3]\n" + "fsub FZ22.4s, FZ22.4s, F24.4s\n" + + "fadd FZ31.4s, F31.4s, F32.4s\n" + "add %x[inptr8], %x[inptr8], #0x10\n" + + "fsub FZ32.4s, F32.4s, F33.4s\n" + "add %x[inptr12], %x[inptr12], #0x10\n" + + "fadd FZ31.4s, FZ31.4s, F33.4s\n" + + "fsub FZ32.4s, FZ32.4s, F34.4s\n" + + "fadd g11.4s, FZ11.4s, FZ21.4s\n" + + "fadd g12.4s, FZ12.4s, FZ22.4s\n" + + "fadd g11.4s, g11.4s, FZ31.4s\n" + + "fadd g12.4s, g12.4s, FZ32.4s\n" + + "fadd FZ41.4s, F41.4s, F42.4s\n" + + "fsub g21.4s, FZ21.4s, FZ31.4s\n" + + "fsub FZ42.4s, F42.4s, F43.4s\n" + + "str qg11, [%x[outptr]]\n" + + "fadd FZ41.4s, FZ41.4s, F43.4s\n" + + "str qg12, [%x[outptr], col_stride]\n" + + "fsub FZ42.4s, FZ42.4s, F44.4s\n" + + "fsub g22.4s, FZ22.4s, FZ32.4s\n" + + "fsub g21.4s, g21.4s, FZ41.4s\n" + + "fsub g22.4s, g22.4s, FZ42.4s\n" + + "subs channel, channel, #4\n" + + "str qg21, [%x[outptr], row_stride]\n" + + // Progress input pointers to the next row of the matrix + "add %x[inptr0], %x[inptr0], %x[mrowpad]\n" + "add %x[inptr4], %x[inptr4], %x[mrowpad]\n" + "add %x[inptr8], %x[inptr8], %x[mrowpad]\n" + "add %x[inptr12], %x[inptr12], %x[mrowpad]\n" + + "str qg22, [%x[outptr], row_plus_col_stride]\n" + + "add %x[outptr], %x[outptr], #0x10\n" + + + "add %x[outptr], %x[outptr], col_stride\n" + "subs tile_j, tile_j, #1\n" + "bne 3b\n" + + "add %x[outptr], %x[outptr], row_stride\n" + "subs tile_i, tile_i, #1\n" + "bne 2b\n" + + "subs %[batch], %[batch], #1\n" + "bne 1b\n" + + ".unreq F11\n" ".unreq qF11\n" + ".unreq F12\n" ".unreq qF12\n" + ".unreq F13\n" ".unreq qF13\n" + ".unreq F14\n" ".unreq qF14\n" + ".unreq F21\n" ".unreq qF21\n" + ".unreq F22\n" ".unreq qF22\n" + ".unreq F23\n" ".unreq qF23\n" + ".unreq F24\n" ".unreq qF24\n" + ".unreq F31\n" ".unreq qF31\n" + ".unreq F32\n" ".unreq qF32\n" + ".unreq F33\n" ".unreq qF33\n" + ".unreq F34\n" ".unreq qF34\n" + ".unreq F41\n" ".unreq qF41\n" + ".unreq F42\n" ".unreq qF42\n" + ".unreq F43\n" ".unreq qF43\n" + ".unreq F44\n" ".unreq qF44\n" + + ".unreq FZ11\n" ".unreq FZ12\n" + ".unreq FZ21\n" ".unreq FZ22\n" + ".unreq FZ31\n" ".unreq FZ32\n" + ".unreq FZ41\n" ".unreq FZ42\n" + + ".unreq g11\n" ".unreq qg11\n" + ".unreq g12\n" ".unreq qg12\n" + ".unreq g21\n" ".unreq qg21\n" + ".unreq g22\n" ".unreq qg22\n" + + ".unreq col_stride\n" + ".unreq row_stride\n" + ".unreq row_plus_col_stride\n" + + ".unreq mstride1\n" + ".unreq mstride2\n" + ".unreq mstride3\n" + + ".unreq tile_i \n" + ".unreq tile_j \n" + ".unreq channel\n" + + : [batch] "+r" (batch), + [outptr] "+r" (outptr), + [inptr0] "+r" (inptr0), + [inptr4] "+r" (inptr4), + [inptr8] "+r" (inptr8), + [inptr12] "+r" (inptr12) + : [tile_M] "r" (tile_M), + [tile_N] "r" (tile_N), + [n_channels] "r" (output_shape.n_channels), + [col_stride] "r" (col_stride), + [row_stride] "r" (row_stride), + [row_plus_col_stride] "r" (row_stride + col_stride), + [mstride1] "r" (mstride * sizeof(float)), + [mstride2] "r" (2 * mstride * sizeof(float)), + [mstride3] "r" (3 * mstride * sizeof(float)), + [mrowpad] "r" ((matrix_row_stride - output_shape.n_channels) * sizeof(float)) + : "x19", "x20", "x21", + "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "q16", "q17", "q18", "q19", "q20", "q21", + "q22", "q23", "q24", "q25", "q26", "q27", + "cc", "memory" + ); +} + +template <> +template +inline void Winograd2x2_3x3GemmOutput::_execute( + const Tensor4DShape &output_shape, + float *output, + const float *input, + const int mstride, + const int matrix_row_stride +) { + // Compute basic information about the shape of the matrices + const int tile_M = output_shape.n_rows / 2; + const int tile_N = output_shape.n_cols / 2; + const int n_channels = output_shape.n_channels; + + // Extract 16 input pointers + const float* inptr[16]; + for (int i = 0; i < 16; i++) { + inptr[i] = input + i*mstride; + } + + // Extract 4 output pointers + float *outptr00 = output; + float *outptr01 = outptr00 + n_channels; + float *outptr10 = outptr00 + output_shape.n_cols * n_channels; + float *outptr11 = outptr10 + n_channels; + + // Progress over the output tiles, generating output values. + for (int batch = 0; batch < output_shape.n_batches; batch++) { + for (int tile_i = 0; tile_i < tile_M; tile_i++) { + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + float F[4][4]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + F[i][j] = *(inptr[i*4 + j]++); + } + } + + // Compute the matrix F.Z + float ZF[4][2]; + ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; + ZF[0][1] = F[0][1] - F[0][2] - F[0][3]; + ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; + ZF[1][1] = F[1][1] - F[1][2] - F[1][3]; + ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; + ZF[2][1] = F[2][1] - F[2][2] - F[2][3]; + ZF[3][0] = F[3][0] + F[3][1] + F[3][2]; + ZF[3][1] = F[3][1] - F[3][2] - F[3][3]; + + // Hence compute the output matrix Z^T . (F.Z) + *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; + *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1]; + *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0]; + *(outptr11++) = ZF[1][1] - ZF[2][1] - ZF[3][1]; + } + + // Progress the input pointers to the next row + for (int i = 0; i < 16; i++) { + inptr[i] += matrix_row_stride - n_channels; + } + + // Progress the output pointers to the next column + outptr00 += n_channels; + outptr01 += n_channels; + outptr10 += n_channels; + outptr11 += n_channels; + } + + if (tail_N) { + // Only evaluate the left-most columns of the output + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + float F[4][3]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 3; j++) { + F[i][j] = *(inptr[i*4 + j]++); + } + } + for (int i = 0; i < 4; i++) { + inptr[i*4 + 3]++; + } + + // Compute the matrix F.Z + float ZF[4][1]; + ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; + ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; + ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; + ZF[3][0] = F[3][0] + F[3][1] + F[3][2]; + + // Hence compute the output matrix Z^T . (F.Z) + *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; + *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0]; + } + + // Progress the input pointers to the next row + for (int i = 0; i < 16; i++) { + inptr[i] += matrix_row_stride - n_channels; + } + + // Progress the output pointers to the next column + outptr01 += n_channels; // Account for being skipped above + outptr11 += n_channels; // Account for being skipped above + } + + // Progress the output pointers to the next row + outptr00 += output_shape.n_cols * n_channels; + outptr01 += output_shape.n_cols * n_channels; + outptr10 += output_shape.n_cols * n_channels; + outptr11 += output_shape.n_cols * n_channels; + } + + if (tail_M) { + // Only work on the upper row of the output + for (int tile_j = 0; tile_j < tile_N; tile_j++) { + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + float F[3][4]; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 4; j++) { + F[i][j] = *(inptr[i*4 + j]++); + } + } + for (int j = 0; j < 4; j++) { + inptr[12 + j]++; + } + + // Compute the matrix F.Z + float ZF[3][2]; + ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; + ZF[0][1] = F[0][1] - F[0][2] - F[0][3]; + ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; + ZF[1][1] = F[1][1] - F[1][2] - F[1][3]; + ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; + ZF[2][1] = F[2][1] - F[2][2] - F[2][3]; + + // Hence compute the output matrix Z^T . (F.Z) + *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; + *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1]; + } + + // Progress the input pointers to the next row + for (int i = 0; i < 16; i++) { + inptr[i] += matrix_row_stride - n_channels; + } + + // Progress the output pointers to the next column + outptr00 += n_channels; + outptr01 += n_channels; + outptr10 += 2 * n_channels; // Account for being skipped above + outptr11 += 2 * n_channels; // Account for being skipped above + } + + if (tail_N) { + // Only evaluate the upper-left cell of the output + for (int channel = 0; channel < n_channels; channel++) { + // Read values from the input pointers + float F[3][3]; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + F[i][j] = *(inptr[i*4 + j]); + } + } + for (int i = 0; i < 16; i++) { + inptr[i]++; + } + + // Compute the matrix F.Z + float ZF[3][1]; + ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; + ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; + ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; + + // Hence compute the output matrix Z^T . (F.Z) + *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; + } + + // Progress the input pointers to the next row + for (int i = 0; i < 16; i++) { + inptr[i] += matrix_row_stride - n_channels; + } + + // Progress the output pointers to the next column + outptr01 += n_channels; // Account for being skipped above + outptr10 += n_channels; // Account for being skipped above + outptr11 += n_channels; // Account for being skipped above + } + } + } +} + +/*****************************************************************************/ +template <> +inline void Winograd2x2_3x3GemmOutput::execute( + const Tensor4DShape &output_shape, + float* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + float* const output +) { + // Dispatch to an appropriate implementation based on the shape of the output + // tensor. + if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { + constexpr bool tail_M = true, tail_N = true; + switch (output_shape.n_channels % 4) { + case 0: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 1: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 2: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 3: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + default: + assert(0); + break; + } + } else if (output_shape.n_rows % 2) { + constexpr bool tail_M = true, tail_N = false; + switch (output_shape.n_channels % 4) { + case 0: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 1: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 2: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 3: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + default: + assert(0); + break; + } + } else if (output_shape.n_cols % 2) { + constexpr bool tail_M = false, tail_N = true; + switch (output_shape.n_channels % 4) { + case 0: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 1: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 2: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 3: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + default: + assert(0); + break; + + } + } else { + constexpr bool tail_M = false, tail_N = false; + switch (output_shape.n_channels % 4) { + case 0: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 1: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 2: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + case 3: + _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); + break; + default: + assert(0); + break; + + } + } +} +/*****************************************************************************/ + +} // namespace winograd +#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp new file mode 100644 index 0000000000..f551b12b52 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp @@ -0,0 +1,655 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +/*****************************************************************************/ +// Compute ZF specializations + +template <> +template <> +inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zf<0>( + const int n_rows, const int n_channels, + float* output, const float* const input[16] +) { + // Make copies of some variables + int row = n_rows; + float* outptr = output; + const float* inptr = input[0]; + + // Perform the transformation + asm volatile ( + // "inptr0 .req %x[inptr]\n" + "inptr1 .req x0\n" + "inptr2 .req x1\n" + "inptr3 .req x2\n" + "inptr4 .req x3\n" + "inptr5 .req x4\n" + "inptr6 .req x5\n" + "inptr7 .req x6\n" + "inptr8 .req x7\n" + "inptr9 .req x8\n" + "inptr10 .req x9\n" + "inptr11 .req x10\n" + "inptr12 .req x11\n" + "inptr13 .req x12\n" + "inptr14 .req x13\n" + "inptr15 .req x14\n" + + // "outptr0 .req %x[outptr]\n" + "outptr1 .req x15\n" + "outptr2 .req x16\n" + "outptr3 .req x17\n" + "outptr4 .req x18\n" + "outptr5 .req x19\n" + "outptr6 .req x20\n" + "outptr7 .req x21\n" + + // Compute additional pointers into the input and output matrices. + "mstride .req x22\n" // Matrix stride + "mul mstride, %x[row], %x[n_channels]\n" + "lsl mstride, mstride, #2\n" // * sizeof(float) + + "add inptr1, %x[inptr], mstride\n" + "add inptr2, %x[inptr], mstride, LSL #1\n" + "add inptr3, inptr2, mstride\n" + "add inptr4, inptr3, mstride\n" + "add inptr5, inptr4, mstride\n" + "add inptr6, inptr5, mstride\n" + "add inptr7, inptr6, mstride\n" + "add inptr8, inptr7, mstride\n" + "add inptr9, inptr8, mstride\n" + "add inptr10, inptr9, mstride\n" + "add inptr11, inptr10, mstride\n" + "add inptr12, inptr11, mstride\n" + "add inptr13, inptr12, mstride\n" + "add inptr14, inptr13, mstride\n" + "add inptr15, inptr14, mstride\n" + + "add outptr1, %[outptr], mstride\n" + "add outptr2, outptr1, mstride\n" + "add outptr3, outptr2, mstride\n" + "add outptr4, outptr3, mstride\n" + "add outptr5, outptr4, mstride\n" + "add outptr6, outptr5, mstride\n" + "add outptr7, outptr6, mstride\n" + + ".unreq mstride\n" + + "column .req x22\n" // Column loop counter + + "1:" // Loop over rows + "ldr q0, [%x[inptr]], #0x10\n" + "ldr q1, [inptr1], #0x10\n" + "ldr q2, [inptr2], #0x10\n" + "ldr q3, [inptr3], #0x10\n" + "ldr q4, [inptr4], #0x10\n" + "ldr q5, [inptr5], #0x10\n" + "ldr q6, [inptr6], #0x10\n" + "ldr q7, [inptr7], #0x10\n" + "subs column, %x[n_channels], #0x4\n" + "beq 3f\n" + + "2:" // Loop over columns + "ldr q8, [inptr8], #0x10\n" + "prfm pldl1keep, [%x[inptr], #196]\n" + "fadd v16.4s, v0.4s, v1.4s\n" + + "ldr q9, [inptr9], #0x10\n" + "prfm pldl1keep, [inptr1, #196]\n" + "fsub v17.4s, v1.4s, v2.4s\n" + + "ldr q10, [inptr10], #0x10\n" + "prfm pldl1keep, [inptr2, #196]\n" + "fadd v16.4s, v16.4s, v2.4s\n" + + "ldr q11, [inptr11], #0x10\n" + "prfm pldl1keep, [inptr3, #196]\n" + "fsub v17.4s, v17.4s, v3.4s\n" + + "ldr q12, [inptr12], #0x10\n" + "prfm pldl1keep, [inptr4, #196]\n" + "str q16, [%x[outptr]], #0x10\n" + + "ldr q13, [inptr13], #0x10\n" + "prfm pldl1keep, [inptr5, #196]\n" + "str q17, [outptr1], #0x10\n" + + "ldr q14, [inptr14], #0x10\n" + "prfm pldl1keep, [inptr6, #196]\n" + "fadd v16.4s, v4.4s, v5.4s\n" + + "ldr q15, [inptr15], #0x10\n" + "prfm pldl1keep, [inptr7, #196]\n" + "fsub v17.4s, v5.4s, v6.4s\n" + + "ldr q0, [%x[inptr]], #0x10\n" + "prfm pldl1keep, [inptr8, #196]\n" + "fadd v16.4s, v16.4s, v6.4s\n" + + "ldr q1, [inptr1], #0x10\n" + "prfm pldl1keep, [inptr9, #196]\n" + "fsub v17.4s, v17.4s, v7.4s\n" + + "ldr q2, [inptr2], #0x10\n" + "prfm pldl1keep, [inptr10, #196]\n" + "str q16, [outptr2], #0x10\n" + + "ldr q3, [inptr3], #0x10\n" + "prfm pldl1keep, [inptr11, #196]\n" + "str q17, [outptr3], #0x10\n" + + "ldr q4, [inptr4], #0x10\n" + "prfm pldl1keep, [inptr12, #196]\n" + "fadd v16.4s, v8.4s, v9.4s\n" + + "ldr q5, [inptr5], #0x10\n" + "prfm pldl1keep, [inptr13, #196]\n" + "fsub v17.4s, v9.4s, v10.4s\n" + + "ldr q6, [inptr6], #0x10\n" + "prfm pldl1keep, [inptr14, #196]\n" + "fadd v16.4s, v16.4s, v10.4s\n" + + "ldr q7, [inptr7], #0x10\n" + "prfm pldl1keep, [inptr15, #196]\n" + "fsub v17.4s, v17.4s, v11.4s\n" + + "str q16, [outptr4], #0x10\n" + "fadd v16.4s, v12.4s, v13.4s\n" + "fsub v18.4s, v13.4s, v14.4s\n" + + "str q17, [outptr5], #0x10\n" + "fadd v16.4s, v16.4s, v14.4s\n" + "fsub v18.4s, v18.4s, v15.4s\n" + + "str q16, [outptr6], #0x10\n" + "subs column, column, #0x4\n" + + "str q18, [outptr7], #0x10\n" + "bne 2b\n" + + "3:" // Tail + "ldr q8, [inptr8], #0x10\n" + "prfm pldl1keep, [%x[inptr], #196]\n" + "fadd v16.4s, v0.4s, v1.4s\n" + + "ldr q9, [inptr9], #0x10\n" + "prfm pldl1keep, [inptr1, #196]\n" + "fsub v17.4s, v1.4s, v2.4s\n" + + "ldr q10, [inptr10], #0x10\n" + "prfm pldl1keep, [inptr2, #196]\n" + "fadd v16.4s, v16.4s, v2.4s\n" + + "ldr q11, [inptr11], #0x10\n" + "prfm pldl1keep, [inptr3, #196]\n" + "fsub v17.4s, v17.4s, v3.4s\n" + + "ldr q12, [inptr12], #0x10\n" + "prfm pldl1keep, [inptr4, #196]\n" + "str q16, [%x[outptr]], #0x10\n" + + "ldr q13, [inptr13], #0x10\n" + "prfm pldl1keep, [inptr5, #196]\n" + "str q17, [outptr1], #0x10\n" + + "ldr q14, [inptr14], #0x10\n" + "prfm pldl1keep, [inptr6, #196]\n" + "fadd v16.4s, v4.4s, v5.4s\n" + + "ldr q15, [inptr15], #0x10\n" + "prfm pldl1keep, [inptr7, #196]\n" + "fsub v17.4s, v5.4s, v6.4s\n" + + "prfm pldl1keep, [inptr8, #196]\n" + "prfm pldl1keep, [inptr9, #196]\n" + "fadd v16.4s, v16.4s, v6.4s\n" + + "prfm pldl1keep, [inptr10, #196]\n" + "prfm pldl1keep, [inptr11, #196]\n" + "fsub v17.4s, v17.4s, v7.4s\n" + + "prfm pldl1keep, [inptr12, #196]\n" + "prfm pldl1keep, [inptr13, #196]\n" + "str q16, [outptr2], #0x10\n" + + "prfm pldl1keep, [inptr14, #196]\n" + "prfm pldl1keep, [inptr15, #196]\n" + "str q17, [outptr3], #0x10\n" + + "fadd v16.4s, v8.4s, v9.4s\n" + "fsub v17.4s, v9.4s, v10.4s\n" + + "fadd v16.4s, v16.4s, v10.4s\n" + "fsub v17.4s, v17.4s, v11.4s\n" + + "str q16, [outptr4], #0x10\n" + "fadd v16.4s, v12.4s, v13.4s\n" + "fsub v18.4s, v13.4s, v14.4s\n" + + "str q17, [outptr5], #0x10\n" + "fadd v16.4s, v16.4s, v14.4s\n" + "fsub v18.4s, v18.4s, v15.4s\n" + + "str q16, [outptr6], #0x10\n" + "str q18, [outptr7], #0x10\n" + + "subs %x[row], %x[row], #0x1\n" + "bne 1b\n" + + ".unreq inptr1\n" + ".unreq inptr2\n" + ".unreq inptr3\n" + ".unreq inptr4\n" + ".unreq inptr5\n" + ".unreq inptr6\n" + ".unreq inptr7\n" + ".unreq inptr8\n" + ".unreq inptr9\n" + ".unreq inptr10\n" + ".unreq inptr11\n" + ".unreq inptr12\n" + ".unreq inptr13\n" + ".unreq inptr14\n" + ".unreq inptr15\n" + ".unreq outptr1\n" + ".unreq outptr2\n" + ".unreq outptr3\n" + ".unreq outptr4\n" + ".unreq outptr5\n" + ".unreq outptr6\n" + ".unreq outptr7\n" + + : [row] "+r" (row), + [inptr] "+r" (inptr), + [outptr] "+r" (outptr) + : [n_channels] "r" (n_channels), + [sizeof_float] "i" (sizeof(float)) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "q16", "q17", "x0", "x1", "x2", "x3", "x4", + "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", + "x16", "x17", "x18", "x19", "x20", "x21", "x22", "cc", "memory" + ); +} + +/*****************************************************************************/ +// Compute ZFZ^T specializations + +template <> +template <> +inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zfzT( + const Tensor4DShape &output_shape, + float* const output, const float* const input +) { + const int tile_M = output_shape.n_rows / 2; + const int tile_N = output_shape.n_cols / 2; + int batch = output_shape.n_batches; + float *outptr = output; + const float *inptr = input; + + asm volatile ( + // Compute input pointers + "inptr1 .req x0\n" + "inptr2 .req x1\n" + "inptr3 .req x2\n" + "inptr4 .req x3\n" + "inptr5 .req x4\n" + "inptr6 .req x5\n" + "inptr7 .req x6\n" + "inptr8 .req x7\n" + + "mstride .req x8\n" + "mul mstride, %x[tile_M], %x[tile_N]\n" + "mul mstride, mstride, %x[n_channels]\n" + "lsl mstride, mstride, #2\n" // * sizeof(float) + + "add inptr1, %[inptr], mstride\n" + "add inptr2, inptr1, mstride\n" + "add inptr3, inptr2, mstride\n" + "add inptr4, inptr3, mstride\n" + "add inptr5, inptr4, mstride\n" + "add inptr6, inptr5, mstride\n" + "add inptr7, inptr6, mstride\n" + "add inptr8, inptr7, mstride\n" + + ".unreq mstride\n" + + // Compute initial output pointers + "outptr01 .req x8\n" + "outptr10 .req x9\n" + "outptr11 .req x10\n" + + "add outptr01, %x[outptr], %x[n_channels], LSL #2\n" + "add outptr10, %x[outptr], %x[row_stride], LSL #2\n" + "add outptr11, outptr10, %x[n_channels], LSL #2\n" + + "tile_i .req x11\n" + "tile_j .req x12\n" + "channel .req x13\n" + + "1:" // Loop over batches + "mov tile_i, %x[tile_M]\n" + + "2:" // Loop over rows of output tiles + "mov tile_j, %x[tile_N]\n" + + "3:" // Loop over columns of output tiles + "ldr q0, [%x[inptr]], #0x10\n" + "ldr q2, [inptr2], #0x10\n" + "subs channel, %x[n_channels], #0x4\n" + + "ldr q1, [inptr1], #0x10\n" + "ldr q3, [inptr3], #0x10\n" + "beq 6f\n" + + "4:" + "ldr q4, [inptr4], #0x10\n" + "ldr q5, [inptr5], #0x10\n" + "fadd v16.4s, v0.4s, v2.4s\n" + + "ldr q6, [inptr6], #0x10\n" + "ldr q7, [inptr7], #0x10\n" + "fadd v17.4s, v1.4s, v3.4s\n" + + "ldr q8, [%x[inptr]], #0x10\n" + "ldr q10, [inptr2], #0x10\n" + "fadd v16.4s, v16.4s, v4.4s\n" + + "ldr q9, [inptr1], #0x10\n" + "ldr q11, [inptr3], #0x10\n" + "fadd v17.4s, v17.4s, v5.4s\n" + + "str q16, [%x[outptr]], #0x10\n" + "prfm pldl1strm, [%x[inptr], #196]\n" + "fsub v18.4s, v2.4s, v4.4s\n" + + "str q17, [outptr01], #0x10\n" + "prfm pldl1strm, [inptr2, #196]\n" + "fsub v19.4s, v3.4s, v5.4s\n" + + "prfm pldl1strm, [inptr1, #196]\n" + "prfm pldl1strm, [inptr3, #196]\n" + "fsub v18.4s, v18.4s, v6.4s\n" + + "prfm pldl1strm, [inptr4, #196]\n" + "prfm pldl1strm, [inptr5, #196]\n" + "fsub v19.4s, v19.4s, v7.4s\n" + + "str q18, [outptr10], #0x10\n" + "prfm pldl1strm, [inptr6, #196]\n" + "prfm pldl1strm, [inptr7, #196]\n" + + "subs channel, channel, #0x4\n" + + "str q19, [outptr11], #0x10\n" + "beq 6f\n" // Branch to tail + + "ldr q12, [inptr4], #0x10\n" + "ldr q13, [inptr5], #0x10\n" + "fadd v16.4s, v8.4s, v10.4s\n" + + "ldr q14, [inptr6], #0x10\n" + "ldr q15, [inptr7], #0x10\n" + "fadd v17.4s, v9.4s, v11.4s\n" + + "ldr q0, [%x[inptr]], #0x10\n" + "ldr q2, [inptr2], #0x10\n" + "fadd v16.4s, v16.4s, v12.4s\n" + + "ldr q1, [inptr1], #0x10\n" + "ldr q3, [inptr3], #0x10\n" + "fadd v17.4s, v17.4s, v13.4s\n" + + "str q16, [%x[outptr]], #0x10\n" + "prfm pldl1strm, [%x[inptr], #196]\n" + "fsub v18.4s, v10.4s, v12.4s\n" + + "str q17, [outptr01], #0x10\n" + "prfm pldl1strm, [inptr2, #196]\n" + "fsub v19.4s, v11.4s, v13.4s\n" + + "prfm pldl1strm, [inptr1, #196]\n" + "prfm pldl1strm, [inptr3, #196]\n" + "fsub v18.4s, v18.4s, v14.4s\n" + + "prfm pldl1strm, [inptr4, #196]\n" + "prfm pldl1strm, [inptr5, #196]\n" + "fsub v19.4s, v19.4s, v15.4s\n" + + "str q18, [outptr10], #0x10\n" + "prfm pldl1strm, [inptr6, #196]\n" + "prfm pldl1strm, [inptr7, #196]\n" + + "subs channel, channel, #0x4\n" + + "str q19, [outptr11], #0x10\n" + "bne 4b\n" // Continue loop + + "5:" // Tail + "ldr q12, [inptr4], #0x10\n" + "ldr q13, [inptr5], #0x10\n" + "fadd v16.4s, v8.4s, v10.4s\n" + + "ldr q14, [inptr6], #0x10\n" + "ldr q15, [inptr7], #0x10\n" + "fadd v17.4s, v9.4s, v11.4s\n" + + "fadd v16.4s, v16.4s, v12.4s\n" + + "fadd v17.4s, v17.4s, v13.4s\n" + + "str q16, [%x[outptr]], #0x10\n" + "fsub v18.4s, v10.4s, v12.4s\n" + "fsub v19.4s, v11.4s, v13.4s\n" + + "str q17, [outptr01], #0x10\n" + "fsub v18.4s, v18.4s, v14.4s\n" + "fsub v19.4s, v19.4s, v15.4s\n" + + "str q18, [outptr10], #0x10\n" + "str q19, [outptr11], #0x10\n" + "b 7f\n" + + "6:" // Tail + "ldr q4, [inptr4], #0x10\n" + "ldr q5, [inptr5], #0x10\n" + "fadd v16.4s, v0.4s, v2.4s\n" + + "ldr q6, [inptr6], #0x10\n" + "ldr q7, [inptr7], #0x10\n" + "fadd v17.4s, v1.4s, v3.4s\n" + + "fadd v16.4s, v16.4s, v4.4s\n" + + "fadd v17.4s, v17.4s, v5.4s\n" + + "str q16, [%x[outptr]], #0x10\n" + "fsub v18.4s, v2.4s, v4.4s\n" + "fsub v19.4s, v3.4s, v5.4s\n" + + "str q17, [outptr01], #0x10\n" + "fsub v18.4s, v18.4s, v6.4s\n" + "fsub v19.4s, v19.4s, v7.4s\n" + + "str q18, [outptr10], #0x10\n" + "str q19, [outptr11], #0x10\n" + + "7:" + "add %x[outptr], %x[outptr], %x[n_channels], LSL #2\n" + "add outptr01, outptr01, %x[n_channels], LSL #2\n" + "add outptr10, outptr10, %x[n_channels], LSL #2\n" + "add outptr11, outptr11, %x[n_channels], LSL #2\n" + + "subs tile_j, tile_j, #1\n" + "bne 3b\n" + + // Progress the output pointers to the new row + "add %x[outptr], %x[outptr], %x[row_stride], LSL #2\n" + "add outptr01, outptr01, %x[row_stride], LSL #2\n" + "add outptr10, outptr10, %x[row_stride], LSL #2\n" + "add outptr11, outptr11, %x[row_stride], LSL #2\n" + + "subs tile_i, tile_i, #1\n" + "bne 2b\n" + + "subs %[batch], %[batch], #1\n" + "bne 1b\n" + "5:" + + ".unreq inptr1\n" + ".unreq inptr2\n" + ".unreq inptr3\n" + ".unreq inptr4\n" + ".unreq inptr5\n" + ".unreq inptr6\n" + ".unreq inptr7\n" + ".unreq inptr8\n" + ".unreq outptr01\n" + ".unreq outptr10\n" + ".unreq outptr11\n" + : [batch] "+r" (batch), + [outptr] "+r" (outptr), + [inptr] "+r" (inptr) + : [tile_M] "r" (tile_M), + [tile_N] "r" (tile_N), + [n_channels] "r" (output_shape.n_channels), + [row_stride] "r" (output_shape.n_cols * output_shape.n_channels) + : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", + "x12", "x13", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "cc", "memory" + ); +} +/*****************************************************************************/ + +/*****************************************************************************/ +template <> +inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::execute( + const Tensor4DShape &output_shape, + float* const matrices[16], float* const output +) { + // profiler prof; + + // Allocate memory for the intermediate matrices + const int tile_M = iceildiv(output_shape.n_rows, 2); + const int tile_N = iceildiv(output_shape.n_cols, 2); + const int n_rows = output_shape.n_batches * tile_M * tile_N; + const int n_channels = output_shape.n_channels; + float* matrices_zf = reinterpret_cast( + calloc(8 * n_rows * n_channels, sizeof(float)) + ); + + // Perform the first stage transform, computing ZF. + const auto f_compute_zf = [&] () { + switch (n_channels % 4) { + case 0: + compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); + break; + case 1: + compute_zf<1>(n_rows, n_channels, matrices_zf, matrices); + break; + case 2: + compute_zf<2>(n_rows, n_channels, matrices_zf, matrices); + break; + case 3: + compute_zf<3>(n_rows, n_channels, matrices_zf, matrices); + }; + }; + // prof("Compute ZF", f_compute_zf, 16 * n_rows * n_channels * sizeof(float), 0, 8 * n_rows * n_channels * sizeof(float)); + f_compute_zf(); + + // Perform the second stage transform, finishing Z F Z^T - variable dispatch + // based on size of the output and the channel tail. + const auto f_compute_zfzT = [&] () { + if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { + constexpr bool tail_M = true, tail_N = true; + switch (n_channels % 4) { + case 0: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 1: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 2: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 3: + compute_zfzT(output_shape, output, matrices_zf); + } + } else if (output_shape.n_rows % 2) { + constexpr bool tail_M = true, tail_N = false; + switch (n_channels % 4) { + case 0: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 1: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 2: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 3: + compute_zfzT(output_shape, output, matrices_zf); + } + } else if (output_shape.n_cols % 2) { + constexpr bool tail_M = false, tail_N = true; + switch (n_channels % 4) { + case 0: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 1: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 2: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 3: + compute_zfzT(output_shape, output, matrices_zf); + } + } else { + constexpr bool tail_M = false, tail_N = false; + switch (n_channels % 4) { + case 0: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 1: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 2: + compute_zfzT(output_shape, output, matrices_zf); + break; + case 3: + compute_zfzT(output_shape, output, matrices_zf); + } + } + }; + // prof("Compute ZFZT", f_compute_zfzT, 8 * n_rows * n_channels * sizeof(float), 0, 4 * n_rows * n_channels * sizeof(float)); + f_compute_zfzT(); + + free(reinterpret_cast(matrices_zf)); +} +/*****************************************************************************/ + +#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/utils.hpp b/arm_compute/core/NEON/kernels/winograd/utils.hpp new file mode 100644 index 0000000000..14e709f028 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/utils.hpp @@ -0,0 +1,55 @@ + +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include + +inline double TimeInUs(void) { +#ifdef CYCLE_PROFILING + timespec t; + clock_gettime(CLOCK_THREAD_CPUTIME_ID, &t); + return 1e6*t.tv_sec + 1e-3*t.tv_nsec; +#else + return 0; +#endif +} + +inline int iceildiv(const int a, const int b) { + return (a + b - 1) / b; +} + +template +inline T roundup(const T a, const T b) { + return a + b - (a % b); +} + +inline void PrintMatrix(const float* const m, const int M, const int N, const int row_stride) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + printf("%.3f ", m[i*row_stride + j]); + } + printf("\n"); + } + printf("\n"); +} diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp new file mode 100644 index 0000000000..c990cd0252 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp @@ -0,0 +1,346 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include +#include +#include + +#include "alloc.hpp" +#include "gemm.hpp" +#include "profiler.hpp" +#include "utils.hpp" +#include "shims.hpp" + +#include "transforms.hpp" + +namespace winograd { + /***************************************************************************/ + /* Implementation of the Winograd F(2x2, 3x3, 4x4) algorithm using GEMM + * internally. + */ + template + class Winograd2x2_3x3GEMM { + public: + /* Instantiate a new Winograd operator. + */ + Winograd2x2_3x3GEMM(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage); + virtual ~Winograd2x2_3x3GEMM(); + + /** Transform the weights into the Winograd domain. + */ + template > + void transform_weights(const TIn* const kernel, void *transform_working_space); + + /* Initializes matrices pointers, to be called once before execute() + */ + template > + void reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const TIn* const input, void* working_space); + + /* Apply the Winograd operator to some input. + */ + template > + void reshape_output(const Tensor4DShape& input_shape, const PaddingType padding_type, TOut* const output); + + + /* Apply the Winograd operator to some input. + */ + void execute(size_t first, size_t last); + + /* Get the memory required to transform the kernel. + */ + static inline size_t get_kernel_transform_working_size(const KernelShape &shape); + + /* Get the output shape of a convolution. + */ + static Tensor4DShape get_output_shape(const Tensor4DShape &input_shape, const KernelShape &k_shape, + const PaddingType padding_type); + + /* Get the memory required to instantiate a new Winograd operator. + */ + static size_t get_kernel_storage_size(const KernelShape &shape); + + /* Get the memory required to apply a Winograd operator to some input. + */ + static size_t get_working_space_size(const Tensor4DShape &input_shape,const KernelShape &k_shape, + const PaddingType padding); + + + Winograd2x2_3x3GEMM(const Winograd2x2_3x3GEMM &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + Winograd2x2_3x3GEMM &operator=(const Winograd2x2_3x3GEMM &) = delete; + /** Allow instances of this class to be moved */ + Winograd2x2_3x3GEMM(Winograd2x2_3x3GEMM &&) = default; + /** Allow instances of this class to be moved */ + Winograd2x2_3x3GEMM &operator=(Winograd2x2_3x3GEMM &&) = default; + + protected: + /* Get the memory required by a single "input" matrix. + */ + static size_t get_input_matrix_size(const Tensor4DShape &input_shape,const KernelShape &k_shape, + const PaddingType padding); + + /* Get the memory required by a single "output" matrix. + */ + static size_t get_output_matrix_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, + const PaddingType padding); + + /* Get the memory required by a single "kernel" matrix. + */ + static size_t get_kernel_matrix_size(const KernelShape &shape); + + const KernelShape kernel_shape; // Shape of applied kernel + const Tensor4DShape in_shape; + const PaddingType padding; + + const int kernel_matrix_row_stride; // Stride within kernel matrix + + const bool manage_kernel_storage; // Free kernel storage when done + void* const _kernel_storage; // Base pointer for kernel matrices + + profiler prof; // Profiler + + TIn *kernel_matrices[16]; // Prepared form of kernel + TIn *input_matrices[16]; + TOut *output_matrices[16]; + + + static const int M_BLOCK = 4; + static const int N_BLOCK = 16; + }; +} // namespace winograd + +template +size_t winograd::Winograd2x2_3x3GEMM::get_kernel_transform_working_size( + const KernelShape &shape +) +{ + // Need to re-order the kernel into HWIO form, require enough space to + // represent the tensor. + return sizeof(TIn) * shape.size(); +} + + +template +template +void winograd::Winograd2x2_3x3GEMM::transform_weights( + const TIn* const kernel, + void *transform_working_space +) +{ + const int kernel_matrix_size_bytes = get_kernel_matrix_size(kernel_shape); + int8_t* const ks_bytes = reinterpret_cast(_kernel_storage); + for (int i = 0; i < 16; i++) { + kernel_matrices[i] = reinterpret_cast( + ks_bytes + i*kernel_matrix_size_bytes); + } + + const TIn *kernel_hwio = kernel; + if( transform_working_space) + { + kernel_hwio = reinterpret_cast(transform_working_space); + ofm_ifm_h_w_to_h_w_ifm_ofm( + kernel, const_cast(kernel_hwio), + kernel_shape.n_output_channels, + kernel_shape.n_input_channels, + kernel_shape.n_rows, + kernel_shape.n_cols + ); + } + KernelTransform::execute( + kernel_shape, kernel_hwio, kernel_matrices[0], + kernel_matrix_size_bytes / sizeof(TIn), + kernel_matrix_row_stride + ); +} + +template +winograd::Winograd2x2_3x3GEMM::Winograd2x2_3x3GEMM( const KernelShape &kernel_shape, const Tensor4DShape input_shape, + const PaddingType padding_type, void *kernel_storage) + : kernel_shape(kernel_shape), in_shape(input_shape), padding(padding_type),kernel_matrix_row_stride(roundup(kernel_shape.n_output_channels, N_BLOCK)), manage_kernel_storage(false), + _kernel_storage(kernel_storage), prof() { + memset(kernel_matrices, 0x00, sizeof(TIn)*16); + memset(input_matrices, 0x00, sizeof(TIn)*16); + memset(output_matrices, 0x00, sizeof(TOut)*16); +} + +/*****************************************************************************/ +template +winograd::Winograd2x2_3x3GEMM::~Winograd2x2_3x3GEMM() {} + +/*****************************************************************************/ +template +template +void winograd::Winograd2x2_3x3GEMM::reshape_input( + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const TIn* const input, + void *working_space +) { + assert(working_space); + int8_t* const ws_bytes = reinterpret_cast(working_space); + // Split the working space into that required for 16 input matrices and + // output matrices. + const int in_matrix_stride_bytes = get_input_matrix_size(input_shape, kernel_shape, padding_type); + const int out_matrix_stride_bytes = get_output_matrix_size(input_shape, kernel_shape, padding_type); + + for (int i = 0; i < 16; i++) { + input_matrices[i] = reinterpret_cast( + ws_bytes + i*in_matrix_stride_bytes); + output_matrices[i] = reinterpret_cast( + ws_bytes + 16*in_matrix_stride_bytes + i*out_matrix_stride_bytes); + } + + // Compute shape for the GEMM + const auto output_shape = get_output_shape(input_shape,kernel_shape, padding_type); + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + const int K = kernel_shape.n_input_channels; + + const int in_matrix_row_stride = K; + const int in_matrix_batch_stride = tile_rows*tile_cols*in_matrix_row_stride; + + // Transform the input tensor into an appropriate form + auto input_prep = [&] () { + InputTransform::execute( + input, input_shape, padding_type, tile_rows, tile_cols, + input_matrices[0], in_matrix_stride_bytes / sizeof(TIn), + in_matrix_batch_stride, in_matrix_row_stride + ); + }; + prof( + "Input Prep", input_prep, + InputTransform::bytes_read(input_shape, output_shape), + InputTransform::flops_performed(input_shape, output_shape), + InputTransform::bytes_written(input_shape, output_shape) + ); + +} + +/*****************************************************************************/ +template +template +void winograd::Winograd2x2_3x3GEMM::reshape_output(const Tensor4DShape& input_shape, const PaddingType padding_type, TOut* const output) { + assert(output_matrices[0]); + const int out_matrix_stride_bytes = get_output_matrix_size(input_shape, kernel_shape, padding_type); + const auto output_shape = get_output_shape(input_shape,kernel_shape, padding_type); + const int out_matrix_row_stride = kernel_matrix_row_stride; + + // Transform the output tensor into an appropriate form + OutputTransform::execute( + output_shape, + output_matrices[0], + out_matrix_stride_bytes / sizeof(TOut), + out_matrix_row_stride, + output + ); +} + + +/*****************************************************************************/ +template +void winograd::Winograd2x2_3x3GEMM::execute( size_t first, size_t last ) { + assert(input_matrices[0] && kernel_matrices[0] && output_matrices[0]); + assert(first < 16 && last < 16 && first < last); + // Compute shape for the GEMM + const auto output_shape = get_output_shape(in_shape,kernel_shape, padding); + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + const int M = in_shape.n_batches * tile_rows * tile_cols; + const int K = kernel_shape.n_input_channels; + const int N = kernel_shape.n_output_channels; + + const int in_matrix_row_stride = K; + const int out_matrix_row_stride = kernel_matrix_row_stride; + // Perform the GEMMs + for (size_t i = first; i <= last; i++) { + BlockedGemm( + input_matrices[i], kernel_matrices[i], output_matrices[i], M, K, N, + in_matrix_row_stride, kernel_matrix_row_stride, out_matrix_row_stride + ); +// prof("GEMM", perform_gemm, 0, 2*M*K*N, 0); // TODO Memory + } + +} + +/*****************************************************************************/ +template +Tensor4DShape winograd::Winograd2x2_3x3GEMM::get_output_shape( + const Tensor4DShape &in_shape, const KernelShape &k_shape, const PaddingType padding) { + return Tensor4DShape { + in_shape.n_batches, + (padding == PADDING_SAME) ? in_shape.n_rows : in_shape.n_rows - 2, + (padding == PADDING_SAME) ? in_shape.n_cols : in_shape.n_cols - 2, + k_shape.n_output_channels + }; +} + +template +size_t winograd::Winograd2x2_3x3GEMM::get_kernel_storage_size( + const KernelShape &shape) { + return 16 * get_kernel_matrix_size(shape); +} + +template +size_t winograd::Winograd2x2_3x3GEMM::get_kernel_matrix_size( + const KernelShape &shape) { + const int K = shape.n_input_channels; + const int N = roundup(shape.n_output_channels, N_BLOCK); + return sizeof(TIn) * K * N; +} + +template +size_t winograd::Winograd2x2_3x3GEMM::get_working_space_size( + const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type +) { + return 16 * get_input_matrix_size(input_shape, k_shape, padding_type) + + 16 * get_output_matrix_size(input_shape, k_shape, padding_type); +} + +template +size_t winograd::Winograd2x2_3x3GEMM::get_input_matrix_size( + const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type +) { + // Compute shape for the GEMM + const auto output_shape = get_output_shape(input_shape, k_shape, padding_type); + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + const int M = roundup(tile_rows * tile_cols, M_BLOCK); + const int K = k_shape.n_input_channels; + + return input_shape.n_batches * M * K * sizeof(TIn); +} + +template +size_t winograd::Winograd2x2_3x3GEMM::get_output_matrix_size( + const Tensor4DShape& input_shape, const KernelShape &k_shape,const PaddingType padding_type +) { + // Compute shape for the GEMM + const auto output_shape = get_output_shape(input_shape, k_shape, padding_type); + const int tile_rows = iceildiv(output_shape.n_rows, 2); + const int tile_cols = iceildiv(output_shape.n_cols, 2); + const int M = roundup(tile_rows * tile_cols, M_BLOCK); + const int N = roundup(k_shape.n_output_channels, N_BLOCK); + + return input_shape.n_batches * M * N * sizeof(TOut); +} diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp new file mode 100644 index 0000000000..4c7e291c58 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#include +#include + +#include "alloc.hpp" +#include "gemm.hpp" +#include "profiler.hpp" +#include "utils.hpp" +#include "shims.hpp" +#include "winograd_gemm.hpp" + +#include "transforms.hpp" + +#ifndef ALLOC_ALIGN +#define ALLOC_ALIGN 64 +#endif // ALLOC_ALIGN + + +namespace winograd_shim_nchw { + /***************************************************************************/ + /* Implementation of the Winograd F(2x2, 3x3, 4x4) algorithm using GEMM + * internally. + */ + template + class Winograd2x2_3x3GEMM : public winograd::Winograd2x2_3x3GEMM { + public: + /* Instantiate a new Winograd operator. + */ + Winograd2x2_3x3GEMM(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage); + + void nchw2nhwc( const Tensor4DShape& input_shape, const PaddingType padding_type, void *working_space, const TIn* const input); + void nhwc2nchw( const Tensor4DShape& input_shape, const PaddingType padding_type, void *working_space, TOut* const output); + + + std::pair get_nhwc_ptrs(const Tensor4DShape& input_shape,const PaddingType padding_type,void *working_space); + + static size_t get_working_space_size(const Tensor4DShape &input_shape,const KernelShape &k_shape, const PaddingType padding); + protected: + /* Get the memory required to store an NHWC copy of the input tensor. */ + static size_t get_working_nhwc_input_size(const Tensor4DShape &input_shape); + + /* Get the memory required to store an NHWC copy of the input tensor. */ + static size_t get_working_nhwc_output_size(const Tensor4DShape &output_shape, const KernelShape &k_shape, const PaddingType padding) ; + }; +} // namespace winograd + +/*****************************************************************************/ +template +winograd_shim_nchw::Winograd2x2_3x3GEMM::Winograd2x2_3x3GEMM( + const KernelShape &kernel_shape, const Tensor4DShape input_shape, + const PaddingType padding_type, void *kernel_storage +) : winograd::Winograd2x2_3x3GEMM(kernel_shape,input_shape,padding_type,kernel_storage) { +} + +/*****************************************************************************/ +template +void winograd_shim_nchw::Winograd2x2_3x3GEMM::nchw2nhwc(const Tensor4DShape& input_shape, const PaddingType padding_type, void *working_space, const TIn* const input) { + assert(working_space); + int8_t* const ws_bytes = reinterpret_cast(working_space); + + // Extract the top chunk of the working space to store the input and output + // tensors in NHWC format. + const int in_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_input_matrix_size(input_shape, this->kernel_shape, padding_type); + const int out_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_output_matrix_size(input_shape, this->kernel_shape, padding_type); + + // Allocate working space for the input and output in NHWC format + TIn* const input_nhwc = reinterpret_cast( + ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes) + ); + + // Re-order the input tensor + this->prof( + "NCHW -> NHWC", + [input, input_shape, input_nhwc] () { + nchw_to_nhwc( + input, input_nhwc, + input_shape.n_batches, + input_shape.n_channels, + input_shape.n_rows, + input_shape.n_cols + ); + }, + input_shape.size(), 0, input_shape.size() + ); +} + +/*****************************************************************************/ +template +void winograd_shim_nchw::Winograd2x2_3x3GEMM::nhwc2nchw(const Tensor4DShape& input_shape, const PaddingType padding_type, + void *working_space, TOut* const output) { + + assert(working_space); + int8_t* const ws_bytes = reinterpret_cast(working_space); + + // Extract the top chunk of the working space to store the input and output + // tensors in NHWC format. + const int in_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_input_matrix_size(input_shape, this->kernel_shape, padding_type); + const int out_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_output_matrix_size(input_shape, this->kernel_shape, padding_type); + + TOut* const output_nhwc = reinterpret_cast(ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes) + get_working_nhwc_input_size(input_shape)); + + // Re-order the output tensor into NCHW + const auto output_shape = winograd::Winograd2x2_3x3GEMM::get_output_shape(input_shape, this->kernel_shape, padding_type); + this->prof( + "NHWC -> NCHW", + [output_nhwc, output_shape, output] () { + nhwc_to_nchw( + output_nhwc, output, + output_shape.n_batches, + output_shape.n_rows, + output_shape.n_cols, + output_shape.n_channels + ); + }, + output_shape.size(), 0, output_shape.size() + ); +} + + +/*****************************************************************************/ +template +std::pair winograd_shim_nchw::Winograd2x2_3x3GEMM::get_nhwc_ptrs( + const Tensor4DShape& input_shape, + const PaddingType padding_type, + void *working_space +) { + assert(working_space); + int8_t* const ws_bytes = reinterpret_cast(working_space); + + // Extract the top chunk of the working space to store the input and output + // tensors in NHWC format. + const int in_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_input_matrix_size(input_shape, this->kernel_shape, padding_type); + const int out_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM::get_output_matrix_size(input_shape, this->kernel_shape, padding_type); + + // Allocate working space for the input and output in NHWC format + TIn* input_nhwc = reinterpret_cast(ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes)); + TOut* output_nhwc = reinterpret_cast(ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes) + get_working_nhwc_input_size(input_shape)); + return std::make_pair(output_nhwc,input_nhwc); +} + + + + +/*****************************************************************************/ +template +size_t winograd_shim_nchw::Winograd2x2_3x3GEMM::get_working_space_size( + const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type +) { + // TODO Add memory required for NHWC copies of input tensors + return winograd::Winograd2x2_3x3GEMM::get_working_space_size( + input_shape, k_shape, padding_type) + + get_working_nhwc_input_size(input_shape) + + get_working_nhwc_output_size(input_shape, k_shape, padding_type); +} + +template +size_t winograd_shim_nchw::Winograd2x2_3x3GEMM::get_working_nhwc_input_size( + const Tensor4DShape& input_shape +) { + return roundup(input_shape.size() * sizeof(TIn), static_cast(ALLOC_ALIGN)); +} + +template +size_t winograd_shim_nchw::Winograd2x2_3x3GEMM::get_working_nhwc_output_size( + const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type +) { + const auto output_shape = winograd::Winograd2x2_3x3GEMM::get_output_shape(input_shape,k_shape, padding_type); + return roundup(output_shape.size() * sizeof(TIn), static_cast(ALLOC_ALIGN)); +} diff --git a/arm_compute/runtime/NEON/NEFunctions.h b/arm_compute/runtime/NEON/NEFunctions.h index 5baaa50d40..2e8c084371 100644 --- a/arm_compute/runtime/NEON/NEFunctions.h +++ b/arm_compute/runtime/NEON/NEFunctions.h @@ -108,5 +108,6 @@ #include "arm_compute/runtime/NEON/functions/NETranspose.h" #include "arm_compute/runtime/NEON/functions/NEWarpAffine.h" #include "arm_compute/runtime/NEON/functions/NEWarpPerspective.h" +#include "arm_compute/runtime/NEON/functions/NEWinogradLayer.h" #endif /* __ARM_COMPUTE_NEFUNCTIONS_H__ */ diff --git a/arm_compute/runtime/NEON/functions/NEWinogradLayer.h b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h new file mode 100644 index 0000000000..7dca4570e5 --- /dev/null +++ b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_NEWINOGRADLAYER_H__ +#define __ARM_COMPUTE_NEWINOGRADLAYER_H__ + +#include "arm_compute/runtime/IFunction.h" + +#include "arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/runtime/MemoryGroup.h" +#include "arm_compute/runtime/Tensor.h" + +#include + +namespace arm_compute +{ +class ITensor; +/** Basic function to simulate a convolution layer. This function calls the following NEON kernels: + */ +class NEWinogradLayer : public IFunction +{ +public: + /** Constructor */ + NEWinogradLayer(std::shared_ptr memory_manager = nullptr); + + /** Set the input and output tensors. + * + * @param[in] input Source tensor. 3 lower dimensions represent a single input [width, height, IFM], + * while every optional dimension from 4 and above represent a batch of inputs. + * Data types supported: F32. + * @param[in] weights Weights tensor. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. Data type supported: Same as @p input. + * Currently only 3x3 kernels are supported. + * @param[in] biases Not supported, biases will be ignored. + * @param[out] output Destination tensor. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs. + * Data types supported: Same as @p input. + * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. Currently only unit strides are supported. + */ + void configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info); + + // Inherited methods overridden: + void run() override; + + /** Prevent instances of this class from being copied (As this class contains pointers) */ + NEWinogradLayer(const NEWinogradLayer &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + NEWinogradLayer &operator=(const NEWinogradLayer &) = delete; + +private: + using Winograd3x3F32 = NEWinogradLayerKernel::Winograd3x3F32; + + MemoryGroup _memory_group; + NEWinogradLayerKernel _winograd_kernel; + Tensor _weights_workspace; + Tensor _workspace; + Tensor _kernel_storage; + const ITensor *_input; + const ITensor *_weights; + ITensor *_output; + bool _reshaped_kernel; + std::unique_ptr _conv; +}; +} +#endif /* __ARM_COMPUTE_NEWINOGRADLAYER_H__ */ diff --git a/scripts/check_bad_style.sh b/scripts/check_bad_style.sh index e7f6f1af54..4cd69757d6 100755 --- a/scripts/check_bad_style.sh +++ b/scripts/check_bad_style.sh @@ -5,7 +5,7 @@ set -e DIRECTORIES="./arm_compute ./src ./examples ./tests ./utils ./support" -grep -HrnP --exclude-dir=assembly "/\*\*$" $DIRECTORIES | tee bad_style.log +grep -HrnP --exclude-dir=assembly --exclude-dir=winograd "/\*\*$" $DIRECTORIES | tee bad_style.log if (( `cat bad_style.log | wc -l` > 0 )) then echo "" @@ -13,7 +13,7 @@ then exit -1 fi -grep -Hnr --exclude-dir=assembly --exclude=Doxyfile "@brief" $DIRECTORIES | tee bad_style.log +grep -Hnr --exclude-dir=assembly --exclude-dir=winograd --exclude=Doxyfile "@brief" $DIRECTORIES | tee bad_style.log if (( `cat bad_style.log | wc -l` > 0 )) then echo "" @@ -21,7 +21,7 @@ then exit -1 fi -grep -HnRE --exclude-dir=assembly "\buint " --exclude-dir=cl_kernels --exclude-dir=cs_shaders $DIRECTORIES | tee bad_style.log +grep -HnRE --exclude-dir=assembly --exclude-dir=winograd "\buint " --exclude-dir=cl_kernels --exclude-dir=cs_shaders $DIRECTORIES | tee bad_style.log if [[ $(cat bad_style.log | wc -l) > 0 ]] then echo "" @@ -29,7 +29,7 @@ then exit -1 fi -grep -HnR --exclude-dir=assembly "float32_t" $DIRECTORIES | tee bad_style.log +grep -HnR --exclude-dir=assembly --exclude-dir=winograd "float32_t" $DIRECTORIES | tee bad_style.log if [[ $(cat bad_style.log | wc -l) > 0 ]] then echo "" @@ -37,7 +37,7 @@ then exit -1 fi -grep -Hnir --exclude-dir=assembly "arm[_ ]\?cv" $DIRECTORIES | tee bad_style.log +grep -Hnir --exclude-dir=assembly --exclude-dir=winograd "arm[_ ]\?cv" $DIRECTORIES | tee bad_style.log if [[ $(cat bad_style.log | wc -l) > 0 ]] then echo "" @@ -45,7 +45,7 @@ then exit -1 fi -grep -Hnir --exclude-dir=assembly "#.*if.*defined[^(]" $DIRECTORIES | tee bad_style.log +grep -Hnir --exclude-dir=assembly --exclude-dir=winograd "#.*if.*defined[^(]" $DIRECTORIES | tee bad_style.log if [[ $(cat bad_style.log | wc -l) > 0 ]] then echo "" @@ -53,7 +53,7 @@ then exit -1 fi -grep -Hnir --exclude-dir=assembly "#else$\|#endif$" $DIRECTORIES | tee bad_style.log +grep -Hnir --exclude-dir=assembly --exclude-dir=winograd "#else$\|#endif$" $DIRECTORIES | tee bad_style.log if [[ $(cat bad_style.log | wc -l) > 0 ]] then echo "" @@ -61,7 +61,7 @@ then exit -1 fi -grep -Hnir --exclude-dir=assembly "ARM_COMPUTE_AARCH64_V8_2" ./tests/validation/CL | tee bad_style.log +grep -Hnir --exclude-dir=assembly --exclude-dir=winograd "ARM_COMPUTE_AARCH64_V8_2" ./tests/validation/CL | tee bad_style.log if [[ $(cat bad_style.log | wc -l) > 0 ]] then echo "" diff --git a/scripts/clang_tidy_rules.py b/scripts/clang_tidy_rules.py index 9c012680d4..5b27dd5be5 100755 --- a/scripts/clang_tidy_rules.py +++ b/scripts/clang_tidy_rules.py @@ -42,6 +42,9 @@ def filter_clang_tidy_lines( lines ): if "/assembly/" in line: continue + if "/winograd/" in line: + continue + if "error:" in line: if (("Utils.cpp" in line and "'arm_compute_version.embed' file not found" in line) or ("cl2.hpp" in line and "cast from pointer to smaller type 'cl_context_properties' (aka 'int') loses information" in line) or diff --git a/src/core/NEON/kernels/NEWinogradLayerKernel.cpp b/src/core/NEON/kernels/NEWinogradLayerKernel.cpp new file mode 100644 index 0000000000..b9109dcff2 --- /dev/null +++ b/src/core/NEON/kernels/NEWinogradLayerKernel.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h" + +#include "arm_compute/core/Error.h" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/TensorInfo.h" + +namespace arm_compute +{ +NEWinogradLayerKernel::NEWinogradLayerKernel() + : _convolver(nullptr), _output(nullptr) +{ +} + +void NEWinogradLayerKernel::configure(ITensor *output, Winograd3x3F32 *convolver) +{ + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32); + _convolver = convolver; + Window win = calculate_max_window(*output->info()); + INEKernel::configure(win); +} + +void NEWinogradLayerKernel::run(const Window &window, const ThreadInfo &info) +{ + ARM_COMPUTE_UNUSED(window); + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); + ARM_COMPUTE_ERROR_ON(info.num_threads < 1); + const size_t tid = info.thread_id; + const size_t num_threads = std::min(info.num_threads, 16); + const size_t num_gemms_per_thread = 16 / num_threads; + const size_t first_gemm = tid * num_gemms_per_thread; + const size_t last_gemm = (tid == (num_threads - 1)) ? 15 : first_gemm + num_gemms_per_thread - 1; + _convolver->execute(first_gemm, last_gemm); +} +} // namespace arm_compute diff --git a/src/runtime/NEON/functions/NEWinogradLayer.cpp b/src/runtime/NEON/functions/NEWinogradLayer.cpp new file mode 100644 index 0000000000..a9dec4ea0d --- /dev/null +++ b/src/runtime/NEON/functions/NEWinogradLayer.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/runtime/NEON/functions/NEWinogradLayer.h" + +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/runtime/NEON/NEScheduler.h" +#include "support/ToolchainSupport.h" + +namespace +{ +inline Tensor4DShape internal_get_input_shape(const arm_compute::ITensor *input) +{ + const int in_width = input->info()->dimension(0); + const int in_height = input->info()->dimension(1); + const int in_batches = input->info()->dimension(3); + const int in_channels = input->info()->dimension(2); + return Tensor4DShape({ in_batches, in_height, in_width, in_channels }); +} +} /* namespace */ + +namespace arm_compute +{ +NEWinogradLayer::NEWinogradLayer(std::shared_ptr memory_manager) + : _memory_group(std::move(memory_manager)), _winograd_kernel(), _weights_workspace(), _workspace(), _kernel_storage(), _input(), _weights(), _output(), _reshaped_kernel(false), _conv() +{ +} /* arm_compute */ + +void NEWinogradLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info) +{ + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32); + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); + ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(1) != 3 || weights->info()->dimension(0) != 3, "Only 3x3 kernels are supported"); + ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4); + + if(biases != nullptr) + { + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases); + ARM_COMPUTE_ERROR_ON(biases->info()->num_dimensions() > 1); + } + + _weights = weights; + _input = input; + _output = output; + + // Get parameters from conv_info + unsigned int stride_x = 0; + unsigned int stride_y = 0; + std::tie(stride_x, stride_y) = conv_info.stride(); + ARM_COMPUTE_ERROR_ON_MSG(stride_y != 1 || stride_x != 1, "Winograd layer only supports unit strides."); + + // Get convolved dimensions + auto padding = PADDING_VALID; + const int in_channels = input->info()->dimension(2); + + const int out_channels = output->info()->dimension(2); + const int weights_width = weights->info()->dimension(0); + const int weights_height = weights->info()->dimension(1); + + const KernelShape kernel_shape({ out_channels, weights_height, weights_width, in_channels }); + const Tensor4DShape in_shape(internal_get_input_shape(input)); + + // Get the memory required to instantiate a new Winograd operator. + constexpr size_t kstore_alignment = 64; + const size_t kernel_storage_per_thread = Winograd3x3F32::get_kernel_storage_size(kernel_shape); + _kernel_storage.allocator()->init(TensorInfo(TensorShape{ (kernel_storage_per_thread + kstore_alignment - 1) }, 1, DataType::U8)); + _memory_group.manage(&_kernel_storage); + + // Get workbench size and allocate memory + constexpr size_t wspace_alignment = 64; + const size_t ws_size = Winograd3x3F32::get_working_space_size(in_shape, kernel_shape, padding); + _workspace.allocator()->init(TensorInfo(TensorShape{ (ws_size + wspace_alignment - 1) }, 1, DataType::U8)); + _memory_group.manage(&_workspace); + + // Workspace for weights transform + const size_t weights_transform_size = Winograd3x3F32::get_kernel_transform_working_size(kernel_shape); + _weights_workspace.allocator()->init(TensorInfo(TensorShape{ (weights_transform_size + wspace_alignment - 1) }, 1, DataType::U8)); + _memory_group.manage(&_weights_workspace); + + _kernel_storage.allocator()->allocate(); + _workspace.allocator()->allocate(); + _weights_workspace.allocator()->allocate(); + + // Create Winograd operator object + _conv = support::cpp14::make_unique(kernel_shape, in_shape, padding, _kernel_storage.buffer()); + + // Configure the kernel, padding not needed so it's safe to call configure after allocare + _winograd_kernel.configure(output, _conv.get()); +} + +void NEWinogradLayer::run() +{ +#if defined(__aarch64__) + _memory_group.acquire(); + if(!_reshaped_kernel) + { + _conv->transform_weights(reinterpret_cast(_weights->buffer()), reinterpret_cast(_weights_workspace.buffer())); + _reshaped_kernel = true; + } + const Tensor4DShape in_shape(internal_get_input_shape(_input)); + auto padding = PADDING_VALID; + + //Bring channels to the front as Winograd code expects the tensor to be in the format NHWC + _conv->nchw2nhwc(in_shape, padding, _workspace.buffer(), reinterpret_cast(_input->buffer())); + + //Get ptrs into the workspace + std::pair nhwc_ptrs = _conv->get_nhwc_ptrs(in_shape, padding, _workspace.buffer()); + + //Setup matrices ptrs and transfor the input tensor to the appropriate form before running GEMM. + _conv->reshape_input(in_shape, padding, nhwc_ptrs.second, _workspace.buffer()); + + //Run 16 GEMMs in multiple threads, each kernel runs one or more GEMMs + NEScheduler::get().schedule(&_winograd_kernel, Window::DimY); + + //Transform the output to the appropriate form + _conv->reshape_output(in_shape, padding, nhwc_ptrs.first); + + //Transform back to NCHW + _conv->nhwc2nchw(in_shape, padding, _workspace.buffer(), reinterpret_cast(_output->buffer())); + + _memory_group.release(); +#else /* __aarch64__ */ + ARM_COMPUTE_UNUSED(_winograd_kernel); + ARM_COMPUTE_UNUSED(_workspace); + ARM_COMPUTE_UNUSED(_kernel_storage); + ARM_COMPUTE_UNUSED(_input); + ARM_COMPUTE_UNUSED(_weights); + ARM_COMPUTE_UNUSED(_output); + ARM_COMPUTE_UNUSED(_reshaped_kernel); + ARM_COMPUTE_UNUSED(_conv); + ARM_COMPUTE_ERROR("Winograd only supported for aarch64, recompile with arch=arm64-v8a."); +#endif /* __aarch64__ */ +} +} // namespace arm_compute diff --git a/tests/datasets/SmallConvolutionLayerDataset.h b/tests/datasets/SmallConvolutionLayerDataset.h index aa9d9f8899..ccdd6e16af 100644 --- a/tests/datasets/SmallConvolutionLayerDataset.h +++ b/tests/datasets/SmallConvolutionLayerDataset.h @@ -37,6 +37,18 @@ namespace test { namespace datasets { +class SmallWinogradLayerDataset final : public ConvolutionLayerDataset +{ +public: + SmallWinogradLayerDataset() + { + // Batch size 1 + add_config(TensorShape(8U, 8U, 2U), TensorShape(3U, 3U, 2U, 1U), TensorShape(1U), TensorShape(6U, 6U, 1U), PadStrideInfo(1, 1, 0, 0)); + // Batch size 4 + add_config(TensorShape(23U, 27U, 5U, 4U), TensorShape(3U, 3U, 5U, 21U), TensorShape(21U), TensorShape(21U, 25U, 21U, 4U), PadStrideInfo(1, 1, 0, 0)); + } +}; + class SmallConvolutionLayerDataset final : public ConvolutionLayerDataset { public: diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp index 5e14a7c3cc..575ffe17a9 100644 --- a/tests/validation/NEON/ConvolutionLayer.cpp +++ b/tests/validation/NEON/ConvolutionLayer.cpp @@ -23,6 +23,7 @@ */ #include "arm_compute/core/Types.h" #include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h" +#include "arm_compute/runtime/NEON/functions/NEWinogradLayer.h" #include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/TensorAllocator.h" #include "tests/NEON/Accessor.h" @@ -34,6 +35,7 @@ #include "tests/framework/datasets/Datasets.h" #include "tests/validation/Validation.h" #include "tests/validation/fixtures/ConvolutionLayerFixture.h" +#include "tests/validation/fixtures/WinogradLayerFixture.h" namespace arm_compute { @@ -62,6 +64,23 @@ const auto CNNDataTypes = framework::dataset::make("DataType", } // namespace TEST_SUITE(NEON) + +#if defined(__aarch64__) +TEST_SUITE(WinogradLayer) +template +using NEWinogradLayerFixture = WinogradLayerValidationFixture; + +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, NEWinogradLayerFixture, framework::DatasetMode::PRECOMMIT, datasets::SmallWinogradLayerDataset()) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_f32); +} + +TEST_SUITE_END() +TEST_SUITE_END() +#endif /* __aarch64__ */ + TEST_SUITE(ConvolutionLayer) DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallConvolutionLayerDataset(), datasets::LargeConvolutionLayerDataset()), CNNDataTypes), diff --git a/tests/validation/fixtures/WinogradLayerFixture.h b/tests/validation/fixtures/WinogradLayerFixture.h new file mode 100644 index 0000000000..a5d6fc966d --- /dev/null +++ b/tests/validation/fixtures/WinogradLayerFixture.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_TEST_WINOGRAD_LAYER_FIXTURE +#define ARM_COMPUTE_TEST_WINOGRAD_LAYER_FIXTURE + +#include "arm_compute/core/TensorShape.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/runtime/NEON/NEScheduler.h" +#include "tests/AssetsLibrary.h" +#include "tests/Globals.h" +#include "tests/IAccessor.h" +#include "tests/framework/Asserts.h" +#include "tests/framework/Fixture.h" +#include "tests/validation/CPP/ConvolutionLayer.h" +#include "tests/validation/CPP/Utils.h" +#include "tests/validation/Helpers.h" + +#include + +namespace arm_compute +{ +class NEWinogradLayer; + +namespace test +{ +namespace validation +{ +template +class WinogradLayerValidationFixture : public framework::Fixture +{ +public: + template + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info) + { + _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info); + _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info); + } + +protected: + template + void fill(U &&tensor, int i, float min, float max) + { + switch(tensor.data_type()) + { + case DataType::F32: + { + std::uniform_real_distribution<> distribution(min, max); + library->fill(tensor, distribution, i); + break; + } + default: + { + ARM_COMPUTE_ERROR("Not supported"); + library->fill_tensor_uniform(tensor, i); + break; + } + } + } + + TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info) + { + // Create tensors + TensorType src = create_tensor(input_shape, DataType::F32, 1); + TensorType weights = create_tensor(weights_shape, DataType::F32, 1); + TensorType bias = create_tensor(bias_shape, DataType::F32, 1); + TensorType dst = create_tensor(output_shape, DataType::F32, 1); + + // Create and configure function + FunctionType conv; + conv.configure(&src, &weights, nullptr, &dst, info); + + ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Allocate tensors + src.allocator()->allocate(); + weights.allocator()->allocate(); + bias.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!weights.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Fill tensors + fill(AccessorType(src), 0, -1.f, 1.f); + fill(AccessorType(weights), 1, -1.f, 1.f); + fill(AccessorType(bias), 2, 0.f, 0.f); + fill(AccessorType(dst), 3, -1.f, 1.f); + + // Compute NEWinogradLayer function + conv.run(); + + return dst; + } + + SimpleTensor compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info) + { + // Create reference + SimpleTensor src{ input_shape, DataType::F32, 1 }; + SimpleTensor weights{ weights_shape, DataType::F32, 1 }; + SimpleTensor bias{ bias_shape, DataType::F32, 1 }; + + // Fill reference + fill(src, 0, -1.f, 1.f); + fill(weights, 1, -1.f, 1.f); + fill(bias, 2, 0.f, 0.f); + + return reference::convolution_layer(src, weights, bias, output_shape, info); + } + + TensorType _target{}; + SimpleTensor _reference{}; + int _fractional_bits{}; + DataType _data_type{}; +}; + +} // namespace validation +} // namespace test +} // namespace arm_compute +#endif /* ARM_COMPUTE_TEST_WINOGRAD_LAYER_FIXTURE */ -- cgit v1.2.1