aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--SConstruct2
-rw-r--r--arm_compute/core/NEON/NEKernels.h1
-rw-r--r--arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h70
-rw-r--r--arm_compute/core/NEON/kernels/winograd/alloc.hpp30
-rw-r--r--arm_compute/core/NEON/kernels/winograd/gemm.hpp127
-rw-r--r--arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp355
-rw-r--r--arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp1445
-rw-r--r--arm_compute/core/NEON/kernels/winograd/perf.h32
-rw-r--r--arm_compute/core/NEON/kernels/winograd/profiler.hpp244
-rw-r--r--arm_compute/core/NEON/kernels/winograd/shims.hpp319
-rw-r--r--arm_compute/core/NEON/kernels/winograd/tensor.hpp210
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms.hpp29
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp638
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp1498
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp961
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp195
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp822
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp356
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp650
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp655
-rw-r--r--arm_compute/core/NEON/kernels/winograd/utils.hpp55
-rw-r--r--arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp346
-rw-r--r--arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp192
-rw-r--r--arm_compute/runtime/NEON/NEFunctions.h1
-rw-r--r--arm_compute/runtime/NEON/functions/NEWinogradLayer.h84
-rwxr-xr-xscripts/check_bad_style.sh16
-rwxr-xr-xscripts/clang_tidy_rules.py3
-rw-r--r--src/core/NEON/kernels/NEWinogradLayerKernel.cpp60
-rw-r--r--src/runtime/NEON/functions/NEWinogradLayer.cpp155
-rw-r--r--tests/datasets/SmallConvolutionLayerDataset.h12
-rw-r--r--tests/validation/NEON/ConvolutionLayer.cpp19
-rw-r--r--tests/validation/fixtures/WinogradLayerFixture.h145
32 files changed, 9719 insertions, 8 deletions
diff --git a/SConstruct b/SConstruct
index 6f4835828a..e7504228d3 100644
--- a/SConstruct
+++ b/SConstruct
@@ -180,6 +180,8 @@ if not GetOption("help"):
if env['standalone']:
env.Append(CXXFLAGS = ['-fPIC'])
env.Append(LINKFLAGS = ['-static-libgcc','-static-libstdc++'])
+ if env['cppthreads']:
+ env.Append(LINKFLAGS = ['-lpthread'])
if env['Werror']:
env.Append(CXXFLAGS = ['-Werror'])
diff --git a/arm_compute/core/NEON/NEKernels.h b/arm_compute/core/NEON/NEKernels.h
index 7fb5f78f13..281f06305f 100644
--- a/arm_compute/core/NEON/NEKernels.h
+++ b/arm_compute/core/NEON/NEKernels.h
@@ -111,6 +111,7 @@
#include "arm_compute/core/NEON/kernels/NETransposeKernel.h"
#include "arm_compute/core/NEON/kernels/NEWarpKernel.h"
#include "arm_compute/core/NEON/kernels/NEWeightsReshapeKernel.h"
+#include "arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h"
#include "arm_compute/core/NEON/kernels/arm32/NEGEMMAArch32Kernel.h"
#include "arm_compute/core/NEON/kernels/arm64/NEGEMMAArch64Kernel.h"
#include "arm_compute/core/NEON/kernels/arm64/NEGEMMLowpAArch64A53Kernel.h"
diff --git a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h
new file mode 100644
index 0000000000..1e7ca64b8c
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__
+#define __ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__
+
+#include "arm_compute/core/NEON/INEKernel.h"
+
+#include "arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp"
+
+namespace arm_compute
+{
+class ITensor;
+
+class NEWinogradLayerKernel : public INEKernel
+{
+public:
+ using Winograd3x3F32 = winograd_shim_nchw::Winograd2x2_3x3GEMM<float, float>;
+
+ /** Constructor */
+ NEWinogradLayerKernel();
+
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ NEWinogradLayerKernel(const NEWinogradLayerKernel &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ NEWinogradLayerKernel &operator=(const NEWinogradLayerKernel &) = delete;
+ /** Allow instances of this class to be moved */
+ NEWinogradLayerKernel(NEWinogradLayerKernel &&) = default;
+ /** Allow instances of this class to be moved */
+ NEWinogradLayerKernel &operator=(NEWinogradLayerKernel &&) = default;
+
+ virtual ~NEWinogradLayerKernel() = default;
+
+ /** Initialise the kernel
+ *
+ * @param[in,out] output Output tensor to store the result of matrix multiplication.
+ * @param[in] convolver A pointer to the winograd convolver, this object must have been configured and is ready to execute 16 GEMMS .
+ */
+ void configure(ITensor *output, Winograd3x3F32 *convolver);
+
+ // Inherited methods overridden:
+ void run(const Window &window, const ThreadInfo &info) override;
+
+protected:
+ Winograd3x3F32 *_convolver;
+ ITensor *_output;
+};
+
+} // namespace arm_compute
+#endif /*__ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__*/
diff --git a/arm_compute/core/NEON/kernels/winograd/alloc.hpp b/arm_compute/core/NEON/kernels/winograd/alloc.hpp
new file mode 100644
index 0000000000..ef6f2b5115
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/alloc.hpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#ifdef ALLOC_ALIGN
+#define ALLOCATE(x) aligned_alloc(ALLOC_ALIGN, x)
+#else
+#define ALLOCATE(x) malloc(x)
+#endif
diff --git a/arm_compute/core/NEON/kernels/winograd/gemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm.hpp
new file mode 100644
index 0000000000..564016a646
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/gemm.hpp
@@ -0,0 +1,127 @@
+
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+#include "utils.hpp"
+
+template <typename TIn, typename TOut>
+void Gemm(const TIn* const a, const TIn* const b, TOut *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride,
+ const bool a_transposed=false,
+ const bool b_transposed=false) {
+ // Array access methods
+ const auto A = [a, a_transposed, M, K, a_row_stride] (const int i, const int j) -> TIn {
+ return a[(!a_transposed) ? i*a_row_stride + j : i + j*M];
+ };
+
+ const auto B = [b, b_transposed, K, N, b_row_stride] (const int i, const int j) -> TIn {
+ return b[(!b_transposed) ? i*b_row_stride + j : i + j*N];
+ };
+
+ const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& {
+ return c[i*c_row_stride + j];
+ };
+
+ // Perform the matrix multiplication
+ for (int i = 0; i < M; i++) {
+ for (int j = 0; j < N; j++) {
+ for (int k = 0; k < K; k++) {
+ C(i, j) += A(i, k) * B(k, j);
+ }
+ }
+ }
+}
+
+template <const int M_BLOCK, const int N_BLOCK, typename TIn, typename TOut>
+void BlockedGemm(
+ const TIn* const a, const TIn* const b, TOut *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ // Array access methods
+ const auto A = [a, M, K, a_row_stride] (const int i, const int j) -> TIn {
+ return a[i*a_row_stride + j];
+ };
+
+ const auto B = [b, K, N, b_row_stride] (const int i, const int j) -> TIn {
+ return b[i*b_row_stride + j];
+ };
+
+ const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& {
+ return c[i*c_row_stride + j];
+ };
+
+ const int M_BLOCKS = iceildiv(M, M_BLOCK);
+ const int N_BLOCKS = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < M_BLOCKS; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < N_BLOCKS; nblock++) {
+ // Create an appropriately sized block of accumulators
+ TOut accum[M_BLOCK][N_BLOCK];
+ for (int i = 0; i < M_BLOCK; i++) {
+ for (int j = 0; j < N_BLOCK; j++) {
+ accum[i][j] = static_cast<TOut>(0);
+ }
+ }
+
+ // Perform this portion of the matrix multiply
+ for (int k = 0; k < K; k++) {
+ // Load elements of A
+ TIn elems_a[M_BLOCK];
+ for (int i = 0; i < M_BLOCK; i++) {
+ elems_a[i] = A(mblock*M_BLOCK + i, k);
+ }
+
+ // Load elements of B
+ TIn elems_b[N_BLOCK];
+ for (int j = 0; j < N_BLOCK; j++) {
+ elems_b[j] = B(k, nblock*N_BLOCK + j);
+ }
+
+ // Perform the partial matrix multiply
+ for (int i = 0; i < M_BLOCK; i++) {
+ for (int j = 0; j < N_BLOCK; j++) {
+ accum[i][j] += elems_a[i] * elems_b[j];
+ }
+ }
+ }
+
+ // Store the partial product
+ for (int i = 0; i < M_BLOCK; i++) {
+ for (int j = 0; j < N_BLOCK; j++) {
+ C(mblock*M_BLOCK + i, nblock*N_BLOCK + j) = accum[i][j];
+ }
+ }
+ }
+ }
+}
+
+#include "gemm/a64_sgemm.hpp"
diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp
new file mode 100644
index 0000000000..e1b7488c31
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp
@@ -0,0 +1,355 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+#include <cassert>
+#include "../utils.hpp"
+
+#ifdef __aarch64__
+
+template <>
+inline void BlockedGemm<8, 12, float, float>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int M_BLOCK = 8;
+ const int N_BLOCK = 12;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = K;
+
+ asm volatile (
+ // Create an 8x12 block of accumulators
+ " A_1 .req v27\n"
+ "sA_1 .req s27\n"
+ " A_2 .req v28\n"
+ "sA_2 .req s28\n"
+ " A_3 .req v29\n"
+ "sA_3 .req s29\n"
+ " A_4 .req v30\n"
+ "sA_4 .req s30\n"
+
+ " B_1 .req v24\n" " B_2 .req v25\n" " B_3 .req v26\n"
+ "qB_1 .req q24\n" "qB_2 .req q25\n" "qB_3 .req q26\n"
+
+ " C_11 .req v0\n" " C_12 .req v1\n" " C_13 .req v2\n"
+ " C_21 .req v3\n" " C_22 .req v4\n" " C_23 .req v5\n"
+ " C_31 .req v6\n" " C_32 .req v7\n" " C_33 .req v8\n"
+ " C_41 .req v9\n" " C_42 .req v10\n" " C_43 .req v11\n"
+ " C_51 .req v12\n" " C_52 .req v13\n" " C_53 .req v14\n"
+ " C_61 .req v15\n" " C_62 .req v16\n" " C_63 .req v17\n"
+ " C_71 .req v18\n" " C_72 .req v19\n" " C_73 .req v20\n"
+ " C_81 .req v21\n" " C_82 .req v22\n" " C_83 .req v23\n"
+
+ "qC_11 .req q0\n" "qC_12 .req q1\n" "qC_13 .req q2\n"
+ "qC_21 .req q3\n" "qC_22 .req q4\n" "qC_23 .req q5\n"
+ "qC_31 .req q6\n" "qC_32 .req q7\n" "qC_33 .req q8\n"
+ "qC_41 .req q9\n" "qC_42 .req q10\n" "qC_43 .req q11\n"
+ "qC_51 .req q12\n" "qC_52 .req q13\n" "qC_53 .req q14\n"
+ "qC_61 .req q15\n" "qC_62 .req q16\n" "qC_63 .req q17\n"
+ "qC_71 .req q18\n" "qC_72 .req q19\n" "qC_73 .req q20\n"
+ "qC_81 .req q21\n" "qC_82 .req q22\n" "qC_83 .req q23\n"
+
+ "aptr1 .req x17\n"
+ "aptr2 .req x18\n"
+ "aptr3 .req x19\n"
+ "aptr4 .req x20\n"
+ "aptr5 .req x21\n"
+ "aptr6 .req x22\n"
+ "aptr7 .req x23\n"
+
+ // Initialise accumulators with 0
+ // Initialise pointers
+ "movi C_11.4s, #0\n"
+ "add aptr1, %x[aptr], %x[a_row_stride]\n"
+ "movi C_12.4s, #0\n"
+ "add aptr2, aptr1, %x[a_row_stride]\n"
+ "movi C_13.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride]\n"
+ "movi C_21.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride]\n"
+ "movi C_22.4s, #0\n"
+ "add aptr5, aptr4, %x[a_row_stride]\n"
+ "movi C_23.4s, #0\n"
+ "add aptr6, aptr5, %x[a_row_stride]\n"
+ "movi C_31.4s, #0\n"
+ "add aptr7, aptr6, %x[a_row_stride]\n"
+ "movi C_32.4s, #0\n"
+ "ldr qB_1, [%x[bptr]]\n"
+ "movi C_33.4s, #0\n"
+ "ldr qB_2, [%x[bptr], #0x10]\n"
+ "movi C_41.4s, #0\n"
+ "prfm pldl1keep, [%x[bptr], #0x00]\n"
+ "movi C_42.4s, #0\n"
+ "prfm pldl1keep, [%x[bptr], #0x10]\n"
+ "movi C_43.4s, #0\n"
+ "prfm pldl1keep, [%x[bptr], #0x20]\n"
+ "movi C_51.4s, #0\n"
+ "prfm pldl1keep, [%x[aptr], #0x00]\n"
+ "movi C_52.4s, #0\n"
+ "prfm pldl1keep, [ aptr1, #0x00]\n"
+ "movi C_53.4s, #0\n"
+ "prfm pldl1keep, [ aptr2, #0x00]\n"
+ "movi C_61.4s, #0\n"
+ "prfm pldl1keep, [ aptr3, #0x00]\n"
+ "movi C_62.4s, #0\n"
+ "prfm pldl1keep, [ aptr4, #0x00]\n"
+ "movi C_63.4s, #0\n"
+ "prfm pldl1keep, [ aptr5, #0x00]\n"
+ "movi C_71.4s, #0\n"
+ "prfm pldl1keep, [ aptr6, #0x00]\n"
+ "movi C_72.4s, #0\n"
+ "prfm pldl1keep, [ aptr7, #0x00]\n"
+ "movi C_73.4s, #0\n"
+ "ldr sA_1, [%x[aptr]], #0x4\n"
+ "movi C_81.4s, #0\n"
+ "ldr sA_2, [ aptr1], #0x4\n"
+ "movi C_82.4s, #0\n"
+ "ldr sA_3, [ aptr2], #0x4\n"
+ "movi C_83.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 2f\n"
+
+ "1:"
+ "fmla C_11.4s, B_1.4s, A_1.s[0]\n"
+ "ldr qB_3, [%x[bptr], #0x20]\n"
+ "fmla C_12.4s, B_2.4s, A_1.s[0]\n"
+ "ldr sA_4, [ aptr3], #0x4\n"
+ "fmla C_13.4s, B_3.4s, A_1.s[0]\n"
+ "ldr sA_1, [ aptr4], #0x04\n"
+
+ "fmla C_21.4s, B_1.4s, A_2.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride]\n"
+ "fmla C_22.4s, B_2.4s, A_2.s[0]\n"
+ "prfm pldl1keep, [ aptr3, #0x10]\n"
+ "fmla C_23.4s, B_3.4s, A_2.s[0]\n"
+ "ldr sA_2, [ aptr5], #0x04\n"
+
+ "fmla C_31.4s, B_1.4s, A_3.s[0]\n"
+ "prfm pldl1keep, [%x[bptr], #0x00]\n"
+ "fmla C_32.4s, B_2.4s, A_3.s[0]\n"
+ "prfm pldl1keep, [%x[bptr], #0x10]\n"
+ "fmla C_33.4s, B_3.4s, A_3.s[0]\n"
+ "ldr sA_3, [ aptr6], #0x04\n"
+
+ "fmla C_41.4s, B_1.4s, A_4.s[0]\n"
+ "prfm pldl1keep, [%x[bptr], #0x20]\n"
+ "fmla C_42.4s, B_2.4s, A_4.s[0]\n"
+ "prfm pldl1keep, [ aptr4, #0x10]\n"
+ "fmla C_43.4s, B_3.4s, A_4.s[0]\n"
+ "ldr sA_4, [ aptr7], #0x04\n"
+
+ "fmla C_51.4s, B_1.4s, A_1.s[0]\n"
+ "prfm pldl1keep, [ aptr5, #0x10]\n"
+ "fmla C_52.4s, B_2.4s, A_1.s[0]\n"
+ "prfm pldl1keep, [ aptr6, #0x10]\n"
+ "fmla C_53.4s, B_3.4s, A_1.s[0]\n"
+ "ldr sA_1, [%x[aptr]], #0x04\n"
+
+ "fmla C_61.4s, B_1.4s, A_2.s[0]\n"
+ "prfm pldl1keep, [ aptr7, #0x10]\n"
+ "fmla C_62.4s, B_2.4s, A_2.s[0]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla C_63.4s, B_3.4s, A_2.s[0]\n"
+ "ldr sA_2, [ aptr1], #0x04\n"
+
+ "fmla C_71.4s, B_1.4s, A_3.s[0]\n"
+ "prfm pldl1keep, [%x[aptr], #0x10]\n"
+ "fmla C_72.4s, B_2.4s, A_3.s[0]\n"
+ "prfm pldl1keep, [ aptr1, #0x10]\n"
+ "fmla C_73.4s, B_3.4s, A_3.s[0]\n"
+ "ldr sA_3, [ aptr2], #0x04\n"
+
+ "fmla C_81.4s, B_1.4s, A_4.s[0]\n"
+ "prfm pldl1keep, [ aptr2, #0x10]\n"
+ "fmla C_82.4s, B_2.4s, A_4.s[0]\n"
+ "ldp qB_1, qB_2, [%x[bptr]]\n"
+ "fmla C_83.4s, B_3.4s, A_4.s[0]\n"
+ "bne 1b\n"
+
+ "2:"
+ "fmla C_11.4s, B_1.4s, A_1.s[0]\n"
+ "ldr qB_3, [%x[bptr], #0x20]\n"
+ "fmla C_12.4s, B_2.4s, A_1.s[0]\n"
+ "stp qC_11, qC_12, [%x[cptr]]\n"
+ "fmla C_13.4s, B_3.4s, A_1.s[0]\n"
+ "str qC_13, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+ "ldr sA_1, [ aptr4], #0x04\n"
+
+ "fmla C_21.4s, B_1.4s, A_2.s[0]\n"
+ "ldr sA_4, [ aptr3], #0x4\n"
+ "fmla C_22.4s, B_2.4s, A_2.s[0]\n"
+ "stp qC_21, qC_22, [%x[cptr]]\n"
+ "fmla C_23.4s, B_3.4s, A_2.s[0]\n"
+ "str qC_23, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+ "ldr sA_2, [ aptr5], #0x04\n"
+
+ "fmla C_31.4s, B_1.4s, A_3.s[0]\n"
+ "fmla C_32.4s, B_2.4s, A_3.s[0]\n"
+ "stp qC_31, qC_32, [%x[cptr]]\n"
+ "fmla C_33.4s, B_3.4s, A_3.s[0]\n"
+ "str qC_33, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+ "ldr sA_3, [ aptr6], #0x04\n"
+
+ "fmla C_41.4s, B_1.4s, A_4.s[0]\n"
+ "fmla C_42.4s, B_2.4s, A_4.s[0]\n"
+ "stp qC_41, qC_42, [%x[cptr]]\n"
+ "fmla C_43.4s, B_3.4s, A_4.s[0]\n"
+ "str qC_43, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+ "ldr sA_4, [ aptr7], #0x04\n"
+
+ "fmla C_51.4s, B_1.4s, A_1.s[0]\n"
+ "fmla C_52.4s, B_2.4s, A_1.s[0]\n"
+ "stp qC_51, qC_52, [%x[cptr]]\n"
+ "fmla C_53.4s, B_3.4s, A_1.s[0]\n"
+ "str qC_53, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+ "fmla C_61.4s, B_1.4s, A_2.s[0]\n"
+ "fmla C_62.4s, B_2.4s, A_2.s[0]\n"
+ "stp qC_61, qC_62, [%x[cptr]]\n"
+ "fmla C_63.4s, B_3.4s, A_2.s[0]\n"
+ "str qC_63, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+ "fmla C_71.4s, B_1.4s, A_3.s[0]\n"
+ "fmla C_72.4s, B_2.4s, A_3.s[0]\n"
+ "stp qC_71, qC_72, [%x[cptr]]\n"
+ "fmla C_73.4s, B_3.4s, A_3.s[0]\n"
+ "str qC_73, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+ "fmla C_81.4s, B_1.4s, A_4.s[0]\n"
+ "fmla C_82.4s, B_2.4s, A_4.s[0]\n"
+ "stp qC_81, qC_82, [%x[cptr]]\n"
+ "fmla C_83.4s, B_3.4s, A_4.s[0]\n"
+ "str qC_83, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+ // Clear aliases
+ ".unreq aptr1\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+ ".unreq aptr5\n"
+ ".unreq aptr6\n"
+ ".unreq aptr7\n"
+
+ ".unreq A_1\n" ".unreq A_2\n" ".unreq A_3\n" ".unreq A_4\n"
+ ".unreq sA_1\n" ".unreq sA_2\n" ".unreq sA_3\n" ".unreq sA_4\n"
+
+ ".unreq B_1\n" ".unreq B_2\n" ".unreq B_3\n"
+ ".unreq qB_1\n" ".unreq qB_2\n" ".unreq qB_3\n"
+
+ ".unreq C_11\n" ".unreq C_12\n" ".unreq C_13\n"
+ ".unreq C_21\n" ".unreq C_22\n" ".unreq C_23\n"
+ ".unreq C_31\n" ".unreq C_32\n" ".unreq C_33\n"
+ ".unreq C_41\n" ".unreq C_42\n" ".unreq C_43\n"
+ ".unreq C_51\n" ".unreq C_52\n" ".unreq C_53\n"
+ ".unreq C_61\n" ".unreq C_62\n" ".unreq C_63\n"
+ ".unreq C_71\n" ".unreq C_72\n" ".unreq C_73\n"
+ ".unreq C_81\n" ".unreq C_82\n" ".unreq C_83\n"
+
+ ".unreq qC_11\n" ".unreq qC_12\n" ".unreq qC_13\n"
+ ".unreq qC_21\n" ".unreq qC_22\n" ".unreq qC_23\n"
+ ".unreq qC_31\n" ".unreq qC_32\n" ".unreq qC_33\n"
+ ".unreq qC_41\n" ".unreq qC_42\n" ".unreq qC_43\n"
+ ".unreq qC_51\n" ".unreq qC_52\n" ".unreq qC_53\n"
+ ".unreq qC_61\n" ".unreq qC_62\n" ".unreq qC_63\n"
+ ".unreq qC_71\n" ".unreq qC_72\n" ".unreq qC_73\n"
+ ".unreq qC_81\n" ".unreq qC_82\n" ".unreq qC_83\n"
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
+ "v29", "v30", "x17", "x18", "x19", "x20", "x21", "x22", "x23"
+ );
+ }
+ }
+}
+
+/*****************************************************************************/
+/* 4x16 blocked GEMM with specialised tails
+ */
+#include "a64_sgemm_4x16.hpp"
+
+template <>
+inline void BlockedGemm<4, 16, float, float>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ // Despatch based on tail of K
+ switch (K % 4) {
+ case 3:
+ sgemm_4x16_impl<3>(
+ a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+ );
+ break;
+ case 2:
+ sgemm_4x16_impl<2>(
+ a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+ );
+ break;
+ case 1:
+ sgemm_4x16_impl<1>(
+ a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+ );
+ break;
+ case 0:
+ sgemm_4x16_impl<0>(
+ a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+ );
+ break;
+ default:
+ assert(0);
+ break;
+ }
+}
+
+#endif // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp
new file mode 100644
index 0000000000..e74610ef27
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp
@@ -0,0 +1,1445 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+template <const unsigned int tail>
+inline void sgemm_4x16_impl(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+);
+
+template <>
+inline void sgemm_4x16_impl<0>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int TAIL_SIZE = 0;
+ const int M_BLOCK = 4;
+ const int N_BLOCK = 16;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = (K - TAIL_SIZE) / 4;
+
+ asm volatile(
+ "aptr2 .req X20\n"
+ "aptr3 .req X21\n"
+ "aptr4 .req X22\n"
+ "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n"
+ "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n"
+ "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n"
+ "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n"
+ "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+ "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+ "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+ "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+ "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+ "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+ "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+ "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+ "vB1 .req v20\n" "qB1 .req q20\n"
+ "vB2 .req v21\n" "qB2 .req q21\n"
+ "vB3 .req v22\n" "qB3 .req q22\n"
+ "vB4 .req v23\n" "qB4 .req q23\n"
+
+ // Clear accumulators, initialise pointers
+ "movi vC11.4s, #0\n"
+ "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+ "movi vC12.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride_bytes]\n"
+ "movi vC13.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride_bytes]\n"
+ "movi vC14.4s, #0\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "movi vC21.4s, #0\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "movi vC22.4s, #0\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "movi vC23.4s, #0\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "movi vC24.4s, #0\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "movi vC31.4s, #0\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 2f\n"
+
+ "1:" // Loop proper
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "bne 1b\n"
+
+ "2:" // Tail
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "stp qC11, qC12, [%x[cptr], #0x00]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "stp qC13, qC14, [%x[cptr], #0x20]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "stp qC21, qC22, [%x[cptr], #0x00]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "stp qC23, qC24, [%x[cptr], #0x20]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "stp qC31, qC32, [%x[cptr], #0x00]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "stp qC33, qC34, [%x[cptr], #0x20]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "stp qC41, qC42, [%x[cptr], #0x00]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "stp qC43, qC44, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+ ".unreq vB4\n" ".unreq qB4\n"
+ ".unreq vB3\n" ".unreq qB3\n"
+ ".unreq vB2\n" ".unreq qB2\n"
+ ".unreq vB1\n" ".unreq qB1\n"
+ ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+ ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+ ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+ ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+ ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+ ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+ ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+ ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+ ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+ ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+ ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+ ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory", "x20", "x21", "x22",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23"
+ );
+ }
+ }
+}
+
+template <>
+inline void sgemm_4x16_impl<1>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int TAIL_SIZE = 1;
+ const int M_BLOCK = 4;
+ const int N_BLOCK = 16;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = (K - TAIL_SIZE) / 4;
+
+ asm volatile(
+ "aptr2 .req X20\n"
+ "aptr3 .req X21\n"
+ "aptr4 .req X22\n"
+ "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n"
+ "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n"
+ "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n"
+ "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n"
+ "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+ "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+ "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+ "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+ "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+ "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+ "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+ "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+ "vB1 .req v20\n" "qB1 .req q20\n"
+ "vB2 .req v21\n" "qB2 .req q21\n"
+ "vB3 .req v22\n" "qB3 .req q22\n"
+ "vB4 .req v23\n" "qB4 .req q23\n"
+
+ // Clear accumulators, initialise pointers
+ "movi vC11.4s, #0\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "movi vC12.4s, #0\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "movi vC13.4s, #0\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "movi vC14.4s, #0\n"
+ "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+ "movi vC21.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride_bytes]\n"
+ "movi vC22.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride_bytes]\n"
+ "movi vC23.4s, #0\n"
+ "cbnz %x[k], 3f\n"
+
+ // Prepare for tail in K
+ "movi vC24.4s, #0\n"
+ "ldr sA1, [%x[aptr]], #0x04\n"
+ "movi vC31.4s, #0\n"
+ "ldr sA2, [ aptr2], #0x04\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "b 2f\n" // Jump to tail
+
+ "3:" // Prepare for loop over K
+ "movi vC24.4s, #0\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "movi vC31.4s, #0\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 4f\n"
+
+ "1:" // Loop proper
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "bne 1b\n"
+
+ "4:" // Tail iteration
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr sA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr sA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+
+ "2:" // Common tail
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "stp qC11, qC12, [%x[cptr], #0x00]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "ldr sA3, [ aptr3], #0x10\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "stp qC13, qC14, [%x[cptr], #0x20]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "stp qC21, qC22, [%x[cptr], #0x00]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "ldr sA4, [ aptr4], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "stp qC23, qC24, [%x[cptr], #0x20]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "stp qC31, qC32, [%x[cptr], #0x00]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "stp qC33, qC34, [%x[cptr], #0x20]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "stp qC41, qC42, [%x[cptr], #0x00]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+ "stp qC43, qC44, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+ ".unreq vB4\n" ".unreq qB4\n"
+ ".unreq vB3\n" ".unreq qB3\n"
+ ".unreq vB2\n" ".unreq qB2\n"
+ ".unreq vB1\n" ".unreq qB1\n"
+ ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+ ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+ ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+ ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+ ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+ ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+ ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+ ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+ ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+ ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+ ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+ ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory", "x20", "x21", "x22",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23"
+ );
+ }
+ }
+}
+
+template <>
+inline void sgemm_4x16_impl<2>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int TAIL_SIZE = 2;
+ const int M_BLOCK = 4;
+ const int N_BLOCK = 16;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = (K - TAIL_SIZE) / 4;
+
+ asm volatile(
+ "aptr2 .req X20\n"
+ "aptr3 .req X21\n"
+ "aptr4 .req X22\n"
+ "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n"
+ "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n"
+ "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n"
+ "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n"
+ "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+ "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+ "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+ "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+ "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+ "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+ "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+ "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+ "vB1 .req v20\n" "qB1 .req q20\n"
+ "vB2 .req v21\n" "qB2 .req q21\n"
+ "vB3 .req v22\n" "qB3 .req q22\n"
+ "vB4 .req v23\n" "qB4 .req q23\n"
+
+ // Clear accumulators, initialise pointers
+ "movi vC11.4s, #0\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "movi vC12.4s, #0\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "movi vC13.4s, #0\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "movi vC14.4s, #0\n"
+ "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+ "movi vC21.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride_bytes]\n"
+ "movi vC22.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride_bytes]\n"
+ "movi vC23.4s, #0\n"
+ "cbnz %x[k], 3f\n"
+
+ // Prepare for tail in K
+ "movi vC24.4s, #0\n"
+ "ldr dA1, [%x[aptr]], #0x08\n"
+ "movi vC31.4s, #0\n"
+ "ldr dA2, [ aptr2], #0x08\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "b 2f\n" // Jump to tail
+
+ "3:" // Prepare for loop over K
+ "movi vC24.4s, #0\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "movi vC31.4s, #0\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 4f\n"
+
+ "1:" // Loop proper
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "bne 1b\n"
+
+ "4:" // Tail iteration
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr dA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr dA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+
+ "2:" // Common tail
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr dA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr dA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "stp qC11, qC12, [%x[cptr], #0x00]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "stp qC13, qC14, [%x[cptr], #0x20]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "stp qC21, qC22, [%x[cptr], #0x00]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "stp qC23, qC24, [%x[cptr], #0x20]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "stp qC31, qC32, [%x[cptr], #0x00]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "stp qC33, qC34, [%x[cptr], #0x20]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "stp qC41, qC42, [%x[cptr], #0x00]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+ "stp qC43, qC44, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+ ".unreq vB4\n" ".unreq qB4\n"
+ ".unreq vB3\n" ".unreq qB3\n"
+ ".unreq vB2\n" ".unreq qB2\n"
+ ".unreq vB1\n" ".unreq qB1\n"
+ ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+ ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+ ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+ ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+ ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+ ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+ ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+ ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+ ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+ ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+ ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+ ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory", "x20", "x21", "x22",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23"
+ );
+ }
+ }
+}
+
+template <>
+inline void sgemm_4x16_impl<3>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int TAIL_SIZE = 3;
+ const int M_BLOCK = 4;
+ const int N_BLOCK = 16;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = (K - TAIL_SIZE) / 4;
+
+ asm volatile(
+ "aptr2 .req X20\n"
+ "aptr3 .req X21\n"
+ "aptr4 .req X22\n"
+ "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n"
+ "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n"
+ "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n"
+ "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n"
+ "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+ "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+ "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+ "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+ "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+ "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+ "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+ "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+ "vB1 .req v20\n" "qB1 .req q20\n"
+ "vB2 .req v21\n" "qB2 .req q21\n"
+ "vB3 .req v22\n" "qB3 .req q22\n"
+ "vB4 .req v23\n" "qB4 .req q23\n"
+
+ // Clear accumulators, initialise pointers
+ "movi vC11.4s, #0\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "movi vC12.4s, #0\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "movi vC13.4s, #0\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "movi vC14.4s, #0\n"
+ "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+ "movi vC21.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride_bytes]\n"
+ "movi vC22.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride_bytes]\n"
+ "movi vC23.4s, #0\n"
+ "cbnz %x[k], 3f\n"
+
+ // Prepare for tail in K
+ "movi vC24.4s, #0\n"
+ "ldr dA1, [%x[aptr]], #0x08\n"
+ "movi vC31.4s, #0\n"
+ "ldr dA2, [ aptr2], #0x08\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "b 2f\n" // Jump to tail
+
+ "3:" // Prepare for loop over K
+ "movi vC24.4s, #0\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "movi vC31.4s, #0\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 4f\n"
+
+ "1:" // Loop proper
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "bne 1b\n"
+
+ "4:" // Tail iteration
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr dA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr dA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+
+ "2:" // Common tail
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr dA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr dA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "ldr sA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "ldr sA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "stp qC11, qC12, [%x[cptr], #0x00]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "ldr sA3, [ aptr3], #0x10\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "stp qC13, qC14, [%x[cptr], #0x20]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "stp qC21, qC22, [%x[cptr], #0x00]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "ldr sA4, [ aptr4], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "stp qC23, qC24, [%x[cptr], #0x20]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "stp qC31, qC32, [%x[cptr], #0x00]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "stp qC33, qC34, [%x[cptr], #0x20]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "stp qC41, qC42, [%x[cptr], #0x00]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+ "stp qC43, qC44, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+ ".unreq vB4\n" ".unreq qB4\n"
+ ".unreq vB3\n" ".unreq qB3\n"
+ ".unreq vB2\n" ".unreq qB2\n"
+ ".unreq vB1\n" ".unreq qB1\n"
+ ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+ ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+ ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+ ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+ ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+ ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+ ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+ ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+ ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+ ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+ ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+ ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory", "x20", "x21", "x22",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23"
+ );
+ }
+ }
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/perf.h b/arm_compute/core/NEON/kernels/winograd/perf.h
new file mode 100644
index 0000000000..11fb0c452f
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/perf.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+/* Prototypes from perf.c */
+
+void start_counter(int fd);
+long long get_counter(int fd);
+long long stop_counter(int fd);
+int open_instruction_counter(void);
+int open_cycle_counter(void);
diff --git a/arm_compute/core/NEON/kernels/winograd/profiler.hpp b/arm_compute/core/NEON/kernels/winograd/profiler.hpp
new file mode 100644
index 0000000000..143192b589
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/profiler.hpp
@@ -0,0 +1,244 @@
+
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#include <algorithm>
+#include <cmath>
+#include <cstring>
+#include <cstdio>
+#include <map>
+#include <vector>
+
+#include "perf.h"
+#include <unistd.h>
+
+class profiler {
+private:
+#ifdef CYCLE_PROFILING
+ struct ProfileEntry {
+ int event_id;
+ long int bytes_read, ops, bytes_written;
+ long int duration;
+ };
+
+ static const int maxevents = 10000;
+ ProfileEntry events[maxevents];
+ int currentevent;
+ int countfd;
+
+ std::map<const char *, int> event_ids;
+
+ int get_event_id(const char *id) {
+ if (!event_ids.count(id)) {
+ event_ids.emplace(id, event_ids.size());
+ }
+ return event_ids[id];
+ }
+#endif // CYCLE_PROFILING
+
+public:
+#ifdef CYCLE_PROFILING
+ profiler() {
+ currentevent = 0;
+ countfd = open_cycle_counter();
+ }
+
+ ~profiler() {
+ close(countfd);
+
+ // Compute performance from recorded events
+ struct ProfileResult {
+ ProfileResult() : total_calls(0),
+ total_duration(0),
+ total_bytes_read(0),
+ total_ops(0),
+ total_bytes_written(0) {
+ }
+
+ void operator+=(const ProfileEntry &rhs) {
+ total_calls++;
+ total_duration += rhs.duration;
+ total_bytes_read += rhs.bytes_read;
+ total_ops += rhs.ops;
+ total_bytes_written = rhs.bytes_written;
+ }
+
+ float avg_duration(void) const {
+ return static_cast<float>(total_duration) /
+ static_cast<float>(total_calls);
+ }
+
+ float bytes_read_per_cycle(void) const {
+ return static_cast<float>(total_bytes_read) /
+ static_cast<float>(total_duration);
+ }
+
+ float ops_per_cycle(void) const {
+ return static_cast<float>(total_ops) /
+ static_cast<float>(total_duration);
+ }
+
+ float bytes_written_per_cycle(void) const {
+ return static_cast<float>(total_bytes_written) /
+ static_cast<float>(total_duration);
+ }
+
+ long int total_calls,
+ total_duration,
+ total_bytes_read,
+ total_ops,
+ total_bytes_written;
+ };
+
+ std::vector<ProfileResult> totals;
+ totals.resize(event_ids.size());
+ for (int i = 0; i < currentevent; i++) {
+ const auto &event = events[i];
+ totals[event.event_id] += event;
+ }
+
+ // Get the longest label
+ int len_label = 0;
+ for (const auto &kv : event_ids) {
+ len_label = std::max(len_label, static_cast<int>(strlen(kv.first)));
+ }
+
+ // Get the longest values for every other field
+ const auto get_length_of_field =
+ [totals] (const char *title, auto f, auto len) -> size_t {
+ size_t l = strlen(title);
+ for (const auto &v : totals) {
+ l = std::max(l, len(f(v)));
+ }
+ return l;
+ };
+
+ // Get the strlen for an int
+ const auto intlen = [] (long int x) -> size_t {
+ size_t len = 0;
+ do {
+ x /= 10;
+ len++;
+ } while (x);
+ return len;
+ };
+
+ // Get the strlen for a float
+ const auto floatlen = [] (const int precision) {
+ return [precision] (float x) {
+ size_t len = 0;
+
+ if (!std::isfinite(x)) {
+ return static_cast<size_t>(3);
+ }
+
+ do {
+ x /= 10.0f;
+ len++;
+ } while (x > 1.0f);
+ return len + 1 + precision;
+ };
+ };
+
+ const int len_calls = get_length_of_field(
+ "Calls", [] (const auto &v) {return v.total_calls;},
+ intlen
+ );
+ const int len_duration = get_length_of_field(
+ "Duration", [] (const auto &v) {return v.total_duration;},
+ intlen
+ );
+ const int len_average_duration = get_length_of_field(
+ "Average", [] (const auto &v) {return v.avg_duration();},
+ floatlen(2)
+ );
+ const int len_reads_per_cycle = get_length_of_field(
+ "Reads / cycle",
+ [] (const auto &v) {return v.bytes_read_per_cycle();},
+ floatlen(6)
+ );
+ const int len_ops_per_cycle = get_length_of_field(
+ "Ops / cycle",
+ [] (const auto &v) {return v.ops_per_cycle();},
+ floatlen(6)
+ );
+ const int len_writes_per_cycle = get_length_of_field(
+ "Writes / cycle",
+ [] (const auto &v) {return v.bytes_written_per_cycle();},
+ floatlen(6)
+ );
+
+ // Print header
+ printf(
+ "%*s %*s %*s %*s %*s %*s %*s\n",
+ len_label, "",
+ len_calls, "Calls",
+ len_duration, "Duration",
+ len_average_duration, "Average",
+ len_reads_per_cycle, "Reads / cycle",
+ len_ops_per_cycle, "Ops / cycle",
+ len_writes_per_cycle, "Writes / cycle"
+ );
+ for (const auto &kv : event_ids) {
+ const auto id = kv.second;
+ printf(
+ "%*s %*ld %*ld %*.2f %*.6f %*.6f %*.6f\n",
+ len_label, kv.first,
+ len_calls, totals[id].total_calls,
+ len_duration, totals[id].total_duration,
+ len_average_duration, totals[id].avg_duration(),
+ len_reads_per_cycle, totals[id].bytes_read_per_cycle(),
+ len_ops_per_cycle, totals[id].ops_per_cycle(),
+ len_writes_per_cycle, totals[id].bytes_written_per_cycle()
+ );
+ }
+ printf("\n");
+ }
+#endif // CYCLE_PROFILING
+
+ template <typename T>
+ void operator() (const char * event,
+ T func,
+ long int bytes_read = 0,
+ long int ops = 0,
+ long int bytes_written = 0) {
+#ifdef CYCLE_PROFILING
+ if (currentevent==maxevents) {
+ func();
+ } else {
+ start_counter(countfd);
+ func();
+ long long cycs = stop_counter(countfd);
+
+ // Store the profiling data
+ events[currentevent++] = {
+ get_event_id(event), bytes_read, ops, bytes_written, cycs
+ };
+ }
+#else
+ func();
+#endif // CYCLE_PROFILING
+ }
+};
diff --git a/arm_compute/core/NEON/kernels/winograd/shims.hpp b/arm_compute/core/NEON/kernels/winograd/shims.hpp
new file mode 100644
index 0000000000..249e5757f0
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/shims.hpp
@@ -0,0 +1,319 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+/** Re-order a weight tensor from [Output feature map x Input feature map x
+ * Height x Width] format to [Height x Width x Input feature map x Output
+ * feature map] format.
+ */
+template <typename T>
+inline void ofm_ifm_h_w_to_h_w_ifm_ofm(
+ const T* const in, // Input in [Output x Input x Height x Width] form
+ T* const out, // Output in [Height x Width x Input x Output] form
+ const int n_output_feature_maps,
+ const int n_input_feature_maps,
+ const int n_rows,
+ const int n_cols,
+ int in_output_feature_map_stride=0,
+ int in_input_feature_map_stride=0,
+ int in_row_stride=0,
+ int out_row_stride=0,
+ int out_col_stride=0,
+ int out_input_feature_map_stride=0
+);
+
+/** Re-order a weight tensor from [Height x Width x Input feature map x Output
+ * feature map] format to [Output feature map x Input feature map x Height x
+ * Width] format.
+ */
+template <typename T>
+inline void h_w_ifm_ofm_to_ofm_ifm_h_w(
+ const T* const in, // Input in [Height x Width x Input x Output] form
+ T* const out, // Output in [Output x Input x Height x Width] form
+ const int n_rows,
+ const int n_cols,
+ const int n_input_feature_maps,
+ const int n_output_feature_maps,
+ int in_row_stride=0,
+ int in_col_stride=0,
+ int in_input_feature_map_stride=0,
+ int out_output_feature_map_stride=0,
+ int out_input_feature_map_stride=0,
+ int out_row_stride=0
+);
+
+
+/* Re-order a tensor from NCHW format to NHWC.
+ */
+template <typename T>
+inline void nchw_to_nhwc(
+ const T* const in,
+ T* const out,
+ const int n_batches,
+ const int n_channels,
+ const int n_rows,
+ const int n_cols,
+ int in_batch_stride=0,
+ int in_channel_stride=0,
+ int in_row_stride=0,
+ int out_batch_stride=0,
+ int out_row_stride=0,
+ int out_col_stride=0
+)
+{
+ // Fill in the stride values
+ in_row_stride = (in_row_stride) ? in_row_stride : n_cols;
+ in_channel_stride = (in_channel_stride) ? in_channel_stride
+ : n_rows * in_row_stride;
+ in_batch_stride = (in_batch_stride) ? in_batch_stride
+ : n_channels * in_channel_stride;
+
+ out_col_stride = (out_col_stride) ? out_col_stride : n_channels;
+ out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride;
+ out_batch_stride = (out_batch_stride) ? out_batch_stride
+ : n_rows * out_row_stride;
+
+ // Perform the re-ordering
+ for (int n = 0; n < n_batches; n++)
+ {
+ const T* const in_batch = in + n*in_batch_stride;
+ T* const out_batch = out + n*out_batch_stride;
+
+ for (int i = 0; i < n_rows; i++)
+ {
+ const T* const in_row = in_batch + i*in_row_stride;
+ T* const out_row = out_batch + i*out_row_stride;
+
+ for (int j = 0; j < n_cols; j++)
+ {
+ const T* const in_col = in_row + j;
+ T* const out_col = out_row + j*out_col_stride;
+
+ for (int c = 0; c < n_channels; c++)
+ {
+ const T* const in_channel = in_col + c*in_channel_stride;
+ out_col[c] = *(in_channel);
+ }
+ }
+ }
+ }
+}
+
+/* Re-order a tensor from NHWC format to NCHW.
+ */
+template <typename T>
+inline void nhwc_to_nchw(
+ const T* const in, // Input data in NHWC form
+ T* const out, // Output data in NCHW form
+ const int n_batches,
+ const int n_rows,
+ const int n_cols,
+ const int n_channels,
+ int in_batch_stride=0,
+ int in_row_stride=0,
+ int in_col_stride=0,
+ int out_batch_stride=0,
+ int out_channel_stride=0,
+ int out_row_stride=0
+)
+{
+ // Fill in stride values
+ in_col_stride = (in_col_stride) ? in_col_stride : n_channels;
+ in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride;
+ in_batch_stride = (in_batch_stride) ? in_batch_stride
+ : n_rows * in_row_stride;
+
+ out_row_stride = (out_row_stride) ? out_row_stride : n_cols;
+ out_channel_stride = (out_channel_stride) ? out_channel_stride
+ : n_rows * out_row_stride;
+ out_batch_stride = (out_batch_stride) ? out_batch_stride
+ : n_channels * out_channel_stride;
+
+ // Perform the re-ordering
+ // For every batch
+ for (int n = 0; n < n_batches; n++)
+ {
+ const T* const in_batch = in + n*in_batch_stride;
+ T* const out_batch = out + n*out_batch_stride;
+
+ // For every row
+ for (int i = 0; i < n_rows; i++)
+ {
+ const T* const in_i = in_batch + i*in_row_stride;
+ T* const out_i = out_batch + i*out_row_stride;
+
+ // For every column
+ for (int j = 0; j < n_cols; j++)
+ {
+ const T* const in_j = in_i + j*in_col_stride;
+ T* const out_j = out_i + j;
+
+ // For every channel
+ for (int c = 0; c < n_channels; c++)
+ {
+ const T* const in_channel = in_j + c;
+ T* const out_channel = out_j + c*out_channel_stride;
+ *(out_channel) = *(in_channel);
+ }
+ }
+ }
+ }
+}
+
+
+/*****************************************************************************/
+/* Generic weight re-order implementation.
+ */
+template <typename T>
+inline void ofm_ifm_h_w_to_h_w_ifm_ofm(
+ const T* const in, // Input in [Output x Input x Height x Width] form
+ T* const out, // Output in [Height x Width x Input x Output] form
+ const int n_output_feature_maps,
+ const int n_input_feature_maps,
+ const int n_rows,
+ const int n_cols,
+ int in_output_feature_map_stride,
+ int in_input_feature_map_stride,
+ int in_row_stride,
+ int out_row_stride,
+ int out_col_stride,
+ int out_input_feature_map_stride
+)
+{
+ // Fill in stride values
+ in_row_stride = (in_row_stride)
+ ? in_row_stride
+ : n_cols;
+ in_input_feature_map_stride = (in_input_feature_map_stride)
+ ? in_input_feature_map_stride
+ : n_rows * in_row_stride;
+ in_output_feature_map_stride = (in_output_feature_map_stride)
+ ? in_output_feature_map_stride
+ : n_input_feature_maps * in_input_feature_map_stride;
+
+ out_input_feature_map_stride = (out_input_feature_map_stride)
+ ? out_input_feature_map_stride
+ : n_output_feature_maps;
+ out_col_stride = (out_col_stride)
+ ? out_col_stride
+ : n_input_feature_maps * out_input_feature_map_stride;
+ out_row_stride = (out_row_stride)
+ ? out_row_stride
+ : n_cols * out_col_stride;
+
+ // Perform the re-ordering
+ for (int i = 0; i < n_rows; i++)
+ {
+ const T* const in_row = in + i * in_row_stride;
+ T* out_row = out + i * out_row_stride;
+
+ for (int j = 0; j < n_cols; j++)
+ {
+ const T* const in_col = in_row + j;
+ T* const out_col = out_row + j * out_col_stride;
+
+ for (int ifm = 0; ifm < n_input_feature_maps; ifm++)
+ {
+ const T* const in_ifm = in_col + ifm * in_input_feature_map_stride;
+ T* const out_ifm = out_col + ifm * out_input_feature_map_stride;
+
+ for (int ofm = 0; ofm < n_output_feature_maps; ofm++)
+ {
+ const T* const in_ofm = in_ifm + ofm * in_output_feature_map_stride;
+ T* const out_ofm = out_ifm + ofm;
+ *(out_ofm) = *(in_ofm);
+ }
+ }
+ }
+ }
+}
+
+/*****************************************************************************/
+/* Generic weight re-order implementation.
+ */
+template <typename T>
+inline void h_w_ifm_ofm_to_ofm_ifm_h_w(
+ const T* const in, // Input in [Height x Width x Input x Output] form
+ T* const out, // Output in [Output x Input x Height x Width] form
+ const int n_rows,
+ const int n_cols,
+ const int n_input_feature_maps,
+ const int n_output_feature_maps,
+ int in_row_stride,
+ int in_col_stride,
+ int in_input_feature_map_stride,
+ int out_output_feature_map_stride,
+ int out_input_feature_map_stride,
+ int out_row_stride
+)
+{
+ // Fill in the stride values
+ in_input_feature_map_stride = (in_input_feature_map_stride)
+ ? in_input_feature_map_stride
+ : n_output_feature_maps;
+ in_col_stride = (in_col_stride)
+ ? in_col_stride
+ : n_input_feature_maps * in_input_feature_map_stride;
+ in_row_stride = (in_row_stride)
+ ? in_row_stride
+ : n_cols * in_col_stride;
+
+ out_row_stride = (out_row_stride)
+ ? out_row_stride
+ : n_cols;
+ out_input_feature_map_stride = (out_input_feature_map_stride)
+ ? out_input_feature_map_stride
+ : n_rows * out_row_stride;
+ out_output_feature_map_stride = (out_output_feature_map_stride)
+ ? out_output_feature_map_stride
+ : n_input_feature_maps * out_input_feature_map_stride;
+
+ // Perform the re-ordering
+ for (int i = 0; i < n_rows; i++)
+ {
+ const T* const in_row = in + i * in_row_stride;
+ T* const out_row = out + i * out_row_stride;
+
+ for (int j = 0; j < n_cols; j++)
+ {
+ const T* const in_col = in_row + j * in_col_stride;
+ T* const out_col = out_row + j;
+
+ for (int ifm = 0; ifm < n_input_feature_maps; ifm++)
+ {
+ const T* const in_ifm = in_col + ifm * in_input_feature_map_stride;
+ T* const out_ifm = out_col + ifm * out_input_feature_map_stride;
+
+ for (int ofm = 0; ofm < n_output_feature_maps; ofm++)
+ {
+ const T* const in_ofm = in_ifm + ofm;
+ T* const out_ofm = out_ifm + ofm * out_output_feature_map_stride;
+ *(out_ofm) = *(in_ofm);
+ }
+ }
+ }
+ }
+}
+
diff --git a/arm_compute/core/NEON/kernels/winograd/tensor.hpp b/arm_compute/core/NEON/kernels/winograd/tensor.hpp
new file mode 100644
index 0000000000..70ef65d2a5
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/tensor.hpp
@@ -0,0 +1,210 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include <cstdio>
+#include <cstdlib>
+#include <random>
+
+#include "alloc.hpp"
+
+/*****************************************************************************/
+/* Padding definitions */
+enum PaddingType {
+ PADDING_SAME, PADDING_VALID
+};
+
+/*****************************************************************************/
+/* Shape of a kernel */
+struct KernelShape {
+ int n_output_channels, n_rows, n_cols, n_input_channels;
+
+ int size(void) const {
+ return n_output_channels * n_rows * n_cols * n_input_channels;
+ }
+};
+
+struct Tensor4DShape {
+ int n_batches,
+ n_rows,
+ n_cols,
+ n_channels;
+
+ int size() const {
+ return n_batches * n_rows * n_cols * n_channels;
+ }
+
+ bool TestEq(const Tensor4DShape& other) const {
+ return (n_batches == other.n_batches &&
+ n_rows == other.n_rows &&
+ n_cols == other.n_cols &&
+ n_channels == other.n_channels);
+ }
+};
+
+template <typename ShapeT, typename T>
+class Tensor4D final {
+ public:
+ Tensor4D(ShapeT shape) :
+ _shape(shape),
+ _data(reinterpret_cast<T*>(ALLOCATE(size_bytes()))) {
+ Clear();
+ }
+
+ ~Tensor4D() {
+ free(_data);
+ }
+
+ T* ptr() const {
+ return _data;
+ }
+
+ const ShapeT& shape() const {
+ return _shape;
+ }
+
+ size_t size_bytes() const {
+ return _shape.size() * sizeof(T);
+ }
+
+ bool TestEq(Tensor4D<ShapeT, T>& other) const;
+ T& element(int, int, int, int) const;
+ void Print() const;
+
+ void Clear() {
+ Fill(static_cast<T>(0));
+ }
+
+ void Fill(T val) {
+ for (int i = 0; i < _shape.size(); i++)
+ _data[i] = val;
+ }
+
+ void TestPattern() {
+ for (int i = 0; i < _shape.size(); i++)
+ _data[i] = static_cast<T>(i);
+ }
+
+ void Rand(const int seed=2311) {
+ std::mt19937 gen(seed);
+ std::uniform_int_distribution<> dis(-50, +50);
+
+ for (int i = 0; i < _shape.size(); i++) {
+ _data[i] = static_cast<T>(dis(gen));
+ }
+ }
+ Tensor4D(const Tensor4D &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ Tensor4D &operator=(const Tensor4D &) = delete;
+ /** Allow instances of this class to be moved */
+ Tensor4D(Tensor4D &&) = default;
+ /** Allow instances of this class to be moved */
+ Tensor4D &operator=(Tensor4D &&) = default;
+
+
+ private:
+ const ShapeT _shape;
+ T* const _data;
+};
+
+
+template <>
+inline float& Tensor4D<Tensor4DShape, float>::element(int n, int i, int j, int c) const {
+ int index = ((n*_shape.n_rows + i)*_shape.n_cols + j)*_shape.n_channels + c;
+ return _data[index];
+}
+
+
+template <>
+inline float& Tensor4D<KernelShape, float>::element(int oc, int i, int j, int ic) const {
+ int index = ((i*_shape.n_cols + j)*_shape.n_input_channels + ic)*_shape.n_output_channels + oc;
+ return _data[index];
+}
+
+template <>
+inline bool Tensor4D<Tensor4DShape, float>::TestEq(Tensor4D<Tensor4DShape, float>& other) const {
+ // Test equivalence, printing errors
+ // First test the shapes are the same
+ if (!_shape.TestEq(other.shape())) {
+ printf("Tensors have different shapes.\n");
+ return false;
+ } else {
+ int incorrects = 0;
+
+ for (int n = 0; n < _shape.n_batches; n++) {
+ for (int i = 0; i < _shape.n_rows; i++) {
+ for (int j = 0; j < _shape.n_cols; j++) {
+ for (int c = 0; c < _shape.n_channels; c++) {
+ // Check elements for equivalence
+ const auto a = this->element(n, i, j, c);
+ const auto b = other.element(n, i, j, c);
+
+ if (a != b) {
+ printf("Difference at element {%d, %d, %d, %d}: %.3f != %.3f\n", n, i, j, c, a, b);
+
+ if (++incorrects > 100) {
+ printf("More than 100 incorrect values, stopping test.\n");
+ return false;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return incorrects == 0;
+ }
+}
+
+
+template <>
+inline void Tensor4D<Tensor4DShape, float>::Print() const {
+ for (int n = 0; n < _shape.n_batches; n++) {
+ for (int c = 0; c < _shape.n_channels; c++) {
+ for (int i = 0; i < _shape.n_rows; i++) {
+ for (int j = 0; j < _shape.n_cols; j++) {
+ printf("%5.2f ", element(n, i, j, c));
+ }
+ printf("\n");
+ }
+ printf("\n");
+ }
+ }
+}
+
+
+template <>
+inline void Tensor4D<KernelShape, float>::Print() const {
+ for (int oc = 0; oc < _shape.n_output_channels; oc++) {
+ for (int ic = 0; ic < _shape.n_input_channels; ic++) {
+ for (int i = 0; i < _shape.n_rows; i++) {
+ for (int j = 0; j < _shape.n_cols; j++) {
+ printf("%5.2f ", element(oc, i, j, ic));
+ }
+ printf("\n");
+ }
+ printf("\n");
+ }
+ }
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms.hpp b/arm_compute/core/NEON/kernels/winograd/transforms.hpp
new file mode 100644
index 0000000000..8546ee9e2e
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms.hpp
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "transforms/input_2x2_3x3.hpp"
+#include "transforms/kernel_2x2_3x3.hpp"
+#include "transforms/output_2x2_3x3.hpp"
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp
new file mode 100644
index 0000000000..7013c66ac0
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp
@@ -0,0 +1,638 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+#include "../tensor.hpp"
+
+namespace winograd {
+ /* Transform an input tensor into the Winograd domain.
+ */
+ template <typename T>
+ struct Winograd2x2_3x3GemmInput {
+ static void execute(
+ const T *inptr,
+ const Tensor4DShape& input_shape,
+ const PaddingType padding_type,
+ const int tile_M,
+ const int tile_N,
+ T *outptr_base,
+ const int matrix_stride,
+ const int matrix_batch_stride,
+ const int matrix_row_stride
+ );
+
+ static size_t bytes_read(const Tensor4DShape &input_shape,
+ const Tensor4DShape &output_shape) {
+ const int tile_rows = iceildiv(output_shape.n_rows, 2);
+ const int tile_cols = iceildiv(output_shape.n_cols, 2);
+ return input_shape.n_batches * tile_rows * (16 + 8*(tile_cols - 1)) * input_shape.n_channels * sizeof(T);
+ }
+
+ static int flops_performed(const Tensor4DShape &input_shape,
+ const Tensor4DShape &output_shape) {
+ const int tile_rows = iceildiv(output_shape.n_rows, 2);
+ const int tile_cols = iceildiv(output_shape.n_cols, 2);
+ return input_shape.n_batches * tile_rows * (32 + 24*(tile_cols - 1)) * input_shape.n_channels;
+ }
+
+ static size_t bytes_written(const Tensor4DShape &input_shape,
+ const Tensor4DShape &output_shape) {
+ const int tile_rows = iceildiv(output_shape.n_rows, 2);
+ const int tile_cols = iceildiv(output_shape.n_cols, 2);
+ const int M = input_shape.n_batches * tile_rows * tile_cols;
+ return 16 * M * input_shape.n_channels * sizeof(T);
+ }
+
+ protected:
+ template <const PaddingType padding, const int pad_bottom, const int pad_right>
+ static void process_tile_tensor(
+ const int tile_M, // Number of rows of tiles
+ const int tile_N, // Number of columns of tiles
+ int n_channels, // Number of input channels
+ const T* const input, // Base input pointer (appropriate to batch and channel)
+ const int input_row_stride, // Stride between rows of the input
+ const int input_col_stride, // Stride between columns of the input
+ T* const matrix, // 1st output matrix (appropriate to batch and channel)
+ const int matrix_stride, // Stride between matrices
+ const int matrix_row_stride // Stride between rows of the output matrix
+ );
+
+ template <const int pad_top, const int pad_left,
+ const int pad_bottom, const int pad_right,
+ const int proc_channels>
+ static void process_tile_row(
+ const int tile_N, // Number of tiles in the row
+ const T* const input, // Base input pointer (appropriate to batch, channel and row)
+ const int input_row_stride, // Stride between rows of the input
+ const int input_col_stride, // Stride between columns of the input
+ T* const matrix, // 1st output matrix (appropriate to batch, channel and row)
+ const int matrix_stride, // Stride between matrices
+ const int matrix_row_stride // Stride between rows of the output matrix
+ );
+ };
+
+ template <typename T>
+ struct Winograd2x2_3x3GemmInputChannelwise {
+ static void execute(
+ const T *inptr,
+ const Tensor4DShape& input_shape,
+ const PaddingType padding_type,
+ const int tile_M,
+ const int tile_N,
+ T *outptr_base,
+ const int matrix_stride,
+ const int matrix_batch_stride,
+ const int matrix_row_stride
+ );
+
+ static size_t bytes_read(const Tensor4DShape &input_shape,
+ const Tensor4DShape &output_shape) {
+ // We read as many bytes as we write
+ return bytes_written(input_shape, output_shape);
+ }
+
+ static int flops_performed(const Tensor4DShape &input_shape,
+ const Tensor4DShape &output_shape) {
+ const int tile_rows = iceildiv(output_shape.n_rows, 2);
+ const int tile_cols = iceildiv(output_shape.n_cols, 2);
+ return input_shape.n_batches * tile_rows * 32 * tile_cols * input_shape.n_channels;
+ }
+
+ static size_t bytes_written(const Tensor4DShape &input_shape,
+ const Tensor4DShape &output_shape) {
+ return winograd::Winograd2x2_3x3GemmInput<T>::bytes_written(input_shape, output_shape);
+ }
+
+ protected:
+ typedef void (*tilefunc)(int, const T*, int, int, T*, int);
+ template <const int pad_top,
+ const int pad_left,
+ const int pad_bottom,
+ const int pad_right>
+ static void process_tile(
+ int n_channels, // Number of channels in the tile
+ const T* const input_base,
+ const int input_row_stride,
+ const int input_col_stride,
+ T* const matrix_base,
+ const int matrix_stride
+ );
+
+ private:
+ template <const int pad_top,
+ const int pad_left,
+ const int pad_bottom,
+ const int pad_right,
+ const int proc_channels>
+ static void _process_tile(
+ int &n_channels, const T* &inptr,
+ const int input_row_stride, const int input_col_stride,
+ T* &outptr, const int matrix_stride
+ );
+ };
+}
+
+/*****************************************************************************/
+// Include specialised implementations here
+#include "input_2x2_3x3/a64_float.hpp"
+#include "input_2x2_3x3/a64_float_channelwise.hpp"
+/*****************************************************************************/
+
+/*****************************************************************************/
+template <typename T>
+void winograd::Winograd2x2_3x3GemmInput<T>::execute(
+ const T *inptr_base,
+ const Tensor4DShape& input_shape,
+ const PaddingType padding_type,
+ const int tile_M,
+ const int tile_N,
+ T *outptr_base,
+ const int matrix_stride,
+ const int matrix_batch_stride,
+ const int matrix_row_stride
+) {
+ // Select an appropriate matrix processing method for the shape and padding
+ // of the input tensor.
+ typedef void (*tensorfunc)(int, int, int, const T*, int, int, T*, int, int);
+ const auto process_tensor = [&padding_type, &input_shape] () -> tensorfunc {
+ if (padding_type == PADDING_VALID) {
+ const int pad_bottom = input_shape.n_rows % 2;
+ const int pad_right = input_shape.n_cols % 2;
+
+ if (pad_bottom == 0 && pad_right == 0) {
+ return process_tile_tensor<PADDING_VALID, 0, 0>;
+ } else if (pad_bottom == 0 && pad_right == 1) {
+ return process_tile_tensor<PADDING_VALID, 0, 1>;
+ } else if (pad_bottom == 1 && pad_right == 0) {
+ return process_tile_tensor<PADDING_VALID, 1, 0>;
+ } else if (pad_bottom == 1 && pad_right == 1) {
+ return process_tile_tensor<PADDING_VALID, 1, 1>;
+ }
+ } else { // PADDING_SAME
+ const int pad_bottom = 1 + input_shape.n_rows % 2;
+ const int pad_right = 1 + input_shape.n_cols % 2;
+
+ if (pad_bottom == 1 && pad_right == 1) {
+ return process_tile_tensor<PADDING_SAME, 1, 1>;
+ } else if (pad_bottom == 1 && pad_right == 2) {
+ return process_tile_tensor<PADDING_SAME, 1, 2>;
+ } else if (pad_bottom == 2 && pad_right == 1) {
+ return process_tile_tensor<PADDING_SAME, 2, 1>;
+ } else if (pad_bottom == 2 && pad_right == 2) {
+ return process_tile_tensor<PADDING_SAME, 2, 2>;
+ }
+ }
+
+ printf("%s::%u Uncovered case.\n", __FILE__, __LINE__);
+ exit(-1);
+ return NULL; // No function found
+ } ();
+
+ // Compute strides
+ const int input_row_stride = input_shape.n_cols * input_shape.n_channels;
+ const int input_col_stride = input_shape.n_channels;
+
+ // Process each batch of the tensor in turn.
+ for (int batch = 0; batch < input_shape.n_batches; batch++) {
+ // Work out pointers
+ const T *inptr = inptr_base + (batch * input_shape.n_rows *
+ input_shape.n_cols * input_shape.n_channels);
+ T *outptr = outptr_base + batch * matrix_batch_stride;
+
+ // Delegate doing the actual work
+ process_tensor(
+ tile_M, tile_N, input_shape.n_channels,
+ inptr, input_row_stride, input_col_stride,
+ outptr, matrix_stride, matrix_row_stride
+ );
+ }
+}
+
+/*****************************************************************************/
+template <typename T>
+template <const PaddingType padding, const int pad_bottom, const int pad_right>
+void winograd::Winograd2x2_3x3GemmInput<T>::process_tile_tensor(
+ const int tile_M, // Number of rows of tiles
+ const int tile_N, // Number of columns of tiles
+ int n_channels, // Number of input channels
+ const T* const input, // Base input pointer (appropriate to batch and channel)
+ const int input_row_stride, // Stride between rows of the input
+ const int input_col_stride, // Stride between columns of the input
+ T* const matrix, // 1st output matrix (appropriate to batch and channel)
+ const int matrix_stride, // Stride between matrices
+ const int matrix_row_stride // Stride between rows of the output matrix
+) {
+ // Base row processing functions
+ typedef void (*rowfunc)(int, const T*, int, int, T*, int, int);
+ const rowfunc process_top_row[3] = {
+ (padding == PADDING_VALID)
+ ? process_tile_row<0, 0, 0, pad_right, 1>
+ : process_tile_row<1, 1, 0, pad_right, 1>,
+ (padding == PADDING_VALID)
+ ? process_tile_row<0, 0, 0, pad_right, 2>
+ : process_tile_row<1, 1, 0, pad_right, 2>,
+ (padding == PADDING_VALID)
+ ? process_tile_row<0, 0, 0, pad_right, 4>
+ : process_tile_row<1, 1, 0, pad_right, 4>,
+ };
+ const rowfunc process_middle_row[3] = {
+ (padding == PADDING_VALID)
+ ? process_tile_row<0, 0, 0, pad_right, 1>
+ : process_tile_row<0, 1, 0, pad_right, 1>,
+ (padding == PADDING_VALID)
+ ? process_tile_row<0, 0, 0, pad_right, 2>
+ : process_tile_row<0, 1, 0, pad_right, 2>,
+ (padding == PADDING_VALID)
+ ? process_tile_row<0, 0, 0, pad_right, 4>
+ : process_tile_row<0, 1, 0, pad_right, 4>,
+ };
+ const rowfunc process_bottom_row[3] = {
+ (padding == PADDING_VALID)
+ ? process_tile_row<0, 0, pad_bottom, pad_right, 1>
+ : process_tile_row<0, 1, pad_bottom, pad_right, 1>,
+ (padding == PADDING_VALID)
+ ? process_tile_row<0, 0, pad_bottom, pad_right, 2>
+ : process_tile_row<0, 1, pad_bottom, pad_right, 2>,
+ (padding == PADDING_VALID)
+ ? process_tile_row<0, 0, pad_bottom, pad_right, 4>
+ : process_tile_row<0, 1, pad_bottom, pad_right, 4>,
+ };
+
+ // Method to get an input pointer for the given tile row
+ const auto get_inptr = [&input, &input_row_stride] (const int tile_i) {
+ if (padding == PADDING_VALID) {
+ return input + 2 * tile_i * input_row_stride;
+ } else {
+ return input + (2 * tile_i - (tile_i ? 1 : 0)) * input_row_stride;
+ }
+ };
+
+ // Wrapper to process a row of tiles, covering all channels.
+ const auto process_row =
+ [tile_N, input_row_stride, input_col_stride, matrix_stride, matrix_row_stride, n_channels]
+ (const rowfunc f[3], const T *inptr, T *outptr) {
+ int rem_channels = n_channels;
+
+ // While there remain channels to process continue to process the
+ // row.
+ for (; rem_channels >= 4; rem_channels -= 4, inptr += 4, outptr += 4) {
+ f[2](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride);
+ }
+ for (; rem_channels >= 2; rem_channels -= 2, inptr += 2, outptr += 2) {
+ f[1](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride);
+ }
+ if (rem_channels) {
+ f[0](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride);
+ }
+ };
+
+ // Process all rows of tiles in the tensor
+ for (int tile_i = 0; tile_i < tile_M; tile_i++) {
+ T* const m_row = matrix + tile_i * tile_N * matrix_row_stride;
+ const T *row_inptr = get_inptr(tile_i);
+
+ if (tile_i == 0) {
+ // Top row of the input
+ process_row(process_top_row, row_inptr, m_row);
+ } else if (tile_i == tile_M - 1) {
+ // Bottom row of the input
+ process_row(process_bottom_row, row_inptr, m_row);
+ } else {
+ // Any other row of the input
+ process_row(process_middle_row, row_inptr, m_row);
+ }
+ }
+}
+
+/*****************************************************************************/
+template <typename T>
+template <const int pad_top, const int pad_left,
+ const int pad_bottom, const int pad_right,
+ const int proc_channels>
+void winograd::Winograd2x2_3x3GemmInput<T>::process_tile_row(
+ const int tile_N, // Number of tiles in the row
+ const T* const input, // Base input pointer (appropriate to batch, channel and row)
+ const int input_row_stride, // Stride between rows of the input
+ const int input_col_stride, // Stride between columns of the input
+ T* const matrix, // 1st output matrix (appropriate to batch, channel and row)
+ const int matrix_stride, // Stride between matrices
+ const int matrix_row_stride // Stride between rows of the output matrix
+) {
+ // Construct copies of the pointers
+ const T *inptr = input;
+ T *outptr = matrix;
+
+ // Storage for the tensors x, X.T x, and X.T x X.
+ T x[4][4][proc_channels], XTx[4][4][proc_channels], XTxX[4][4][proc_channels];
+
+ // For every tile in the row
+ for (int tile_j = 0; tile_j < tile_N; tile_j++) {
+ // Determine the padding for the tile
+ const int tile_pad_left = (tile_j == 0) ? pad_left : 0;
+ const int tile_pad_right = (tile_j == tile_N - 1) ? pad_right : 0;
+
+ // Load tile values. If this is the first tile in the row then we must load
+ // all values, otherwise we can just load the final two columns of the input.
+ for (int i = 0; i < 4; i++) {
+ for (int j = ((tile_j == 0) ? 0 : 2); j < 4; j++) {
+ // Fill with padding if required
+ if (i < pad_top || 4 - pad_bottom <= i ||
+ j < tile_pad_left || 4 - tile_pad_right <= j) {
+ for (int c = 0; c < proc_channels; c++) {
+ x[i][j][c] = static_cast<T>(0); // Padding
+ }
+ } else {
+ // Load values, note that the initial padding offsets the pointer we
+ // were provided.
+ for (int c = 0; c < proc_channels; c++) {
+ const int row_offset = (i - pad_top) * input_row_stride;
+ const int col_offset = (j - tile_pad_left) * input_col_stride;
+ x[i][j][c] = inptr[row_offset + col_offset + c];
+ }
+ }
+ }
+ }
+
+ // Compute the matrix X.T x. Note, can elide operations depending on the
+ // padding. Furthermore, if this isn't the left-most tile we can skip half
+ // of the operations by copying results from the previous version of X.T x.
+ // This latter optimisation can be simplified by unrolling the outermost
+ // loop by two and by renaming the registers containing XTx.
+ if (tile_j == 0) {
+ for (int j = 0; j < 4; j++) {
+ for (int c = 0; c < proc_channels; c++) {
+ XTx[0][j][c] = x[0][j][c] - x[2][j][c];
+ XTx[1][j][c] = x[1][j][c] + x[2][j][c];
+ XTx[2][j][c] = -x[1][j][c] + x[2][j][c];
+ XTx[3][j][c] = x[1][j][c] - x[3][j][c];
+ }
+ }
+ } else {
+ for (int j = 0; j < 2; j++) {
+ for (int c = 0; c < proc_channels; c++) {
+ XTx[0][j][c] = XTx[0][j + 2][c];
+ XTx[1][j][c] = XTx[1][j + 2][c];
+ XTx[2][j][c] = XTx[2][j + 2][c];
+ XTx[3][j][c] = XTx[3][j + 2][c];
+ }
+ }
+ for (int j = 2; j < 4; j++) {
+ for (int c = 0; c < proc_channels; c++) {
+ XTx[0][j][c] = x[0][j][c] - x[2][j][c];
+ XTx[1][j][c] = x[1][j][c] + x[2][j][c];
+ XTx[2][j][c] = -x[1][j][c] + x[2][j][c];
+ XTx[3][j][c] = x[1][j][c] - x[3][j][c];
+ }
+ }
+ }
+
+ // Compute the matrix X.T x X. Note, can elide operations based on the
+ // padding.
+ for (int i = 0; i < 4; i++) {
+ for (int c = 0; c < proc_channels; c++) {
+ XTxX[i][0][c] = XTx[i][0][c] - XTx[i][2][c];
+ XTxX[i][1][c] = XTx[i][1][c] + XTx[i][2][c];
+ XTxX[i][2][c] = -XTx[i][1][c] + XTx[i][2][c];
+ XTxX[i][3][c] = XTx[i][1][c] - XTx[i][3][c];
+ }
+ }
+
+ // Store the output matrix (X.T x X)
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ // Get a pointer to the relevant output matrix
+ T *mptr = outptr + (i*4 + j)*matrix_stride;
+
+ // Write out the channels
+ for (int c = 0; c < proc_channels; c++) {
+ mptr[c] = XTxX[i][j][c];
+ }
+ }
+ }
+
+ // Update the pointers
+ inptr += input_col_stride * ((tile_j == 0 && pad_left) ? 1 : 2);
+ outptr += matrix_row_stride;
+ }
+}
+
+/*****************************************************************************/
+template <typename T>
+void winograd::Winograd2x2_3x3GemmInputChannelwise<T>::execute(
+ const T *inptr,
+ const Tensor4DShape& input_shape,
+ const PaddingType padding_type,
+ const int tile_M,
+ const int tile_N,
+ T *outptr_base,
+ const int matrix_stride,
+ const int matrix_batch_stride,
+ const int matrix_row_stride
+) {
+ const int n_channels = input_shape.n_channels;
+ const int input_col_stride = n_channels;
+ const int input_row_stride = input_shape.n_cols * input_col_stride;
+
+ // Determine the padding and hence select appropriate methods for each tile.
+ tilefunc fs[3][3];
+
+ if (padding_type == PADDING_VALID) {
+ constexpr int pad_top = 0;
+ constexpr int pad_left = 0;
+ const int pad_right = input_shape.n_cols % 2 == 0;
+
+ fs[0][0] = process_tile<pad_top, pad_left, 0, 0>;
+ fs[0][1] = process_tile<pad_top, 0, 0, 0>;
+ fs[0][2] = (pad_right) ? process_tile<pad_top, 0, 0, 0> : process_tile<pad_top, 0, 0, 1>;
+
+ fs[1][0] = process_tile<0, pad_left, 0, 0>;
+ fs[1][1] = process_tile<0, 0, 0, 0>;
+ fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 0> : process_tile<0, 0, 0, 1>;
+
+ if (input_shape.n_rows % 2 == 0) {
+ constexpr int pad_bottom = 0;
+ fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>;
+ fs[2][1] = process_tile<0, 0, pad_bottom, 0>;
+ fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>;
+ } else {
+ constexpr int pad_bottom = 1;
+ fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>;
+ fs[2][1] = process_tile<0, 0, pad_bottom, 0>;
+ fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>;
+ }
+ } else {
+ constexpr int pad_top = 1;
+ constexpr int pad_left = 1;
+ const int pad_right = input_shape.n_cols % 2 == 0;
+
+ fs[0][0] = process_tile<pad_top, pad_left, 0, 0>;
+ fs[0][1] = process_tile<pad_top, 0, 0, 0>;
+ fs[0][2] = (pad_right) ? process_tile<pad_top, 0, 0, 1> : process_tile<pad_top, 0, 0, 2>;
+
+ fs[1][0] = process_tile<0, pad_left, 0, 0>;
+ fs[1][1] = process_tile<0, 0, 0, 0>;
+ fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 1> : process_tile<0, 0, 0, 2>;
+
+ if (input_shape.n_rows % 2 == 0) {
+ constexpr int pad_bottom = 1;
+ fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>;
+ fs[2][1] = process_tile<0, 0, pad_bottom, 0>;
+ fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>;
+ } else {
+ constexpr int pad_bottom = 2;
+ fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>;
+ fs[2][1] = process_tile<0, 0, pad_bottom, 0>;
+ fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>;
+ }
+ }
+
+ // Process each tile in turn
+ for (int batch = 0; batch < input_shape.n_batches; batch++) {
+ const T* const input_base_batch = inptr + batch*input_shape.n_rows*input_shape.n_cols*n_channels;
+
+ for (int tile_i = 0; tile_i < tile_M; tile_i++) {
+ const int row_offset = (tile_i == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1);
+ const T* const input_base_row = input_base_batch + (2*tile_i - row_offset)*input_shape.n_cols*n_channels;
+
+ // Select the set of functions for the row
+ const int fs_i = (tile_i == 0) ? 0 : ((tile_i < tile_M - 1) ? 1 : 2);
+
+ for (int tile_j = 0; tile_j < tile_N; tile_j++) {
+ // Select the function for the column
+ const int fs_j = (tile_j == 0) ? 0 : ((tile_j < tile_N - 1) ? 1 : 2);
+ const auto f = fs[fs_i][fs_j];
+
+ // Get pointers into the input and outputs
+ const int col_offset = (tile_j == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1);
+ const T* const input_base_col = input_base_row + (2*tile_j - col_offset)*n_channels;
+ T* const matrix_base = outptr_base + batch*matrix_batch_stride + (tile_i*tile_N + tile_j)*matrix_row_stride;
+ f(n_channels, input_base_col, input_row_stride, input_col_stride,
+ matrix_base, matrix_stride);
+ }
+ }
+ }
+}
+
+template <typename T>
+template <const int pad_top,
+ const int pad_left,
+ const int pad_bottom,
+ const int pad_right>
+void winograd::Winograd2x2_3x3GemmInputChannelwise<T>::process_tile(
+ int n_channels, // Number of channels in the tile
+ const T* const input_base,
+ const int input_row_stride,
+ const int input_col_stride,
+ T* const matrix_base,
+ const int matrix_stride
+) {
+ // Copy pointers
+ const T *inptr = input_base;
+ T *outptr = matrix_base;
+
+ // Process channels (modifies inptr, outptr and n_channels)
+ _process_tile<pad_top, pad_left, pad_bottom, pad_right, 4>(
+ n_channels, inptr, input_row_stride, input_col_stride,
+ outptr, matrix_stride
+ );
+ _process_tile<pad_top, pad_left, pad_bottom, pad_right, 2>(
+ n_channels, inptr, input_row_stride, input_col_stride,
+ outptr, matrix_stride
+ );
+ _process_tile<pad_top, pad_left, pad_bottom, pad_right, 1>(
+ n_channels, inptr, input_row_stride, input_col_stride,
+ outptr, matrix_stride
+ );
+}
+
+template <typename T>
+template <const int pad_top,
+ const int pad_left,
+ const int pad_bottom,
+ const int pad_right,
+ const int proc_channels>
+void winograd::Winograd2x2_3x3GemmInputChannelwise<T>::_process_tile(
+ int &n_channels,
+ const T* &inptr, const int input_row_stride, const int input_col_stride,
+ T* &outptr, const int matrix_stride
+) {
+ // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
+ // offsets to access the intermediate matrices.
+ T* outptrs[4] = {
+ outptr,
+ outptr + matrix_stride * 4,
+ outptr + matrix_stride * 8,
+ outptr + matrix_stride * 12
+ };
+
+ // The matrix X; zeroed to account for padding.
+ T x[4][4];
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ x[i][j] = 0;
+ }
+ }
+
+ // The matrices X.T x and U
+ T XTx[4][4], U[4][4];
+
+ // Now progress through each channel
+ for (; n_channels >= proc_channels; n_channels -= proc_channels) {
+ for (int n = 0; n < proc_channels; n++) {
+ // Load the matrix X
+ for (int cell_i = pad_top, i = 0; cell_i < 4 - pad_bottom; cell_i++, i++) {
+ for (int cell_j = pad_left, j = 0; cell_j < 4 - pad_right; cell_j++, j++) {
+ x[cell_i][cell_j] = inptr[i*input_row_stride + j*input_col_stride];
+ }
+ }
+ inptr++;
+
+ // Compute the matrix X.T
+ for (int j = 0; j < 4; j++) {
+ XTx[0][j] = x[0][j] - x[2][j];
+ XTx[1][j] = x[1][j] + x[2][j];
+ XTx[2][j] = x[2][j] - x[1][j];
+ XTx[3][j] = x[1][j] - x[3][j];
+ }
+
+ // Hence compute the matrix U
+ for (int i = 0; i < 4; i++) {
+ U[i][0] = XTx[i][0] - XTx[i][2];
+ U[i][1] = XTx[i][1] + XTx[i][2];
+ U[i][2] = XTx[i][2] - XTx[i][1];
+ U[i][3] = XTx[i][1] - XTx[i][3];
+ }
+
+ // Store the matrix U
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ outptrs[i][j * matrix_stride] = U[i][j];
+ }
+ outptrs[i]++;
+ }
+ }
+ }
+
+ // Update the output pointer for future calls
+ outptr = outptrs[0];
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp
new file mode 100644
index 0000000000..a99cbe325b
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp
@@ -0,0 +1,1498 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+#include "../input_2x2_3x3.hpp"
+
+#ifdef __aarch64__
+namespace winograd {
+
+// Pad left by one column, pad right by one column, no upper or lower padding, 4 channels
+template <>
+template <>
+inline void Winograd2x2_3x3GemmInput<float>::process_tile_row<0, 1, 0, 1, 4>(
+ const int tile_N, // Number of tiles in the row
+ const float* const input, // Base input pointer (appropriate to batch, channel and row)
+ const int input_row_stride, // Stride between rows of the input
+ const int input_col_stride, // Stride between columns of the input
+ float* const matrix, // 1st output matrix (appropriate to batch, channel and row)
+ const int matrix_stride, // Stride between matrices
+ const int matrix_row_stride // Stride between rows of the output matrix
+) {
+ /* SIMD register allocation
+ * ========================
+ *
+ * In the following code we read 4x4 tiles of a matrix `x`, with which we
+ * compute another matrix `X.T x` where:
+ *
+ * / 1 0 0 0 \
+ * X = | 0 1 -1 1 |
+ * | -1 1 1 0 |
+ * \ 0 0 0 -1 /
+ *
+ * Hence, `X.T` is a program which operates upon rows of the matrix `X`.
+ * We subsequently compute and store the matrix `U = (X.T x) X`.
+ *
+ * Importantly, each iteration of the loop below loads a new matrix `x'`
+ * where the final two columns of `x'` are the first two columns of the
+ * previous `x`. That is:
+ *
+ * x11 x12 x13 x14
+ * x21 x22 x23 x24
+ * x31 x32 x33 x34
+ * x41 x42 x43 x44
+ *
+ * x'11 x'12 x'13 x'14
+ * x'21 x'22 x'23 x'24
+ * x'31 x'32 x'33 x'34
+ * x'41 x'42 x'43 x'44
+ *
+ * Consequently, while the first iteration of the below loop must load 16
+ * values for `x`, the second need load only 8. *Furthermore*, since we noted
+ * above that the operation `X.T x` was a program which operated upon *rows*
+ * of the matrix `x` it follows that that the relation that `x'[i][1] =
+ * x[i][3]` and `x'[i][2] = x[i][4]` applies also the matrices `X.T x'` and
+ * `X.T x`. That is:
+ *
+ * (X.T x)11 (X.T x)12 (X.T x)13 (X.T x)14
+ * (X.T x)21 (X.T x)22 (X.T x)23 (X.T x)24
+ * (X.T x)31 (X.T x)32 (X.T x)33 (X.T x)34
+ * (X.T x)41 (X.T x)42 (X.T x)43 (X.T x)44
+ *
+ * (X.T x')11 (X.T x')12 (X.T x')13 (X.T x')14
+ * (X.T x')12 (X.T x')12 (X.T x')12 (X.T x')12
+ * (X.T x')13 (X.T x')13 (X.T x')13 (X.T x')13
+ * (X.T x')14 (X.T x')14 (X.T x')14 (X.T x')14
+ *
+ * Hence, as well as not needing to load new values for x'[i][1..2] it is
+ * also unnecessary to recompute values for (X.T x')[i][1..2].
+ *
+ * Following this we break the registers into blocks `A` and `B` used by the
+ * two stages of the unrolled loop. These registers named such that the
+ * latter columns of `A` become the earlier columns of `B` and vice-versa:
+ *
+ * AXTx11 AXTx12 > AXTx13 AXTx14 |
+ * AXTx21 AXTx22 > AXTx23 AXTx24 |
+ * AXTx31 AXTx32 > AXTx33 AXTx34 |
+ * AXTx41 AXTx42 > AXTx43 AXTx44 |
+ *
+ * BXTx13 BXTx14 | BXTx11 BXTx12 >
+ * BXTx23 BXTx24 | BXTx21 BXTx22 >
+ * BXTx33 BXTx34 | BXTx31 BXTx32 >
+ * BXTx43 BXTx44 | BXTx41 BXTx42 >
+ *
+ * These 32 named registers require only 16 architectural registers. 1
+ * additional architectural register is used as scratch space and 8
+ * architectural registers are used to load in the values x[1..4][3,4].
+ *
+ * Input and output addressing
+ * ===========================
+ * TODO Description
+ */
+ const float *inptr0 = input;
+ const float *inptr1 = input + input_row_stride;
+ const float *inptr2 = input + input_row_stride * 2;
+ const float *inptr3 = input + input_row_stride * 3;
+
+ float *outptr0 = matrix;
+ float *outptr4 = matrix + matrix_stride * 4;
+ float *outptr8 = matrix + matrix_stride * 8;
+ float *outptr12 = matrix + matrix_stride * 12;
+
+ int tile_j = tile_N; // Tiles to process
+
+ asm volatile (
+ // Named SIMD registers according to the policy given above
+ // Registers into which to load the latter two columns of `x`
+ "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n"
+ "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n"
+ "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n"
+ "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n"
+
+ // Registers for storing X.T x (both A and B halves)
+ "AXTx11 .req v8\n" "BXTx13 .req v8\n"
+ "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n"
+ "AXTx21 .req v10\n" "BXTx23 .req v10\n"
+ "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n"
+ "AXTx31 .req v12\n" "BXTx33 .req v12\n"
+ "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n"
+ "AXTx41 .req v14\n" "BXTx43 .req v14\n"
+ "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n"
+ "AXTx13 .req v16\n" "BXTx11 .req v16\n"
+ "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n"
+ "AXTx23 .req v18\n" "BXTx21 .req v18\n"
+ "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n"
+ "AXTx33 .req v20\n" "BXTx31 .req v20\n"
+ "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n"
+ "AXTx43 .req v22\n" "BXTx41 .req v22\n"
+ "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n"
+
+ // Result register (TODO Does using more registers yield better
+ // performance)
+ "U .req v24\n qU .req q24\n"
+
+ // ----------------------------------------------------------------------
+ // Head of loop
+ // Loads a complete 4x4 tile of x, computes X.T x, computes and stores
+ // `U = X.T x X`. Prepares for the 'A' half of the loop.
+ // NOTE: Since the first tile has the leftmost column padded we can
+ // skip 4 loads and 4 calculations for the matrix X.T x X.
+
+ // Temporarily alias registers for computing the first (non-padded)
+ // column of x.
+ "x_12 .req v0\n qx_12 .req q0\n"
+ "x_22 .req v1\n qx_22 .req q1\n"
+ "x_32 .req v2\n qx_32 .req q2\n"
+ "x_42 .req v3\n qx_42 .req q3\n"
+
+ "ldr qx_12, [%x[inptr0]]\n"
+ "ldr qx_22, [%x[inptr1]]\n"
+ "ldr qx_32, [%x[inptr2]]\n"
+ "ldr qx_42, [%x[inptr3]]\n"
+
+ "fsub BXTx12.4s, x_12.4s, x_32.4s\n"
+ "fadd BXTx22.4s, x_22.4s, x_32.4s\n"
+ "fsub BXTx32.4s, x_32.4s, x_22.4s\n"
+ "fsub BXTx42.4s, x_22.4s, x_42.4s\n"
+
+ ".unreq x_12\n .unreq qx_12\n"
+ ".unreq x_22\n .unreq qx_22\n"
+ ".unreq x_32\n .unreq qx_32\n"
+ ".unreq x_42\n .unreq qx_42\n"
+
+ // Load and compute latter two columns of the first tile. Progress the
+ // input pointers (by three columns so that the each points are the
+ // second column of the next tile, that is, each points at the first
+ // column which must be read for the next tile.
+ "ldr qx_13, [%x[inptr0], %x[colstride1]]\n"
+ "ldr qx_23, [%x[inptr1], %x[colstride1]]\n"
+ "ldr qx_33, [%x[inptr2], %x[colstride1]]\n"
+ "ldr qx_43, [%x[inptr3], %x[colstride1]]\n"
+
+ "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
+ "ldr qx_14, [%x[inptr0], %x[colstride2]]\n"
+
+ "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
+ "ldr qx_24, [%x[inptr1], %x[colstride2]]\n"
+
+ "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
+ "ldr qx_34, [%x[inptr2], %x[colstride2]]\n"
+
+ "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
+ "ldr qx_44, [%x[inptr3], %x[colstride2]]\n"
+
+ "fsub BXTx14.4s, x_14.4s, x_34.4s\n"
+ "add %x[inptr0], %x[inptr0], %x[colstride3]\n"
+
+ "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
+ "add %x[inptr1], %x[inptr1], %x[colstride3]\n"
+
+ "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], %x[colstride3]\n"
+
+ "fsub BXTx44.4s, x_24.4s, x_44.4s\n"
+ "add %x[inptr3], %x[inptr3], %x[colstride3]\n"
+
+ // Compute and store U for the first tile
+ // First row
+ "fneg U.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
+
+ // Second row
+ "fneg U.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
+
+ // Third row
+ "fneg U.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
+
+ // Fourth row, simultaneously load the first column of inputs for the
+ // next tile.
+ "fneg U.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "ldr qx_13, [%x[inptr0]]\n"
+
+ "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "ldr qx_23, [%x[inptr1]]\n"
+
+ "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "ldr qx_33, [%x[inptr2]]\n"
+
+ "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+ "ldr qx_43, [%x[inptr3]]\n"
+
+ "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
+
+ // Update the loop counter, subtract two to account for both the head and
+ // the tail.
+ "subs %x[tile_j], %x[tile_j], #2\n"
+ "beq 2f\n" // Jump to "A" tail if out of tiles
+
+ // ----------------------------------------------------------------------
+ "1:"
+ // Start part A
+ // Load last column of this tile (the first column has already been
+ // loaded) and compute latter two columns of X.T x.
+ "fsub AXTx13.4s, x_13.4s, x_33.4s\n"
+ "ldr qx_14, [%x[inptr0], %x[colstride1]]\n"
+ "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
+ "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
+ "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
+ "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
+ "fsub AXTx43.4s, x_23.4s, x_43.4s\n"
+ "ldr qx_44, [%x[inptr3], %x[colstride1]]\n"
+ "fsub AXTx14.4s, x_14.4s, x_34.4s\n"
+ "add %x[inptr0], %x[inptr0], %x[colstride2]\n"
+ "fadd AXTx24.4s, x_24.4s, x_34.4s\n"
+ "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
+ "fsub AXTx34.4s, x_34.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
+ "fsub AXTx44.4s, x_24.4s, x_44.4s\n"
+ "add %x[inptr3], %x[inptr3], %x[colstride2]\n"
+
+ // Compute and store U.
+ // First row
+ "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, AXTx12.4s, AXTx14.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
+
+ // Second row
+ "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fsub U.4s, AXTx22.4s, AXTx24.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
+
+ // Third row
+ "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, AXTx32.4s, AXTx34.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
+
+ // Fourth row
+ "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "ldr qx_13, [%x[inptr0]]\n"
+
+ "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "ldr qx_23, [%x[inptr1]]\n"
+
+ "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "ldr qx_33, [%x[inptr2]]\n"
+
+ "fsub U.4s, AXTx42.4s, AXTx44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+ "ldr qx_43, [%x[inptr3]]\n"
+
+ "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
+
+ "subs %x[tile_j], %x[tile_j], #1\n"
+ "beq 3f\n" // Jump to 'B' tail
+
+ // Start part B
+ // Load last column of this tile (the first column has already been
+ // loaded) and compute latter two columns of X.T x.
+ "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
+ "ldr qx_14, [%x[inptr0], %x[colstride1]]\n"
+ "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
+ "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
+ "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
+ "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
+ "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
+ "ldr qx_44, [%x[inptr3], %x[colstride1]]\n"
+ "fsub BXTx14.4s, x_14.4s, x_34.4s\n"
+ "add %x[inptr0], %x[inptr0], %x[colstride2]\n"
+ "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
+ "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
+ "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
+ "fsub BXTx44.4s, x_24.4s, x_44.4s\n"
+ "add %x[inptr3], %x[inptr3], %x[colstride2]\n"
+
+ // Compute and store U.
+ // First row
+ "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
+
+ // Second row
+ "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
+
+ // Third row
+ "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
+
+ // Fourth row
+ "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "ldr qx_13, [%x[inptr0]]\n"
+
+ "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "ldr qx_23, [%x[inptr1]]\n"
+
+ "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "ldr qx_33, [%x[inptr2]]\n"
+
+ "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+ "ldr qx_43, [%x[inptr3]]\n"
+
+ "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
+ "subs %x[tile_j], %x[tile_j], #1\n"
+ "bne 1b\n" // Continue loop, otherwise flow into 'A' tail
+
+ // ----------------------------------------------------------------------
+ "2:"
+ // 'A' tail
+ // Since the final column is padding and the last-but-one column has
+ // already been loaded just compute the 3rd column of `X.T x'.
+ "fsub AXTx13.4s, x_13.4s, x_33.4s\n"
+ "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
+ "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
+ "fsub AXTx43.4s, x_23.4s, x_43.4s\n"
+
+ // Compute and store U. Modified to account for the final column of X.T
+ // x containing padding. Note, it is also unnecessary to update the
+ // output pointers.
+ // First row
+ "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "str qAXTx12, [%x[outptr0], %x[mstride3]]\n"
+
+ // Second row
+ "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "str qAXTx22, [%x[outptr4], %x[mstride3]]\n"
+
+ // Third row
+ "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "str qAXTx32, [%x[outptr8], %x[mstride3]]\n"
+
+ // Fourth row
+ "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "str qAXTx42, [%x[outptr12], %x[mstride3]]\n"
+
+ "b 4f\n" // Jump to end of function
+
+ // ----------------------------------------------------------------------
+ "3:"
+ // 'B' tail
+ // Since the final column is padding and the last-but-one column has
+ // already been loaded just compute the 3rd column of `X.T x'.
+ "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
+ "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
+ "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
+ "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
+
+ // Compute and store U. Modified to account for the final column of X.T
+ // x containing padding. Note, it is also unnecessary to update the
+ // output pointers.
+ // First row
+ "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "str qBXTx12, [%x[outptr0], %x[mstride3]]\n"
+
+ // Second row
+ "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "str qBXTx22, [%x[outptr4], %x[mstride3]]\n"
+
+ // Third row
+ "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "str qBXTx32, [%x[outptr8], %x[mstride3]]\n"
+
+ // Fourth row
+ "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "str qBXTx42, [%x[outptr12], %x[mstride3]]\n"
+
+ // ----------------------------------------------------------------------
+ "4:"
+ // End of function
+
+ // Clear names
+ ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n"
+ ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n"
+ ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n"
+ ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n"
+ ".unreq AXTx11\n" ".unreq BXTx13\n"
+ ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n"
+ ".unreq AXTx21\n" ".unreq BXTx23\n"
+ ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n"
+ ".unreq AXTx31\n" ".unreq BXTx33\n"
+ ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n"
+ ".unreq AXTx41\n" ".unreq BXTx43\n"
+ ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n"
+ ".unreq AXTx13\n" ".unreq BXTx11\n"
+ ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n"
+ ".unreq AXTx23\n" ".unreq BXTx21\n"
+ ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n"
+ ".unreq AXTx33\n" ".unreq BXTx31\n"
+ ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n"
+ ".unreq AXTx43\n" ".unreq BXTx41\n"
+ ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n"
+ ".unreq U\n" ".unreq qU\n"
+ : [inptr0] "+r" (inptr0),
+ [inptr1] "+r" (inptr1),
+ [inptr2] "+r" (inptr2),
+ [inptr3] "+r" (inptr3),
+ [outptr0] "+r" (outptr0),
+ [outptr4] "+r" (outptr4),
+ [outptr8] "+r" (outptr8),
+ [outptr12] "+r" (outptr12),
+ [tile_j] "+r" (tile_j) // Tile counter
+ : [colstride1] "r" (1 * input_col_stride * sizeof(float)),
+ [colstride2] "r" (2 * input_col_stride * sizeof(float)),
+ [colstride3] "r" (3 * input_col_stride * sizeof(float)),
+ [mstride1] "r" (1 * matrix_stride * sizeof(float)),
+ [mstride2] "r" (2 * matrix_stride * sizeof(float)),
+ [mstride3] "r" (3 * matrix_stride * sizeof(float)),
+ [matrix_row_stride] "r" (matrix_row_stride * sizeof(float))
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
+ "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
+ "v22", "v23", "v24"
+ );
+}
+
+// Pad top, left and right by 1.
+template <>
+template <>
+inline void Winograd2x2_3x3GemmInput<float>::process_tile_row<1, 1, 0, 1, 4>(
+ const int tile_N,
+ const float* const input,
+ const int input_row_stride,
+ const int input_col_stride,
+ float* const matrix,
+ const int matrix_stride,
+ const int matrix_row_stride
+) {
+ const float *inptr0 = input;
+ const float *inptr1 = input + input_row_stride;
+ const float *inptr2 = input + input_row_stride * 2;
+
+ float *outptr0 = matrix;
+ float *outptr4 = matrix + matrix_stride * 4;
+ float *outptr8 = matrix + matrix_stride * 8;
+ float *outptr12 = matrix + matrix_stride * 12;
+
+ int tile_j = tile_N; // Tiles to process
+
+ asm volatile (
+ // Named SIMD registers according to the policy given above
+ // Registers into which to load the latter two columns of `x`
+ // NOTE: We need only load the latter three rows since we know that the
+ // first row is padded.
+ "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n"
+ "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n"
+ "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n"
+
+ // Registers for storing X.T x (both A and B halves)
+ "AXTx11 .req v8\n" "BXTx13 .req v8\n"
+ "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n"
+ "AXTx21 .req v10\n" "BXTx23 .req v10\n"
+ "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n"
+ "AXTx31 .req v12\n" "BXTx33 .req v12\n"
+ "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n"
+ "AXTx41 .req v14\n" "BXTx43 .req v14\n"
+ "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n"
+ "AXTx13 .req v16\n" "BXTx11 .req v16\n"
+ "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n"
+ "AXTx23 .req v18\n" "BXTx21 .req v18\n"
+ "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n"
+ "AXTx33 .req v20\n" "BXTx31 .req v20\n"
+ "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n"
+ "AXTx43 .req v22\n" "BXTx41 .req v22\n"
+ "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n"
+
+ // Result register (TODO Does using more registers yield better
+ // performance)
+ "U .req v24\n qU .req q24\n"
+
+ // ----------------------------------------------------------------------
+ // Head of loop
+ // Loads a complete 4x4 tile of x, computes X.T x, computes and stores
+ // `U = X.T x X`. Prepares for the 'A' half of the loop.
+ // NOTE: Since the first tile has the leftmost column padded we can
+ // skip 4 loads and 4 calculations for the matrix X.T x X.
+
+ // Temporarily alias registers for computing the first (non-padded)
+ // column of x.
+ "x_22 .req v1\n qx_22 .req q1\n"
+ "x_32 .req v2\n qx_32 .req q2\n"
+ "x_42 .req v3\n qx_42 .req q3\n"
+
+ "ldr qx_22, [%x[inptr1]]\n"
+ "ldr qx_32, [%x[inptr2]]\n"
+ "ldr qx_42, [%x[inptr3]]\n"
+
+ "fneg BXTx12.4s, x_32.4s\n"
+ "fadd BXTx22.4s, x_22.4s, x_32.4s\n"
+ "fsub BXTx32.4s, x_32.4s, x_22.4s\n"
+ "fsub BXTx42.4s, x_22.4s, x_42.4s\n"
+
+ ".unreq x_22\n .unreq qx_22\n"
+ ".unreq x_32\n .unreq qx_32\n"
+ ".unreq x_42\n .unreq qx_42\n"
+
+ // Load and compute latter two columns of the first tile. Progress the
+ // input pointers (by three columns so that the each points are the
+ // second column of the next tile, that is, each points at the first
+ // column which must be read for the next tile.
+ "ldr qx_23, [%x[inptr1], %x[colstride1]]\n"
+ "ldr qx_33, [%x[inptr2], %x[colstride1]]\n"
+ "ldr qx_43, [%x[inptr3], %x[colstride1]]\n"
+
+ "fneg BXTx13.4s, x_33.4s\n"
+
+ "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
+ "ldr qx_24, [%x[inptr1], %x[colstride2]]\n"
+
+ "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
+ "ldr qx_34, [%x[inptr2], %x[colstride2]]\n"
+
+ "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
+ "ldr qx_44, [%x[inptr3], %x[colstride2]]\n"
+
+ "fneg BXTx14.4s, x_34.4s\n"
+
+ "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
+ "add %x[inptr1], %x[inptr1], %x[colstride3]\n"
+
+ "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], %x[colstride3]\n"
+
+ "fsub BXTx44.4s, x_24.4s, x_44.4s\n"
+ "add %x[inptr3], %x[inptr3], %x[colstride3]\n"
+
+ // Compute and store U for the first tile
+ // First row
+ "fneg U.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
+
+ // Second row
+ "fneg U.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
+
+ // Third row
+ "fneg U.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
+
+ // Fourth row, simultaneously load the first column of inputs for the
+ // next tile.
+ "fneg U.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+
+ "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "ldr qx_23, [%x[inptr1]]\n"
+
+ "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "ldr qx_33, [%x[inptr2]]\n"
+
+ "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+ "ldr qx_43, [%x[inptr3]]\n"
+
+ "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
+
+ // Update the loop counter, subtract two to account for both the head and
+ // the tail.
+ "subs %x[tile_j], %x[tile_j], #2\n"
+ "beq 2f\n" // Jump to "A" tail if out of tiles
+
+ // ----------------------------------------------------------------------
+ "1:"
+ // Start part A
+ // Load last column of this tile (the first column has already been
+ // loaded) and compute latter two columns of X.T x.
+ "fneg AXTx13.4s, x_33.4s\n"
+ "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
+ "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
+ "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
+ "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
+ "fsub AXTx43.4s, x_23.4s, x_43.4s\n"
+ "ldr qx_44, [%x[inptr3], %x[colstride1]]\n"
+ "fneg AXTx14.4s, x_34.4s\n"
+ "fadd AXTx24.4s, x_24.4s, x_34.4s\n"
+ "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
+ "fsub AXTx34.4s, x_34.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
+ "fsub AXTx44.4s, x_24.4s, x_44.4s\n"
+ "add %x[inptr3], %x[inptr3], %x[colstride2]\n"
+
+ // Compute and store U.
+ // First row
+ "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, AXTx12.4s, AXTx14.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
+
+ // Second row
+ "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fsub U.4s, AXTx22.4s, AXTx24.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
+
+ // Third row
+ "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, AXTx32.4s, AXTx34.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
+
+ // Fourth row
+ "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+
+ "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "ldr qx_23, [%x[inptr1]]\n"
+
+ "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "ldr qx_33, [%x[inptr2]]\n"
+
+ "fsub U.4s, AXTx42.4s, AXTx44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+ "ldr qx_43, [%x[inptr3]]\n"
+
+ "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
+
+ "subs %x[tile_j], %x[tile_j], #1\n"
+ "beq 3f\n" // Jump to 'B' tail
+
+ // Start part B
+ // Load last column of this tile (the first column has already been
+ // loaded) and compute latter two columns of X.T x.
+ "fneg BXTx13.4s, x_33.4s\n"
+ "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
+ "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
+ "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
+ "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
+ "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
+ "ldr qx_44, [%x[inptr3], %x[colstride1]]\n"
+ "fneg BXTx14.4s, x_34.4s\n"
+ "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
+ "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
+ "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
+ "fsub BXTx44.4s, x_24.4s, x_44.4s\n"
+ "add %x[inptr3], %x[inptr3], %x[colstride2]\n"
+
+ // Compute and store U.
+ // First row
+ "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
+
+ // Second row
+ "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
+
+ // Third row
+ "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
+
+ // Fourth row
+ "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+
+ "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "ldr qx_23, [%x[inptr1]]\n"
+
+ "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "ldr qx_33, [%x[inptr2]]\n"
+
+ "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+ "ldr qx_43, [%x[inptr3]]\n"
+
+ "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
+ "subs %x[tile_j], %x[tile_j], #1\n"
+ "bne 1b\n" // Continue loop, otherwise flow into 'A' tail
+
+ // ----------------------------------------------------------------------
+ "2:"
+ // 'A' tail
+ // Since the final column is padding and the last-but-one column has
+ // already been loaded just compute the 3rd column of `X.T x'.
+ "fneg AXTx13.4s, x_33.4s\n"
+ "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
+ "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
+ "fsub AXTx43.4s, x_23.4s, x_43.4s\n"
+
+ // Compute and store U. Modified to account for the final column of X.T
+ // x containing padding. Note, it is also unnecessary to update the
+ // output pointers.
+ // First row
+ "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "str qAXTx12, [%x[outptr0], %x[mstride3]]\n"
+
+ // Second row
+ "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "str qAXTx22, [%x[outptr4], %x[mstride3]]\n"
+
+ // Third row
+ "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "str qAXTx32, [%x[outptr8], %x[mstride3]]\n"
+
+ // Fourth row
+ "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "str qAXTx42, [%x[outptr12], %x[mstride3]]\n"
+
+ "b 4f\n" // Jump to end of function
+
+ // ----------------------------------------------------------------------
+ "3:"
+ // 'B' tail
+ // Since the final column is padding and the last-but-one column has
+ // already been loaded just compute the 3rd column of `X.T x'.
+ "fneg BXTx13.4s, x_33.4s\n"
+ "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
+ "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
+ "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
+
+ // Compute and store U. Modified to account for the final column of X.T
+ // x containing padding. Note, it is also unnecessary to update the
+ // output pointers.
+ // First row
+ "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "str qBXTx12, [%x[outptr0], %x[mstride3]]\n"
+
+ // Second row
+ "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "str qBXTx22, [%x[outptr4], %x[mstride3]]\n"
+
+ // Third row
+ "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "str qBXTx32, [%x[outptr8], %x[mstride3]]\n"
+
+ // Fourth row
+ "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "str qBXTx42, [%x[outptr12], %x[mstride3]]\n"
+
+ // ----------------------------------------------------------------------
+ "4:"
+ // End of function
+
+ // Clear names
+ ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n"
+ ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n"
+ ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n"
+ ".unreq AXTx11\n" ".unreq BXTx13\n"
+ ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n"
+ ".unreq AXTx21\n" ".unreq BXTx23\n"
+ ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n"
+ ".unreq AXTx31\n" ".unreq BXTx33\n"
+ ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n"
+ ".unreq AXTx41\n" ".unreq BXTx43\n"
+ ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n"
+ ".unreq AXTx13\n" ".unreq BXTx11\n"
+ ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n"
+ ".unreq AXTx23\n" ".unreq BXTx21\n"
+ ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n"
+ ".unreq AXTx33\n" ".unreq BXTx31\n"
+ ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n"
+ ".unreq AXTx43\n" ".unreq BXTx41\n"
+ ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n"
+ ".unreq U\n" ".unreq qU\n"
+ : [inptr1] "+r" (inptr0), // Offset to account for padded row
+ [inptr2] "+r" (inptr1), // Offset to account for padded row
+ [inptr3] "+r" (inptr2), // Offset to account for padded row
+ [outptr0] "+r" (outptr0),
+ [outptr4] "+r" (outptr4),
+ [outptr8] "+r" (outptr8),
+ [outptr12] "+r" (outptr12),
+ [tile_j] "+r" (tile_j) // Tile counter
+ : [colstride1] "r" (1 * input_col_stride * sizeof(float)),
+ [colstride2] "r" (2 * input_col_stride * sizeof(float)),
+ [colstride3] "r" (3 * input_col_stride * sizeof(float)),
+ [mstride1] "r" (1 * matrix_stride * sizeof(float)),
+ [mstride2] "r" (2 * matrix_stride * sizeof(float)),
+ [mstride3] "r" (3 * matrix_stride * sizeof(float)),
+ [matrix_row_stride] "r" (matrix_row_stride * sizeof(float))
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
+ "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
+ "v22", "v23", "v24"
+ );
+}
+
+// Pad left, right and bottom by 1.
+template <>
+template <>
+inline void Winograd2x2_3x3GemmInput<float>::process_tile_row<0, 1, 1, 1, 4>(
+ const int tile_N,
+ const float* const input,
+ const int input_row_stride,
+ const int input_col_stride,
+ float* const matrix,
+ const int matrix_stride,
+ const int matrix_row_stride
+) {
+ const float *inptr0 = input;
+ const float *inptr1 = input + input_row_stride;
+ const float *inptr2 = input + input_row_stride * 2;
+
+ float *outptr0 = matrix;
+ float *outptr4 = matrix + matrix_stride * 4;
+ float *outptr8 = matrix + matrix_stride * 8;
+ float *outptr12 = matrix + matrix_stride * 12;
+
+ int tile_j = tile_N; // Tiles to process
+
+ asm volatile (
+ // Named SIMD registers according to the policy given above
+ // Registers into which to load the latter two columns of `x`
+ // NOTE: Bottom row is not required since since it is padded.
+ "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n"
+ "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n"
+ "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n"
+
+ // Registers for storing X.T x (both A and B halves)
+ "AXTx11 .req v8\n" "BXTx13 .req v8\n"
+ "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n"
+ "AXTx21 .req v10\n" "BXTx23 .req v10\n"
+ "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n"
+ "AXTx31 .req v12\n" "BXTx33 .req v12\n"
+ "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n"
+ "AXTx41 .req v14\n" "BXTx43 .req v14\n"
+ "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n"
+ "AXTx13 .req v16\n" "BXTx11 .req v16\n"
+ "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n"
+ "AXTx23 .req v18\n" "BXTx21 .req v18\n"
+ "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n"
+ "AXTx33 .req v20\n" "BXTx31 .req v20\n"
+ "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n"
+ "AXTx43 .req v22\n" "BXTx41 .req v22\n"
+ "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n"
+
+ // Result register (TODO Does using more registers yield better
+ // performance)
+ "U .req v24\n qU .req q24\n"
+
+ // ----------------------------------------------------------------------
+ // Head of loop
+ // Loads a complete 4x4 tile of x, computes X.T x, computes and stores
+ // `U = X.T x X`. Prepares for the 'A' half of the loop.
+ // NOTE: Since the first tile has the leftmost column padded we can
+ // skip 4 loads and 4 calculations for the matrix X.T x X.
+
+ // Temporarily alias registers for computing the first (non-padded)
+ // column of x.
+ "x_12 .req v0\n qx_12 .req q0\n"
+ "x_22 .req v1\n qx_22 .req q1\n"
+ "x_32 .req v2\n qx_32 .req q2\n"
+
+ "ldr qx_12, [%x[inptr0]]\n"
+ "ldr qx_22, [%x[inptr1]]\n"
+ "ldr qx_32, [%x[inptr2]]\n"
+
+ "fsub BXTx12.4s, x_12.4s, x_32.4s\n"
+ "fadd BXTx22.4s, x_22.4s, x_32.4s\n"
+ "fsub BXTx32.4s, x_32.4s, x_22.4s\n"
+ "mov BXTx42.16b, x_22.16b\n" // Probably should do better
+
+ ".unreq x_12\n .unreq qx_12\n"
+ ".unreq x_22\n .unreq qx_22\n"
+ ".unreq x_32\n .unreq qx_32\n"
+
+ // Load and compute latter two columns of the first tile. Progress the
+ // input pointers (by three columns so that the each points are the
+ // second column of the next tile, that is, each points at the first
+ // column which must be read for the next tile.
+ "ldr qx_13, [%x[inptr0], %x[colstride1]]\n"
+ "ldr qx_23, [%x[inptr1], %x[colstride1]]\n"
+ "ldr qx_33, [%x[inptr2], %x[colstride1]]\n"
+
+ "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
+ "ldr qx_14, [%x[inptr0], %x[colstride2]]\n"
+
+ "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
+ "ldr qx_24, [%x[inptr1], %x[colstride2]]\n"
+
+ "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
+ "ldr qx_34, [%x[inptr2], %x[colstride2]]\n"
+
+ "mov BXTx43.16b, x_23.16b\n"
+ "fsub BXTx14.4s, x_14.4s, x_34.4s\n"
+ "add %x[inptr0], %x[inptr0], %x[colstride3]\n"
+
+ "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
+ "add %x[inptr1], %x[inptr1], %x[colstride3]\n"
+
+ "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], %x[colstride3]\n"
+
+ "mov BXTx44.16b, x_24.16b\n"
+
+ // Compute and store U for the first tile
+ // First row
+ "fneg U.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
+
+ // Second row
+ "fneg U.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
+
+ // Third row
+ "fneg U.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
+
+ // Fourth row, simultaneously load the first column of inputs for the
+ // next tile.
+ "fneg U.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "ldr qx_13, [%x[inptr0]]\n"
+
+ "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "ldr qx_23, [%x[inptr1]]\n"
+
+ "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "ldr qx_33, [%x[inptr2]]\n"
+
+ "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+
+ "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
+
+ // Update the loop counter, subtract two to account for both the head and
+ // the tail.
+ "subs %x[tile_j], %x[tile_j], #2\n"
+ "beq 2f\n" // Jump to "A" tail if out of tiles
+
+ // ----------------------------------------------------------------------
+ "1:"
+ // Start part A
+ // Load last column of this tile (the first column has already been
+ // loaded) and compute latter two columns of X.T x.
+ "fsub AXTx13.4s, x_13.4s, x_33.4s\n"
+ "ldr qx_14, [%x[inptr0], %x[colstride1]]\n"
+ "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
+ "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
+ "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
+ "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
+ "mov AXTx43.16b, x_23.16b\n"
+
+ "fsub AXTx14.4s, x_14.4s, x_34.4s\n"
+ "add %x[inptr0], %x[inptr0], %x[colstride2]\n"
+ "fadd AXTx24.4s, x_24.4s, x_34.4s\n"
+ "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
+ "fsub AXTx34.4s, x_34.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
+ "mov AXTx44.16b, x_24.16b\n"
+
+ // Compute and store U.
+ // First row
+ "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, AXTx12.4s, AXTx14.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
+
+ // Second row
+ "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fsub U.4s, AXTx22.4s, AXTx24.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
+
+ // Third row
+ "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, AXTx32.4s, AXTx34.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
+
+ // Fourth row
+ "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "ldr qx_13, [%x[inptr0]]\n"
+
+ "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "ldr qx_23, [%x[inptr1]]\n"
+
+ "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "ldr qx_33, [%x[inptr2]]\n"
+
+ "fsub U.4s, AXTx42.4s, AXTx44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+
+ "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
+
+ "subs %x[tile_j], %x[tile_j], #1\n"
+ "beq 3f\n" // Jump to 'B' tail
+
+ // Start part B
+ // Load last column of this tile (the first column has already been
+ // loaded) and compute latter two columns of X.T x.
+ "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
+ "ldr qx_14, [%x[inptr0], %x[colstride1]]\n"
+ "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
+ "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
+ "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
+ "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
+ "mov BXTx43.16b, x_23.16b\n"
+
+ "fsub BXTx14.4s, x_14.4s, x_34.4s\n"
+ "add %x[inptr0], %x[inptr0], %x[colstride2]\n"
+ "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
+ "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
+ "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
+ "mov BXTx44.16b, x_24.16b\n"
+
+ // Compute and store U.
+ // First row
+ "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
+
+ // Second row
+ "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
+
+ // Third row
+ "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
+
+ // Fourth row
+ "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "ldr qx_13, [%x[inptr0]]\n"
+
+ "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "ldr qx_23, [%x[inptr1]]\n"
+
+ "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "ldr qx_33, [%x[inptr2]]\n"
+
+ "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+
+ "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
+ "subs %x[tile_j], %x[tile_j], #1\n"
+ "bne 1b\n" // Continue loop, otherwise flow into 'A' tail
+
+ // ----------------------------------------------------------------------
+ "2:"
+ // 'A' tail
+ // Since the final column is padding and the last-but-one column has
+ // already been loaded just compute the 3rd column of `X.T x'.
+ "fsub AXTx13.4s, x_13.4s, x_33.4s\n"
+ "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
+ "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
+ "mov AXTx43.16b, x_23.16b\n"
+
+ // Compute and store U. Modified to account for the final column of X.T
+ // x containing padding. Note, it is also unnecessary to update the
+ // output pointers.
+ // First row
+ "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "str qAXTx12, [%x[outptr0], %x[mstride3]]\n"
+
+ // Second row
+ "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "str qAXTx22, [%x[outptr4], %x[mstride3]]\n"
+
+ // Third row
+ "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "str qAXTx32, [%x[outptr8], %x[mstride3]]\n"
+
+ // Fourth row
+ "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "str qAXTx42, [%x[outptr12], %x[mstride3]]\n"
+
+ "b 4f\n" // Jump to end of function
+
+ // ----------------------------------------------------------------------
+ "3:"
+ // 'B' tail
+ // Since the final column is padding and the last-but-one column has
+ // already been loaded just compute the 3rd column of `X.T x'.
+ "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
+ "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
+ "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
+ "mov BXTx43.16b, x_23.16b\n"
+
+ // Compute and store U. Modified to account for the final column of X.T
+ // x containing padding. Note, it is also unnecessary to update the
+ // output pointers.
+ // First row
+ "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "str qBXTx12, [%x[outptr0], %x[mstride3]]\n"
+
+ // Second row
+ "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "str qBXTx22, [%x[outptr4], %x[mstride3]]\n"
+
+ // Third row
+ "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "str qBXTx32, [%x[outptr8], %x[mstride3]]\n"
+
+ // Fourth row
+ "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "str qBXTx42, [%x[outptr12], %x[mstride3]]\n"
+
+ // ----------------------------------------------------------------------
+ "4:"
+ // End of function
+
+ // Clear names
+ ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n"
+ ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n"
+ ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n"
+ ".unreq AXTx11\n" ".unreq BXTx13\n"
+ ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n"
+ ".unreq AXTx21\n" ".unreq BXTx23\n"
+ ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n"
+ ".unreq AXTx31\n" ".unreq BXTx33\n"
+ ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n"
+ ".unreq AXTx41\n" ".unreq BXTx43\n"
+ ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n"
+ ".unreq AXTx13\n" ".unreq BXTx11\n"
+ ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n"
+ ".unreq AXTx23\n" ".unreq BXTx21\n"
+ ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n"
+ ".unreq AXTx33\n" ".unreq BXTx31\n"
+ ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n"
+ ".unreq AXTx43\n" ".unreq BXTx41\n"
+ ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n"
+ ".unreq U\n" ".unreq qU\n"
+ : [inptr0] "+r" (inptr0),
+ [inptr1] "+r" (inptr1),
+ [inptr2] "+r" (inptr2),
+ [outptr0] "+r" (outptr0),
+ [outptr4] "+r" (outptr4),
+ [outptr8] "+r" (outptr8),
+ [outptr12] "+r" (outptr12),
+ [tile_j] "+r" (tile_j) // Tile counter
+ : [colstride1] "r" (1 * input_col_stride * sizeof(float)),
+ [colstride2] "r" (2 * input_col_stride * sizeof(float)),
+ [colstride3] "r" (3 * input_col_stride * sizeof(float)),
+ [mstride1] "r" (1 * matrix_stride * sizeof(float)),
+ [mstride2] "r" (2 * matrix_stride * sizeof(float)),
+ [mstride3] "r" (3 * matrix_stride * sizeof(float)),
+ [matrix_row_stride] "r" (matrix_row_stride * sizeof(float))
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
+ "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
+ "v22", "v23", "v24"
+ );
+}
+}
+#endif // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp
new file mode 100644
index 0000000000..ad1ad55291
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp
@@ -0,0 +1,961 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+#include "../input_2x2_3x3.hpp"
+
+#ifdef __aarch64__
+
+namespace winograd {
+
+template <>
+template <>
+inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 0, 0, 0, 4>(
+ int &n_channels, // Number of channels in the tile
+ const float* &inptr0,
+ const int input_row_stride,
+ const int input_col_stride,
+ float* &outptr0,
+ const int matrix_stride
+) {
+ // We use 4 pointers to point to the starting position on each row and use
+ // three offsets to extract elements from each of the other 3 columns.
+ auto inptr1 = inptr0 + 1*input_row_stride;
+ auto inptr2 = inptr0 + 2*input_row_stride;
+ auto inptr3 = inptr0 + 3*input_row_stride;
+
+ // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
+ // offsets to access the intermediate matrices.
+ auto outptr1 = outptr0 + matrix_stride * 4;
+ auto outptr2 = outptr0 + matrix_stride * 8;
+ auto outptr3 = outptr0 + matrix_stride * 12;
+
+ for (; n_channels > 3; n_channels -= 4) {
+ asm volatile (
+ "X_11 .req v0\n" "qX_11 .req q0\n"
+ "X_12 .req v1\n" "qX_12 .req q1\n"
+ "X_13 .req v2\n" "qX_13 .req q2\n"
+ "X_14 .req v3\n" "qX_14 .req q3\n"
+ "X_21 .req v4\n" "qX_21 .req q4\n"
+ "X_22 .req v5\n" "qX_22 .req q5\n"
+ "X_23 .req v6\n" "qX_23 .req q6\n"
+ "X_24 .req v7\n" "qX_24 .req q7\n"
+ "X_31 .req v8\n" "qX_31 .req q8\n"
+ "X_32 .req v9\n" "qX_32 .req q9\n"
+ "X_33 .req v10\n" "qX_33 .req q10\n"
+ "X_34 .req v11\n" "qX_34 .req q11\n"
+ "X_41 .req v12\n" "qX_41 .req q12\n"
+ "X_42 .req v13\n" "qX_42 .req q13\n"
+ "X_43 .req v14\n" "qX_43 .req q14\n"
+ "X_44 .req v15\n" "qX_44 .req q15\n"
+ "xX_11 .req v16\n"
+ "xX_12 .req v17\n"
+ "xX_13 .req v18\n"
+ "xX_14 .req v19\n"
+ "xX_21 .req v20\n"
+ "xX_22 .req v21\n"
+ "xX_23 .req v22\n"
+ "xX_24 .req v23\n"
+ "xX_31 .req v24\n"
+ "xX_32 .req v25\n"
+ "xX_33 .req v26\n"
+ "xX_34 .req v27\n"
+ "xX_41 .req v28\n"
+ "xX_42 .req v29\n"
+ "xX_43 .req v30\n"
+ "xX_44 .req v31\n"
+ " U .req v0\n"
+ "qU .req q0\n"
+
+ // Load the tile, and compute compute the matrix xX
+ "ldr qX_11, [%x[inptr0]]\n"
+ "ldr qX_12, [%x[inptr0], %x[colstride1]]\n"
+ "ldr qX_13, [%x[inptr0], %x[colstride2]]\n"
+ "ldr qX_14, [%x[inptr0], %x[colstride3]]\n"
+ "add %x[inptr0], %x[inptr0], #0x10\n"
+
+ "ldr qX_21, [%x[inptr1]]\n"
+ "fsub xX_11.4s, x_11.4s, x_13.4s\n"
+ "ldr qX_22, [%x[inptr1], %x[colstride1]]\n"
+ "fadd xX_12.4s, x_12.4s, x_13.4s\n"
+ "ldr qX_23, [%x[inptr1], %x[colstride2]]\n"
+ "fsub xX_13.4s, x_13.4s, x_12.4s\n"
+ "ldr qX_24, [%x[inptr1], %x[colstride3]]\n"
+ "fsub xX_14.4s, x_12.4s, x_14.4s\n"
+ "add %x[inptr1], %x[inptr1], #0x10\n"
+
+ "ldr qX_31, [%x[inptr2]]\n"
+ "fsub xX_21.4s, x_21.4s, x_23.4s\n"
+ "ldr qX_32, [%x[inptr2], %x[colstride1]]\n"
+ "fadd xX_22.4s, x_22.4s, x_23.4s\n"
+ "ldr qX_33, [%x[inptr2], %x[colstride2]]\n"
+ "fsub xX_23.4s, x_23.4s, x_22.4s\n"
+ "ldr qX_34, [%x[inptr2], %x[colstride3]]\n"
+ "fsub xX_24.4s, x_22.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], #0x10\n"
+
+ "ldr qX_41, [%x[inptr3]]\n"
+ "fsub xX_31.4s, x_31.4s, x_33.4s\n"
+ "ldr qX_42, [%x[inptr3], %x[colstride1]]\n"
+ "fadd xX_32.4s, x_32.4s, x_33.4s\n"
+ "ldr qX_43, [%x[inptr3], %x[colstride2]]\n"
+ "fsub xX_33.4s, x_33.4s, x_32.4s\n"
+ "ldr qX_44, [%x[inptr3], %x[colstride3]]\n"
+ "fsub xX_34.4s, x_32.4s, x_34.4s\n"
+ "add %x[inptr3], %x[inptr3], #0x10\n"
+
+ // Complete computing xX while beginning to compute and store
+ // $U = X.T x X$
+
+ "fsub xX_41.4s, x_41.4s, x_43.4s\n"
+
+ "fsub U.4s, xX_11.4s, xX_31.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fsub U.4s, xX_12.4s, xX_32.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, xX_13.4s, xX_33.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, xX_14.4s, xX_34.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], #0x10\n"
+
+ "fadd xX_42.4s, x_42.4s, x_43.4s\n"
+
+ "fadd U.4s, xX_21.4s, xX_31.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, xX_22.4s, xX_32.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fadd U.4s, xX_23.4s, xX_33.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fadd U.4s, xX_24.4s, xX_34.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], #0x10\n"
+
+ "fsub xX_43.4s, x_43.4s, x_42.4s\n"
+
+ "fsub U.4s, xX_31.4s, xX_21.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fsub U.4s, xX_32.4s, xX_22.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, xX_33.4s, xX_23.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, xX_34.4s, xX_24.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], #0x10\n"
+
+ "fsub xX_44.4s, x_42.4s, x_44.4s\n"
+
+ "fsub U.4s, xX_21.4s, xX_41.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "fsub U.4s, xX_22.4s, xX_42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "fsub U.4s, xX_23.4s, xX_43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "fsub U.4s, xX_24.4s, xX_44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+ "add %x[outptr12], %x[outptr12], #0x10\n"
+
+ ".unreq qU\n"
+ ".unreq U\n"
+ ".unreq X_11\n" ".unreq qX_11\n"
+ ".unreq X_12\n" ".unreq qX_12\n"
+ ".unreq X_13\n" ".unreq qX_13\n"
+ ".unreq X_14\n" ".unreq qX_14\n"
+ ".unreq X_21\n" ".unreq qX_21\n"
+ ".unreq X_22\n" ".unreq qX_22\n"
+ ".unreq X_23\n" ".unreq qX_23\n"
+ ".unreq X_24\n" ".unreq qX_24\n"
+ ".unreq X_31\n" ".unreq qX_31\n"
+ ".unreq X_32\n" ".unreq qX_32\n"
+ ".unreq X_33\n" ".unreq qX_33\n"
+ ".unreq X_34\n" ".unreq qX_34\n"
+ ".unreq X_41\n" ".unreq qX_41\n"
+ ".unreq X_42\n" ".unreq qX_42\n"
+ ".unreq X_43\n" ".unreq qX_43\n"
+ ".unreq X_44\n" ".unreq qX_44\n"
+ ".unreq xX_11\n"
+ ".unreq xX_12\n"
+ ".unreq xX_13\n"
+ ".unreq xX_14\n"
+ ".unreq xX_21\n"
+ ".unreq xX_22\n"
+ ".unreq xX_23\n"
+ ".unreq xX_24\n"
+ ".unreq xX_31\n"
+ ".unreq xX_32\n"
+ ".unreq xX_33\n"
+ ".unreq xX_34\n"
+ ".unreq xX_41\n"
+ ".unreq xX_42\n"
+ ".unreq xX_43\n"
+ ".unreq xX_44\n"
+
+ : [inptr0] "+r" (inptr0),
+ [inptr1] "+r" (inptr1),
+ [inptr2] "+r" (inptr2),
+ [inptr3] "+r" (inptr3),
+ [outptr0] "+r" (outptr0),
+ [outptr4] "+r" (outptr1),
+ [outptr8] "+r" (outptr2),
+ [outptr12] "+r" (outptr3)
+ : [colstride1] "r" (input_col_stride * sizeof(float)),
+ [colstride2] "r" (input_col_stride * sizeof(float) * 2),
+ [colstride3] "r" (input_col_stride * sizeof(float) * 3),
+ [mstride1] "r" (matrix_stride * sizeof(float)),
+ [mstride2] "r" (matrix_stride * sizeof(float) * 2),
+ [mstride3] "r" (matrix_stride * sizeof(float) * 3)
+ : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31"
+ );
+ }
+}
+
+// Pad top by 1
+template <>
+template <>
+inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<1, 0, 0, 0, 4>(
+ int &n_channels, // Number of channels in the tile
+ const float* &inptr0,
+ const int input_row_stride,
+ const int input_col_stride,
+ float* &outptr0,
+ const int matrix_stride
+) {
+ // We use 4 pointers to point to the starting position on each row and use
+ // three offsets to extract elements from each of the other 3 columns.
+ auto inptr1 = inptr0 + 0*input_row_stride;
+ auto inptr2 = inptr0 + 1*input_row_stride;
+
+ // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
+ // offsets to access the intermediate matrices.
+ auto outptr1 = outptr0 + matrix_stride * 4;
+ auto outptr2 = outptr0 + matrix_stride * 8;
+ auto outptr3 = outptr0 + matrix_stride * 12;
+
+ for (; n_channels > 3; n_channels -= 4) {
+ asm volatile (
+ "X_21 .req v4\n" "qX_21 .req q4\n"
+ "X_22 .req v5\n" "qX_22 .req q5\n"
+ "X_23 .req v6\n" "qX_23 .req q6\n"
+ "X_24 .req v7\n" "qX_24 .req q7\n"
+ "X_31 .req v8\n" "qX_31 .req q8\n"
+ "X_32 .req v9\n" "qX_32 .req q9\n"
+ "X_33 .req v10\n" "qX_33 .req q10\n"
+ "X_34 .req v11\n" "qX_34 .req q11\n"
+ "X_41 .req v12\n" "qX_41 .req q12\n"
+ "X_42 .req v13\n" "qX_42 .req q13\n"
+ "X_43 .req v14\n" "qX_43 .req q14\n"
+ "X_44 .req v15\n" "qX_44 .req q15\n"
+ "xX_21 .req v20\n"
+ "xX_22 .req v21\n"
+ "xX_23 .req v22\n"
+ "xX_24 .req v23\n"
+ "xX_31 .req v24\n"
+ "xX_32 .req v25\n"
+ "xX_33 .req v26\n"
+ "xX_34 .req v27\n"
+ "xX_41 .req v28\n"
+ "xX_42 .req v29\n"
+ "xX_43 .req v30\n"
+ "xX_44 .req v31\n"
+ " U .req v0\n"
+ "qU .req q0\n"
+
+ // Load the tile, and compute compute the matrix xX
+ "ldr qX_21, [%x[inptr1]]\n"
+ "ldr qX_22, [%x[inptr1], %x[colstride1]]\n"
+ "ldr qX_23, [%x[inptr1], %x[colstride2]]\n"
+ "ldr qX_24, [%x[inptr1], %x[colstride3]]\n"
+ "add %x[inptr1], %x[inptr1], #0x10\n"
+
+ "ldr qX_31, [%x[inptr2]]\n"
+ "fsub xX_21.4s, x_21.4s, x_23.4s\n"
+ "ldr qX_32, [%x[inptr2], %x[colstride1]]\n"
+ "fadd xX_22.4s, x_22.4s, x_23.4s\n"
+ "ldr qX_33, [%x[inptr2], %x[colstride2]]\n"
+ "fsub xX_23.4s, x_23.4s, x_22.4s\n"
+ "ldr qX_34, [%x[inptr2], %x[colstride3]]\n"
+ "fsub xX_24.4s, x_22.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], #0x10\n"
+
+ "ldr qX_41, [%x[inptr3]]\n"
+ "fsub xX_31.4s, x_31.4s, x_33.4s\n"
+ "ldr qX_42, [%x[inptr3], %x[colstride1]]\n"
+ "fadd xX_32.4s, x_32.4s, x_33.4s\n"
+ "ldr qX_43, [%x[inptr3], %x[colstride2]]\n"
+ "fsub xX_33.4s, x_33.4s, x_32.4s\n"
+ "ldr qX_44, [%x[inptr3], %x[colstride3]]\n"
+ "fsub xX_34.4s, x_32.4s, x_34.4s\n"
+ "add %x[inptr3], %x[inptr3], #0x10\n"
+
+ // Complete computing xX while beginning to compute and store
+ // $U = X.T x X$
+
+ "fsub xX_41.4s, x_41.4s, x_43.4s\n"
+
+ "fneg U.4s, xX_31.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fneg U.4s, xX_32.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fneg U.4s, xX_33.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fneg U.4s, xX_34.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], #0x10\n"
+
+ "fadd xX_42.4s, x_42.4s, x_43.4s\n"
+
+ "fadd U.4s, xX_21.4s, xX_31.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, xX_22.4s, xX_32.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fadd U.4s, xX_23.4s, xX_33.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fadd U.4s, xX_24.4s, xX_34.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], #0x10\n"
+
+ "fsub xX_43.4s, x_43.4s, x_42.4s\n"
+
+ "fsub U.4s, xX_31.4s, xX_21.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fsub U.4s, xX_32.4s, xX_22.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, xX_33.4s, xX_23.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, xX_34.4s, xX_24.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], #0x10\n"
+
+ "fsub xX_44.4s, x_42.4s, x_44.4s\n"
+
+ "fsub U.4s, xX_21.4s, xX_41.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "fsub U.4s, xX_22.4s, xX_42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "fsub U.4s, xX_23.4s, xX_43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "fsub U.4s, xX_24.4s, xX_44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+ "add %x[outptr12], %x[outptr12], #0x10\n"
+
+ ".unreq qU\n"
+ ".unreq U\n"
+ ".unreq X_21\n" ".unreq qX_21\n"
+ ".unreq X_22\n" ".unreq qX_22\n"
+ ".unreq X_23\n" ".unreq qX_23\n"
+ ".unreq X_24\n" ".unreq qX_24\n"
+ ".unreq X_31\n" ".unreq qX_31\n"
+ ".unreq X_32\n" ".unreq qX_32\n"
+ ".unreq X_33\n" ".unreq qX_33\n"
+ ".unreq X_34\n" ".unreq qX_34\n"
+ ".unreq X_41\n" ".unreq qX_41\n"
+ ".unreq X_42\n" ".unreq qX_42\n"
+ ".unreq X_43\n" ".unreq qX_43\n"
+ ".unreq X_44\n" ".unreq qX_44\n"
+ ".unreq xX_21\n"
+ ".unreq xX_22\n"
+ ".unreq xX_23\n"
+ ".unreq xX_24\n"
+ ".unreq xX_31\n"
+ ".unreq xX_32\n"
+ ".unreq xX_33\n"
+ ".unreq xX_34\n"
+ ".unreq xX_41\n"
+ ".unreq xX_42\n"
+ ".unreq xX_43\n"
+ ".unreq xX_44\n"
+
+ : [inptr1] "+r" (inptr0), // Offset for missing row
+ [inptr2] "+r" (inptr1), // Offset for missing row
+ [inptr3] "+r" (inptr2), // Offset for missing row
+ [outptr0] "+r" (outptr0),
+ [outptr4] "+r" (outptr1),
+ [outptr8] "+r" (outptr2),
+ [outptr12] "+r" (outptr3)
+ : [colstride1] "r" (input_col_stride * sizeof(float)),
+ [colstride2] "r" (input_col_stride * sizeof(float) * 2),
+ [colstride3] "r" (input_col_stride * sizeof(float) * 3),
+ [mstride1] "r" (matrix_stride * sizeof(float)),
+ [mstride2] "r" (matrix_stride * sizeof(float) * 2),
+ [mstride3] "r" (matrix_stride * sizeof(float) * 3)
+ : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31"
+ );
+ }
+}
+
+// Pad left by 1
+template <>
+template <>
+inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 1, 0, 0, 4>(
+ int &n_channels, // Number of channels in the tile
+ const float* &inptr0,
+ const int input_row_stride,
+ const int input_col_stride,
+ float* &outptr0,
+ const int matrix_stride
+) {
+ // We use 4 pointers to point to the starting position on each row and use
+ // three offsets to extract elements from each of the other 3 columns.
+ auto inptr1 = inptr0 + 1*input_row_stride;
+ auto inptr2 = inptr0 + 2*input_row_stride;
+ auto inptr3 = inptr0 + 3*input_row_stride;
+
+ // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
+ // offsets to access the intermediate matrices.
+ auto outptr1 = outptr0 + matrix_stride * 4;
+ auto outptr2 = outptr0 + matrix_stride * 8;
+ auto outptr3 = outptr0 + matrix_stride * 12;
+
+ for (; n_channels > 3; n_channels -= 4) {
+ asm volatile (
+ "X_12 .req v1\n" "qX_12 .req q1\n"
+ "X_13 .req v2\n" "qX_13 .req q2\n"
+ "X_14 .req v3\n" "qX_14 .req q3\n"
+ "X_22 .req v5\n" "qX_22 .req q5\n"
+ "X_23 .req v6\n" "qX_23 .req q6\n"
+ "X_24 .req v7\n" "qX_24 .req q7\n"
+ "X_32 .req v9\n" "qX_32 .req q9\n"
+ "X_33 .req v10\n" "qX_33 .req q10\n"
+ "X_34 .req v11\n" "qX_34 .req q11\n"
+ "X_42 .req v13\n" "qX_42 .req q13\n"
+ "X_43 .req v14\n" "qX_43 .req q14\n"
+ "X_44 .req v15\n" "qX_44 .req q15\n"
+ "xX_11 .req v16\n"
+ "xX_12 .req v17\n"
+ "xX_13 .req v18\n"
+ "xX_14 .req v19\n"
+ "xX_21 .req v20\n"
+ "xX_22 .req v21\n"
+ "xX_23 .req v22\n"
+ "xX_24 .req v23\n"
+ "xX_31 .req v24\n"
+ "xX_32 .req v25\n"
+ "xX_33 .req v26\n"
+ "xX_34 .req v27\n"
+ "xX_41 .req v28\n"
+ "xX_42 .req v29\n"
+ "xX_43 .req v30\n"
+ "xX_44 .req v31\n"
+ " U .req v0\n"
+ "qU .req q0\n"
+
+ // Load the tile, and compute compute the matrix xX
+ "ldr qX_12, [%x[inptr0]]\n"
+ "ldr qX_13, [%x[inptr0], %x[colstride1]]\n"
+ "ldr qX_14, [%x[inptr0], %x[colstride2]]\n"
+ "add %x[inptr0], %x[inptr0], #0x10\n"
+
+ "fneg xX_11.4s, x_13.4s\n"
+ "ldr qX_22, [%x[inptr1]]\n"
+ "fadd xX_12.4s, x_12.4s, x_13.4s\n"
+ "ldr qX_23, [%x[inptr1], %x[colstride1]]\n"
+ "fsub xX_13.4s, x_13.4s, x_12.4s\n"
+ "ldr qX_24, [%x[inptr1], %x[colstride2]]\n"
+ "fsub xX_14.4s, x_12.4s, x_14.4s\n"
+ "add %x[inptr1], %x[inptr1], #0x10\n"
+
+ "fneg xX_21.4s, x_23.4s\n"
+ "ldr qX_32, [%x[inptr2]]\n"
+ "fadd xX_22.4s, x_22.4s, x_23.4s\n"
+ "ldr qX_33, [%x[inptr2], %x[colstride1]]\n"
+ "fsub xX_23.4s, x_23.4s, x_22.4s\n"
+ "ldr qX_34, [%x[inptr2], %x[colstride2]]\n"
+ "fsub xX_24.4s, x_22.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], #0x10\n"
+
+ "fneg xX_31.4s, x_33.4s\n"
+ "ldr qX_42, [%x[inptr3]]\n"
+ "fadd xX_32.4s, x_32.4s, x_33.4s\n"
+ "ldr qX_43, [%x[inptr3], %x[colstride1]]\n"
+ "fsub xX_33.4s, x_33.4s, x_32.4s\n"
+ "ldr qX_44, [%x[inptr3], %x[colstride2]]\n"
+ "fsub xX_34.4s, x_32.4s, x_34.4s\n"
+ "add %x[inptr3], %x[inptr3], #0x10\n"
+
+ // Complete computing xX while beginning to compute and store
+ // $U = X.T x X$
+
+ "fneg xX_41.4s, x_43.4s\n"
+
+ "fsub U.4s, xX_11.4s, xX_31.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fsub U.4s, xX_12.4s, xX_32.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, xX_13.4s, xX_33.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, xX_14.4s, xX_34.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], #0x10\n"
+
+ "fadd xX_42.4s, x_42.4s, x_43.4s\n"
+
+ "fadd U.4s, xX_21.4s, xX_31.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, xX_22.4s, xX_32.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fadd U.4s, xX_23.4s, xX_33.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fadd U.4s, xX_24.4s, xX_34.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], #0x10\n"
+
+ "fsub xX_43.4s, x_43.4s, x_42.4s\n"
+
+ "fsub U.4s, xX_31.4s, xX_21.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fsub U.4s, xX_32.4s, xX_22.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, xX_33.4s, xX_23.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, xX_34.4s, xX_24.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], #0x10\n"
+
+ "fsub xX_44.4s, x_42.4s, x_44.4s\n"
+
+ "fsub U.4s, xX_21.4s, xX_41.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "fsub U.4s, xX_22.4s, xX_42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "fsub U.4s, xX_23.4s, xX_43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "fsub U.4s, xX_24.4s, xX_44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+ "add %x[outptr12], %x[outptr12], #0x10\n"
+
+ ".unreq X_12\n" ".unreq qX_12\n"
+ ".unreq X_13\n" ".unreq qX_13\n"
+ ".unreq X_14\n" ".unreq qX_14\n"
+ ".unreq X_22\n" ".unreq qX_22\n"
+ ".unreq X_23\n" ".unreq qX_23\n"
+ ".unreq X_24\n" ".unreq qX_24\n"
+ ".unreq X_32\n" ".unreq qX_32\n"
+ ".unreq X_33\n" ".unreq qX_33\n"
+ ".unreq X_34\n" ".unreq qX_34\n"
+ ".unreq X_42\n" ".unreq qX_42\n"
+ ".unreq X_43\n" ".unreq qX_43\n"
+ ".unreq X_44\n" ".unreq qX_44\n"
+ ".unreq xX_11\n"
+ ".unreq xX_12\n"
+ ".unreq xX_13\n"
+ ".unreq xX_14\n"
+ ".unreq xX_21\n"
+ ".unreq xX_22\n"
+ ".unreq xX_23\n"
+ ".unreq xX_24\n"
+ ".unreq xX_31\n"
+ ".unreq xX_32\n"
+ ".unreq xX_33\n"
+ ".unreq xX_34\n"
+ ".unreq xX_41\n"
+ ".unreq xX_42\n"
+ ".unreq xX_43\n"
+ ".unreq xX_44\n"
+ ".unreq U\n"
+ ".unreq qU\n"
+
+ : [inptr0] "+r" (inptr0),
+ [inptr1] "+r" (inptr1),
+ [inptr2] "+r" (inptr2),
+ [inptr3] "+r" (inptr3),
+ [outptr0] "+r" (outptr0),
+ [outptr4] "+r" (outptr1),
+ [outptr8] "+r" (outptr2),
+ [outptr12] "+r" (outptr3)
+ : [colstride1] "r" (input_col_stride * sizeof(float)),
+ [colstride2] "r" (input_col_stride * sizeof(float) * 2),
+ [mstride1] "r" (matrix_stride * sizeof(float)),
+ [mstride2] "r" (matrix_stride * sizeof(float) * 2),
+ [mstride3] "r" (matrix_stride * sizeof(float) * 3)
+ : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31"
+ );
+ }
+}
+
+// Pad bottom by 1
+template <>
+template <>
+inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 0, 1, 0, 4>(
+ int &n_channels, // Number of channels in the tile
+ const float* &inptr0,
+ const int input_row_stride,
+ const int input_col_stride,
+ float* &outptr0,
+ const int matrix_stride
+) {
+ // We use 4 pointers to point to the starting position on each row and use
+ // three offsets to extract elements from each of the other 3 columns.
+ auto inptr1 = inptr0 + 1*input_row_stride;
+ auto inptr2 = inptr0 + 2*input_row_stride;
+
+ // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
+ // offsets to access the intermediate matrices.
+ auto outptr1 = outptr0 + matrix_stride * 4;
+ auto outptr2 = outptr0 + matrix_stride * 8;
+ auto outptr3 = outptr0 + matrix_stride * 12;
+
+ for (; n_channels > 3; n_channels -= 4) {
+ asm volatile (
+ "X_11 .req v0\n" "qX_11 .req q0\n"
+ "X_12 .req v1\n" "qX_12 .req q1\n"
+ "X_13 .req v2\n" "qX_13 .req q2\n"
+ "X_14 .req v3\n" "qX_14 .req q3\n"
+ "X_21 .req v4\n" "qX_21 .req q4\n"
+ "X_22 .req v5\n" "qX_22 .req q5\n"
+ "X_23 .req v6\n" "qX_23 .req q6\n"
+ "X_24 .req v7\n" "qX_24 .req q7\n"
+ "X_31 .req v8\n" "qX_31 .req q8\n"
+ "X_32 .req v9\n" "qX_32 .req q9\n"
+ "X_33 .req v10\n" "qX_33 .req q10\n"
+ "X_34 .req v11\n" "qX_34 .req q11\n"
+ "xX_11 .req v16\n"
+ "xX_12 .req v17\n"
+ "xX_13 .req v18\n"
+ "xX_14 .req v19\n"
+ "xX_21 .req v20\n" "qxX_21 .req q20\n"
+ "xX_22 .req v21\n" "qxX_22 .req q21\n"
+ "xX_23 .req v22\n" "qxX_23 .req q22\n"
+ "xX_24 .req v23\n" "qxX_24 .req q23\n"
+ "xX_31 .req v24\n"
+ "xX_32 .req v25\n"
+ "xX_33 .req v26\n"
+ "xX_34 .req v27\n"
+ " U .req v0\n"
+ "qU .req q0\n"
+
+ // Load the tile, and compute compute the matrix xX
+ "ldr qX_11, [%x[inptr0]]\n"
+ "ldr qX_12, [%x[inptr0], %x[colstride1]]\n"
+ "ldr qX_13, [%x[inptr0], %x[colstride2]]\n"
+ "ldr qX_14, [%x[inptr0], %x[colstride3]]\n"
+ "add %x[inptr0], %x[inptr0], #0x10\n"
+
+ "ldr qX_21, [%x[inptr1]]\n"
+ "fsub xX_11.4s, x_11.4s, x_13.4s\n"
+ "ldr qX_22, [%x[inptr1], %x[colstride1]]\n"
+ "fadd xX_12.4s, x_12.4s, x_13.4s\n"
+ "ldr qX_23, [%x[inptr1], %x[colstride2]]\n"
+ "fsub xX_13.4s, x_13.4s, x_12.4s\n"
+ "ldr qX_24, [%x[inptr1], %x[colstride3]]\n"
+ "fsub xX_14.4s, x_12.4s, x_14.4s\n"
+ "add %x[inptr1], %x[inptr1], #0x10\n"
+
+ "ldr qX_31, [%x[inptr2]]\n"
+ "fsub xX_21.4s, x_21.4s, x_23.4s\n"
+ "ldr qX_32, [%x[inptr2], %x[colstride1]]\n"
+ "fadd xX_22.4s, x_22.4s, x_23.4s\n"
+ "ldr qX_33, [%x[inptr2], %x[colstride2]]\n"
+ "fsub xX_23.4s, x_23.4s, x_22.4s\n"
+ "ldr qX_34, [%x[inptr2], %x[colstride3]]\n"
+ "fsub xX_24.4s, x_22.4s, x_24.4s\n"
+ "add %x[inptr2], %x[inptr2], #0x10\n"
+
+ "fsub xX_31.4s, x_31.4s, x_33.4s\n"
+ "fadd xX_32.4s, x_32.4s, x_33.4s\n"
+ "fsub xX_33.4s, x_33.4s, x_32.4s\n"
+ "fsub xX_34.4s, x_32.4s, x_34.4s\n"
+
+ // Complete computing xX while beginning to compute and store
+ // $U = X.T x X$
+
+ "fsub U.4s, xX_11.4s, xX_31.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fsub U.4s, xX_12.4s, xX_32.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, xX_13.4s, xX_33.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, xX_14.4s, xX_34.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], #0x10\n"
+
+ "fadd U.4s, xX_21.4s, xX_31.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, xX_22.4s, xX_32.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fadd U.4s, xX_23.4s, xX_33.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fadd U.4s, xX_24.4s, xX_34.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], #0x10\n"
+
+ "fsub U.4s, xX_31.4s, xX_21.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fsub U.4s, xX_32.4s, xX_22.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, xX_33.4s, xX_23.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, xX_34.4s, xX_24.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], #0x10\n"
+
+ "str qxX_21, [%x[outptr12]]\n"
+ "str qxX_22, [%x[outptr12], %x[mstride1]]\n"
+ "str qxX_23, [%x[outptr12], %x[mstride2]]\n"
+ "str qxX_24, [%x[outptr12], %x[mstride3]]\n"
+ "add %x[outptr12], %x[outptr12], #0x10\n"
+
+ ".unreq qU\n"
+ ".unreq U\n"
+ ".unreq X_11\n" ".unreq qX_11\n"
+ ".unreq X_12\n" ".unreq qX_12\n"
+ ".unreq X_13\n" ".unreq qX_13\n"
+ ".unreq X_14\n" ".unreq qX_14\n"
+ ".unreq X_21\n" ".unreq qX_21\n"
+ ".unreq X_22\n" ".unreq qX_22\n"
+ ".unreq X_23\n" ".unreq qX_23\n"
+ ".unreq X_24\n" ".unreq qX_24\n"
+ ".unreq X_31\n" ".unreq qX_31\n"
+ ".unreq X_32\n" ".unreq qX_32\n"
+ ".unreq X_33\n" ".unreq qX_33\n"
+ ".unreq X_34\n" ".unreq qX_34\n"
+ ".unreq xX_11\n"
+ ".unreq xX_12\n"
+ ".unreq xX_13\n"
+ ".unreq xX_14\n"
+ ".unreq xX_21\n" ".unreq qxX_21\n"
+ ".unreq xX_22\n" ".unreq qxX_22\n"
+ ".unreq xX_23\n" ".unreq qxX_23\n"
+ ".unreq xX_24\n" ".unreq qxX_24\n"
+ ".unreq xX_31\n"
+ ".unreq xX_32\n"
+ ".unreq xX_33\n"
+ ".unreq xX_34\n"
+
+ : [inptr0] "+r" (inptr0),
+ [inptr1] "+r" (inptr1),
+ [inptr2] "+r" (inptr2),
+ [outptr0] "+r" (outptr0),
+ [outptr4] "+r" (outptr1),
+ [outptr8] "+r" (outptr2),
+ [outptr12] "+r" (outptr3)
+ : [colstride1] "r" (input_col_stride * sizeof(float)),
+ [colstride2] "r" (input_col_stride * sizeof(float) * 2),
+ [colstride3] "r" (input_col_stride * sizeof(float) * 3),
+ [mstride1] "r" (matrix_stride * sizeof(float)),
+ [mstride2] "r" (matrix_stride * sizeof(float) * 2),
+ [mstride3] "r" (matrix_stride * sizeof(float) * 3)
+ : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31"
+ );
+ }
+}
+
+// Pad right by 1
+template <>
+template <>
+inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 0, 0, 1, 4>(
+ int &n_channels, // Number of channels in the tile
+ const float* &inptr0,
+ const int input_row_stride,
+ const int input_col_stride,
+ float* &outptr0,
+ const int matrix_stride
+) {
+ // We use 4 pointers to point to the starting position on each row and use
+ // three offsets to extract elements from each of the other 3 columns.
+ auto inptr1 = inptr0 + 1*input_row_stride;
+ auto inptr2 = inptr0 + 2*input_row_stride;
+ auto inptr3 = inptr0 + 3*input_row_stride;
+
+ // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
+ // offsets to access the intermediate matrices.
+ auto outptr1 = outptr0 + matrix_stride * 4;
+ auto outptr2 = outptr0 + matrix_stride * 8;
+ auto outptr3 = outptr0 + matrix_stride * 12;
+
+ for (; n_channels > 3; n_channels -= 4) {
+ asm volatile (
+ "X_11 .req v0\n" "qX_11 .req q0\n"
+ "X_12 .req v1\n" "qX_12 .req q1\n"
+ "X_13 .req v2\n" "qX_13 .req q2\n"
+ "X_21 .req v4\n" "qX_21 .req q4\n"
+ "X_22 .req v5\n" "qX_22 .req q5\n"
+ "X_23 .req v6\n" "qX_23 .req q6\n"
+ "X_31 .req v8\n" "qX_31 .req q8\n"
+ "X_32 .req v9\n" "qX_32 .req q9\n"
+ "X_33 .req v10\n" "qX_33 .req q10\n"
+ "X_41 .req v12\n" "qX_41 .req q12\n"
+ "X_42 .req v13\n" "qX_42 .req q13\n"
+ "X_43 .req v14\n" "qX_43 .req q14\n"
+ "xX_11 .req v16\n"
+ "xX_12 .req v17\n"
+ "xX_13 .req v18\n"
+ "xX_14 .req x_12\n"
+ "xX_21 .req v20\n"
+ "xX_22 .req v21\n"
+ "xX_23 .req v22\n"
+ "xX_24 .req x_22\n"
+ "xX_31 .req v24\n"
+ "xX_32 .req v25\n"
+ "xX_33 .req v26\n"
+ "xX_34 .req x_32\n"
+ "xX_41 .req v28\n"
+ "xX_42 .req v29\n"
+ "xX_43 .req v30\n"
+ "xX_44 .req x_42\n"
+ " U .req v0\n"
+ "qU .req q0\n"
+
+ // Load the tile, and compute compute the matrix xX
+ "ldr qX_11, [%x[inptr0]]\n"
+ "ldr qX_12, [%x[inptr0], %x[colstride1]]\n"
+ "ldr qX_13, [%x[inptr0], %x[colstride2]]\n"
+ "add %x[inptr0], %x[inptr0], #0x10\n"
+
+ "ldr qX_21, [%x[inptr1]]\n"
+ "fsub xX_11.4s, x_11.4s, x_13.4s\n"
+ "ldr qX_22, [%x[inptr1], %x[colstride1]]\n"
+ "fadd xX_12.4s, x_12.4s, x_13.4s\n"
+ "ldr qX_23, [%x[inptr1], %x[colstride2]]\n"
+ "fsub xX_13.4s, x_13.4s, x_12.4s\n"
+ "add %x[inptr1], %x[inptr1], #0x10\n"
+
+ "ldr qX_31, [%x[inptr2]]\n"
+ "fsub xX_21.4s, x_21.4s, x_23.4s\n"
+ "ldr qX_32, [%x[inptr2], %x[colstride1]]\n"
+ "fadd xX_22.4s, x_22.4s, x_23.4s\n"
+ "ldr qX_33, [%x[inptr2], %x[colstride2]]\n"
+ "fsub xX_23.4s, x_23.4s, x_22.4s\n"
+ "add %x[inptr2], %x[inptr2], #0x10\n"
+
+ "ldr qX_41, [%x[inptr3]]\n"
+ "fsub xX_31.4s, x_31.4s, x_33.4s\n"
+ "ldr qX_42, [%x[inptr3], %x[colstride1]]\n"
+ "fadd xX_32.4s, x_32.4s, x_33.4s\n"
+ "ldr qX_43, [%x[inptr3], %x[colstride2]]\n"
+ "fsub xX_33.4s, x_33.4s, x_32.4s\n"
+ "add %x[inptr3], %x[inptr3], #0x10\n"
+
+ // Complete computing xX while beginning to compute and store
+ // $U = X.T x X$
+
+ "fsub xX_41.4s, x_41.4s, x_43.4s\n"
+
+ "fsub U.4s, xX_11.4s, xX_31.4s\n"
+ "str qU, [%x[outptr0]]\n"
+ "fsub U.4s, xX_12.4s, xX_32.4s\n"
+ "str qU, [%x[outptr0], %x[mstride1]]\n"
+ "fsub U.4s, xX_13.4s, xX_33.4s\n"
+ "str qU, [%x[outptr0], %x[mstride2]]\n"
+ "fsub U.4s, xX_14.4s, xX_34.4s\n"
+ "str qU, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[outptr0], %x[outptr0], #0x10\n"
+
+ "fadd xX_42.4s, x_42.4s, x_43.4s\n"
+
+ "fadd U.4s, xX_21.4s, xX_31.4s\n"
+ "str qU, [%x[outptr4]]\n"
+ "fadd U.4s, xX_22.4s, xX_32.4s\n"
+ "str qU, [%x[outptr4], %x[mstride1]]\n"
+ "fadd U.4s, xX_23.4s, xX_33.4s\n"
+ "str qU, [%x[outptr4], %x[mstride2]]\n"
+ "fadd U.4s, xX_24.4s, xX_34.4s\n"
+ "str qU, [%x[outptr4], %x[mstride3]]\n"
+ "add %x[outptr4], %x[outptr4], #0x10\n"
+
+ "fsub xX_43.4s, x_43.4s, x_42.4s\n"
+
+ "fsub U.4s, xX_31.4s, xX_21.4s\n"
+ "str qU, [%x[outptr8]]\n"
+ "fsub U.4s, xX_32.4s, xX_22.4s\n"
+ "str qU, [%x[outptr8], %x[mstride1]]\n"
+ "fsub U.4s, xX_33.4s, xX_23.4s\n"
+ "str qU, [%x[outptr8], %x[mstride2]]\n"
+ "fsub U.4s, xX_34.4s, xX_24.4s\n"
+ "str qU, [%x[outptr8], %x[mstride3]]\n"
+ "add %x[outptr8], %x[outptr8], #0x10\n"
+
+ "fsub U.4s, xX_21.4s, xX_41.4s\n"
+ "str qU, [%x[outptr12]]\n"
+ "fsub U.4s, xX_22.4s, xX_42.4s\n"
+ "str qU, [%x[outptr12], %x[mstride1]]\n"
+ "fsub U.4s, xX_23.4s, xX_43.4s\n"
+ "str qU, [%x[outptr12], %x[mstride2]]\n"
+ "fsub U.4s, xX_24.4s, xX_44.4s\n"
+ "str qU, [%x[outptr12], %x[mstride3]]\n"
+ "add %x[outptr12], %x[outptr12], #0x10\n"
+
+ ".unreq qU\n"
+ ".unreq U\n"
+ ".unreq X_11\n" ".unreq qX_11\n"
+ ".unreq X_12\n" ".unreq qX_12\n"
+ ".unreq X_13\n" ".unreq qX_13\n"
+ ".unreq X_21\n" ".unreq qX_21\n"
+ ".unreq X_22\n" ".unreq qX_22\n"
+ ".unreq X_23\n" ".unreq qX_23\n"
+ ".unreq X_31\n" ".unreq qX_31\n"
+ ".unreq X_32\n" ".unreq qX_32\n"
+ ".unreq X_33\n" ".unreq qX_33\n"
+ ".unreq X_41\n" ".unreq qX_41\n"
+ ".unreq X_42\n" ".unreq qX_42\n"
+ ".unreq X_43\n" ".unreq qX_43\n"
+ ".unreq xX_11\n"
+ ".unreq xX_12\n"
+ ".unreq xX_13\n"
+ ".unreq xX_14\n"
+ ".unreq xX_21\n"
+ ".unreq xX_22\n"
+ ".unreq xX_23\n"
+ ".unreq xX_24\n"
+ ".unreq xX_31\n"
+ ".unreq xX_32\n"
+ ".unreq xX_33\n"
+ ".unreq xX_34\n"
+ ".unreq xX_41\n"
+ ".unreq xX_42\n"
+ ".unreq xX_43\n"
+ ".unreq xX_44\n"
+
+ : [inptr0] "+r" (inptr0),
+ [inptr1] "+r" (inptr1),
+ [inptr2] "+r" (inptr2),
+ [inptr3] "+r" (inptr3),
+ [outptr0] "+r" (outptr0),
+ [outptr4] "+r" (outptr1),
+ [outptr8] "+r" (outptr2),
+ [outptr12] "+r" (outptr3)
+ : [colstride1] "r" (input_col_stride * sizeof(float)),
+ [colstride2] "r" (input_col_stride * sizeof(float) * 2),
+ [mstride1] "r" (matrix_stride * sizeof(float)),
+ [mstride2] "r" (matrix_stride * sizeof(float) * 2),
+ [mstride3] "r" (matrix_stride * sizeof(float) * 3)
+ : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31"
+ );
+ }
+}
+}
+#endif
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp
new file mode 100644
index 0000000000..033442aa14
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp
@@ -0,0 +1,195 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+namespace winograd {
+ /* Transform a kernel into the Winograd domain.
+ *
+ * NOTE: It is assumed that the kernel is in the form [height x width x
+ * input_channels x output_channel].
+ */
+ template <typename T>
+ struct winograd2x2_3x3_gemm_kernel_transform_impl{
+ static void execute(
+ const KernelShape &shape,
+ const T* const kernel,
+ T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride
+ );
+
+ protected:
+ template <const int output_channel_tail>
+ static void transform_kernel(
+ const T* const kernel,
+ const int n_input_channels,
+ const int n_output_channels,
+ T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride
+ );
+ };
+}
+
+/*****************************************************************************/
+/* Transform a fp32 kernel into the Winograd domain.
+ */
+#include "kernel_2x2_3x3/a64_float.hpp" // AArch64 specialisations
+
+namespace winograd
+{
+template <>
+inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::execute(
+ const KernelShape &shape,
+ const float* const kernel,
+ float* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride
+) {
+ // Delegate based on tail size
+ const int n_input_channels = shape.n_input_channels;
+ const int n_output_channels = shape.n_output_channels;
+
+ switch (n_output_channels % 4) {
+ case 0:
+ transform_kernel<0>(
+ kernel, n_input_channels, n_output_channels,
+ matrix_base, matrix_stride, matrix_row_stride
+ );
+ break;
+ case 1:
+ transform_kernel<1>(
+ kernel, n_input_channels, n_output_channels,
+ matrix_base, matrix_stride, matrix_row_stride
+ );
+ break;
+ case 2:
+ transform_kernel<2>(
+ kernel, n_input_channels, n_output_channels,
+ matrix_base, matrix_stride, matrix_row_stride
+ );
+ break;
+ case 3:
+ transform_kernel<3>(
+ kernel, n_input_channels, n_output_channels,
+ matrix_base, matrix_stride, matrix_row_stride
+ );
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Cannot happen");
+ break;
+ }
+}
+
+template <>
+template<const int output_channel_tail>
+inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel(
+ const float* const kernel,
+ const int n_input_channels,
+ const int n_output_channels,
+ float* const matrix_base,
+ const int mstride,
+ const int matrix_row_stride
+) {
+ // Use one input pointer for each row of the kernel, use two additional
+ // offsets to extract columns.
+ const int kernel_col_stride = n_input_channels * n_output_channels;
+ const int kernel_row_stride = 3 * kernel_col_stride;
+ const float *inptr0 = kernel;
+ const float *inptr1 = kernel + kernel_row_stride;
+ const float *inptr2 = kernel + kernel_row_stride*2;
+
+ // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
+ // offsets to extract further matrices.
+ float *outptr0 = matrix_base;
+ float *outptr4 = matrix_base + mstride * 4;
+ float *outptr8 = matrix_base + mstride * 8;
+ float *outptr12 = matrix_base + mstride * 12;
+
+ // For every input channel
+ for (int in_c = 0; in_c < n_input_channels; in_c++) {
+ // For every output channel
+ for (int c = 0; c < n_output_channels; c++) {
+ // Read in the kernel
+ float w11 = inptr0[0], w12 = inptr0[kernel_col_stride], w13 = inptr0[kernel_col_stride*2];
+ float w21 = inptr1[0], w22 = inptr1[kernel_col_stride], w23 = inptr1[kernel_col_stride*2];
+ float w31 = inptr2[0], w32 = inptr2[kernel_col_stride], w33 = inptr2[kernel_col_stride*2];
+
+ // Progress input pointers
+ inptr0++;
+ inptr1++;
+ inptr2++;
+
+ // Compute the kernel W w, note we need only compute the middle two rows
+ // (2 and 3) because the first and last rows are merely copies of values
+ // from the matrix w.
+ float Ww11 = w11, Ww12 = w12, Ww13 = w13;
+ float Ww21 = 0.5*(w11 + w21 + w31), Ww22 = 0.5*(w12 + w22 + w32), Ww23 = 0.5*(w13 + w23 + w33);
+ float Ww31 = 0.5*(w11 - w21 + w31), Ww32 = 0.5*(w12 - w22 + w32), Ww33 = 0.5*(w13 - w23 + w33);
+ float Ww41 = w31, Ww42 = w32, Ww43 = w33;
+
+ // Hence compute W w W.T; again note we need compute only the middle two
+ // columns since the first and last columns are copies of the first and
+ // last columns of the previous matrix.
+ float WwWT11 = Ww11, WwWT12 = 0.5*(Ww11 + Ww12 + Ww13), WwWT13 = 0.5*(Ww11 - Ww12 + Ww13), WwWT14 = Ww13;
+ float WwWT21 = Ww21, WwWT22 = 0.5*(Ww21 + Ww22 + Ww23), WwWT23 = 0.5*(Ww21 - Ww22 + Ww23), WwWT24 = Ww23;
+ float WwWT31 = Ww31, WwWT32 = 0.5*(Ww31 + Ww32 + Ww33), WwWT33 = 0.5*(Ww31 - Ww32 + Ww33), WwWT34 = Ww33;
+ float WwWT41 = Ww41, WwWT42 = 0.5*(Ww41 + Ww42 + Ww43), WwWT43 = 0.5*(Ww41 - Ww42 + Ww43), WwWT44 = Ww43;
+
+ // Store the computed weights
+ outptr0[0 * mstride] = WwWT11;
+ outptr0[1 * mstride] = WwWT12;
+ outptr0[2 * mstride] = WwWT13;
+ outptr0[3 * mstride] = WwWT14;
+
+ outptr4[0 * mstride] = WwWT21;
+ outptr4[1 * mstride] = WwWT22;
+ outptr4[2 * mstride] = WwWT23;
+ outptr4[3 * mstride] = WwWT24;
+
+ outptr8[0 * mstride] = WwWT31;
+ outptr8[1 * mstride] = WwWT32;
+ outptr8[2 * mstride] = WwWT33;
+ outptr8[3 * mstride] = WwWT34;
+
+ outptr12[0 * mstride] = WwWT41;
+ outptr12[1 * mstride] = WwWT42;
+ outptr12[2 * mstride] = WwWT43;
+ outptr12[3 * mstride] = WwWT44;
+
+ // Progress output pointers
+ outptr0++;
+ outptr4++;
+ outptr8++;
+ outptr12++;
+ }
+
+ // Progression to complete stride
+ outptr0 += matrix_row_stride - n_output_channels;
+ outptr4 += matrix_row_stride - n_output_channels;
+ outptr8 += matrix_row_stride - n_output_channels;
+ outptr12 += matrix_row_stride - n_output_channels;
+ }
+}
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp
new file mode 100644
index 0000000000..3dd62d1ac1
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp
@@ -0,0 +1,822 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#ifdef __aarch64__
+namespace winograd {
+template <>
+template <>
+inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel<0>(
+ const float* const kernel,
+ const int n_input_channels,
+ const int n_output_channels,
+ float* const matrix_base,
+ const int mstride,
+ const int matrix_row_stride
+) {
+ // Use one input pointer for each row of the kernel, use two additional
+ // offsets to extract columns.
+ const int kernel_col_stride = n_input_channels * n_output_channels;
+ const int kernel_row_stride = 3 * kernel_col_stride;
+ const float *inptr0 = kernel;
+ const float *inptr1 = kernel + kernel_row_stride;
+ const float *inptr2 = kernel + kernel_row_stride*2;
+
+ // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
+ // offsets to extract further matrices.
+ float *outptr0 = matrix_base;
+ float *outptr4 = matrix_base + mstride * 4;
+ float *outptr8 = matrix_base + mstride * 8;
+ float *outptr12 = matrix_base + mstride * 12;
+
+ // For every input channel
+ for (int in_c = 0; in_c < n_input_channels; in_c++) {
+ int n_remaining_channels = n_output_channels;
+
+ asm volatile (
+ // Registers into which to read the kernel
+ "w_11 .req v0\n" "qw_11 .req q0\n"
+ "w_12 .req v1\n" "qw_12 .req q1\n"
+ "w_13 .req v2\n" "qw_13 .req q2\n"
+ "w_21 .req v3\n" "qw_21 .req q3\n"
+ "w_22 .req v4\n" "qw_22 .req q4\n"
+ "w_23 .req v5\n" "qw_23 .req q5\n"
+ "w_31 .req v6\n" "qw_31 .req q6\n"
+ "w_32 .req v7\n" "qw_32 .req q7\n"
+ "w_33 .req v8\n" "qw_33 .req q8\n"
+
+ // Transformed matrix Ww
+ "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n"
+ "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n"
+ "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n"
+ "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n"
+
+ // Output matrix U = WwWT
+ "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n"
+ "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n"
+ "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n"
+ "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n"
+
+ // Storage view of output matrices
+ "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n"
+ "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n"
+ "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n"
+ "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n"
+
+ "half .req v23\n" // {0.5, ..., 0.5}
+ "dup half.4s, %w[one_half]\n"
+ "scratch .req v24\n"
+
+ "1:"
+ // Load tile of the kernel
+ "ldr qw_11, [%x[inptr0]]\n"
+ "str qU11, [%x[outptr0]]\n"
+ "ldr qw_12, [%x[inptr0], %x[colstride1]]\n"
+ "ldr qw_13, [%x[inptr0], %x[colstride2]]\n"
+ "str qU14, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[inptr0], %x[inptr0], #0x10\n"
+
+ "ldr qw_21, [%x[inptr1]]\n"
+ "ldr qw_22, [%x[inptr1], %x[colstride1]]\n"
+ "ldr qw_23, [%x[inptr1], %x[colstride2]]\n"
+ "add %x[inptr1], %x[inptr1], #0x10\n"
+
+ "ldr qw_31, [%x[inptr2]]\n"
+ "str qU41, [%x[outptr12]]\n"
+ "ldr qw_32, [%x[inptr2], %x[colstride1]]\n"
+ "ldr qw_33, [%x[inptr2], %x[colstride2]]\n"
+ "str qU44, [%x[outptr12], %x[mstride3]]\n"
+ "add %x[inptr2], %x[inptr2], #0x10\n"
+
+ // Compute 2nd and 3rd rows of Ww
+ "fadd scratch.4s, w_11.4s, w_31.4s\n"
+ "fmul Ww21.4s, scratch.4s, half.4s\n"
+ "fmla Ww21.4s, w_21.4s, half.4s\n"
+ "str qU21, [%x[outptr4]]\n"
+ "fmul Ww31.4s, scratch.4s, half.4s\n"
+ "fmls Ww31.4s, w_21.4s, half.4s\n"
+ "str qU31, [%x[outptr8]]\n"
+
+ "fadd scratch.4s, w_12.4s, w_32.4s\n"
+ "fmul Ww22.4s, scratch.4s, half.4s\n"
+ "fmla Ww22.4s, w_22.4s, half.4s\n"
+ "fmul Ww32.4s, scratch.4s, half.4s\n"
+ "fmls Ww32.4s, w_22.4s, half.4s\n"
+
+ "fadd scratch.4s, w_13.4s, w_33.4s\n"
+ "fmul Ww23.4s, scratch.4s, half.4s\n"
+ "fmla Ww23.4s, w_23.4s, half.4s\n"
+ "str qU24, [%x[outptr4], %x[mstride3]]\n"
+ "fmul Ww33.4s, scratch.4s, half.4s\n"
+ "fmls Ww33.4s, w_23.4s, half.4s\n"
+ "str qU34, [%x[outptr8], %x[mstride3]]\n"
+
+ // Compute and store U, only need to compute the 2nd and 3rd columns
+ // of U and update output pointers
+ "fadd scratch.4s, Ww11.4s, Ww13.4s\n"
+ "fmul U12.4s, scratch.4s, half.4s\n"
+ "fmla U12.4s, Ww12.4s, half.4s\n"
+ "str qU12, [%x[outptr0], %x[mstride1]]\n"
+ "fmul U13.4s, scratch.4s, half.4s\n"
+ "fmls U13.4s, Ww12.4s, half.4s\n"
+ "str qU13, [%x[outptr0], %x[mstride2]]\n"
+ "add %x[outptr0], %x[outptr0], #0x10\n"
+
+ "fadd scratch.4s, Ww21.4s, Ww23.4s\n"
+ "fmul U22.4s, scratch.4s, half.4s\n"
+ "fmla U22.4s, Ww22.4s, half.4s\n"
+ "str qU22, [%x[outptr4], %x[mstride1]]\n"
+ "fmul U23.4s, scratch.4s, half.4s\n"
+ "fmls U23.4s, Ww22.4s, half.4s\n"
+ "str qU23, [%x[outptr4], %x[mstride2]]\n"
+ "add %x[outptr4], %x[outptr4], #0x10\n"
+
+ "fadd scratch.4s, Ww31.4s, Ww33.4s\n"
+ "fmul U32.4s, scratch.4s, half.4s\n"
+ "fmla U32.4s, Ww32.4s, half.4s\n"
+ "str qU32, [%x[outptr8], %x[mstride1]]\n"
+ "fmul U33.4s, scratch.4s, half.4s\n"
+ "fmls U33.4s, Ww32.4s, half.4s\n"
+ "str qU33, [%x[outptr8], %x[mstride2]]\n"
+ "add %x[outptr8], %x[outptr8], #0x10\n"
+
+ "fadd scratch.4s, Ww41.4s, Ww43.4s\n"
+ "fmul U42.4s, scratch.4s, half.4s\n"
+ "fmla U42.4s, Ww42.4s, half.4s\n"
+ "str qU42, [%x[outptr12], %x[mstride1]]\n"
+ "fmul U43.4s, scratch.4s, half.4s\n"
+ "fmls U43.4s, Ww42.4s, half.4s\n"
+ "str qU43, [%x[outptr12], %x[mstride2]]\n"
+ "add %x[outptr12], %x[outptr12], #0x10\n"
+
+ "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n"
+ "bne 1b\n"
+
+ // Clear aliases
+ ".unreq half\n"
+ ".unreq scratch\n"
+ ".unreq w_11\n" ".unreq qw_11\n"
+ ".unreq w_12\n" ".unreq qw_12\n"
+ ".unreq w_13\n" ".unreq qw_13\n"
+ ".unreq w_21\n" ".unreq qw_21\n"
+ ".unreq w_22\n" ".unreq qw_22\n"
+ ".unreq w_23\n" ".unreq qw_23\n"
+ ".unreq w_31\n" ".unreq qw_31\n"
+ ".unreq w_32\n" ".unreq qw_32\n"
+ ".unreq w_33\n" ".unreq qw_33\n"
+ ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n"
+ ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n"
+ ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n"
+ ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n"
+ ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n"
+ ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n"
+ ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n"
+ ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n"
+ ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n"
+ ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n"
+ ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n"
+ ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n"
+
+ : [inptr0] "+r" (inptr0),
+ [inptr1] "+r" (inptr1),
+ [inptr2] "+r" (inptr2),
+ [outptr0] "+r" (outptr0),
+ [outptr4] "+r" (outptr4),
+ [outptr8] "+r" (outptr8),
+ [outptr12] "+r" (outptr12),
+ [n_remaining_channels] "+r" (n_remaining_channels)
+ : [mstride1] "r" (sizeof(float) * mstride),
+ [mstride2] "r" (sizeof(float) * mstride * 2),
+ [mstride3] "r" (sizeof(float) * mstride * 3),
+ [colstride1] "r" (sizeof(float) * kernel_col_stride),
+ [colstride2] "r" (sizeof(float) * kernel_col_stride * 2),
+ [one_half] "r" (0.5f)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24"
+ );
+
+ // Progression to complete stride
+ outptr0 += matrix_row_stride - n_output_channels;
+ outptr4 += matrix_row_stride - n_output_channels;
+ outptr8 += matrix_row_stride - n_output_channels;
+ outptr12 += matrix_row_stride - n_output_channels;
+ }
+}
+
+template <>
+template <>
+inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel<2>(
+ const float* const kernel,
+ const int n_input_channels,
+ const int n_output_channels,
+ float* const matrix_base,
+ const int mstride,
+ const int matrix_row_stride
+) {
+ // Use one input pointer for each row of the kernel, use two additional
+ // offsets to extract columns.
+ const int kernel_col_stride = n_input_channels * n_output_channels;
+ const int kernel_row_stride = 3 * kernel_col_stride;
+ const float *inptr0 = kernel;
+ const float *inptr1 = kernel + kernel_row_stride;
+ const float *inptr2 = kernel + kernel_row_stride*2;
+
+ // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
+ // offsets to extract further matrices.
+ float *outptr0 = matrix_base;
+ float *outptr4 = matrix_base + mstride * 4;
+ float *outptr8 = matrix_base + mstride * 8;
+ float *outptr12 = matrix_base + mstride * 12;
+
+ // For every input channel
+ for (int in_c = 0; in_c < n_input_channels; in_c++) {
+ int n_remaining_channels = n_output_channels;
+
+ asm volatile (
+ // Registers into which to read the kernel
+ "w_11 .req v0\n" "qw_11 .req q0\n" "dw_11 .req d0\n"
+ "w_12 .req v1\n" "qw_12 .req q1\n" "dw_12 .req d1\n"
+ "w_13 .req v2\n" "qw_13 .req q2\n" "dw_13 .req d2\n"
+ "w_21 .req v3\n" "qw_21 .req q3\n" "dw_21 .req d3\n"
+ "w_22 .req v4\n" "qw_22 .req q4\n" "dw_22 .req d4\n"
+ "w_23 .req v5\n" "qw_23 .req q5\n" "dw_23 .req d5\n"
+ "w_31 .req v6\n" "qw_31 .req q6\n" "dw_31 .req d6\n"
+ "w_32 .req v7\n" "qw_32 .req q7\n" "dw_32 .req d7\n"
+ "w_33 .req v8\n" "qw_33 .req q8\n" "dw_33 .req d8\n"
+
+ // Transformed matrix Ww
+ "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n"
+ "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n"
+ "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n"
+ "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n"
+
+ // Output matrix U = WwWT
+ "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n"
+ "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n"
+ "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n"
+ "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n"
+
+ // Storage view of output matrices
+ "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n"
+ "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n"
+ "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n"
+ "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n"
+
+ "dU11 .req d0\n" "dU12 .req d15\n" "dU13 .req d16\n" "dU14 .req d2\n"
+ "dU21 .req d9\n" "dU22 .req d17\n" "dU23 .req d18\n" "dU24 .req d11\n"
+ "dU31 .req d12\n" "dU32 .req d19\n" "dU33 .req d20\n" "dU34 .req d14\n"
+ "dU41 .req d6\n" "dU42 .req d21\n" "dU43 .req d22\n" "dU44 .req d8\n"
+
+ "half .req v23\n" // {0.5, ..., 0.5}
+ "dup half.4s, %w[one_half]\n"
+ "scratch .req v24\n"
+
+ // Subtract the tail from the number of remaining channels and jump to
+ // the tail if necessary.
+ "subs %x[n_remaining_channels], %x[n_remaining_channels], #2\n"
+ "beq 2f\n"
+
+ "1:"
+ // Load tile of the kernel
+ "ldr qw_11, [%x[inptr0]]\n"
+ "str qU11, [%x[outptr0]]\n"
+ "ldr qw_12, [%x[inptr0], %x[colstride1]]\n"
+ "ldr qw_13, [%x[inptr0], %x[colstride2]]\n"
+ "str qU14, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[inptr0], %x[inptr0], #0x10\n"
+
+ "ldr qw_21, [%x[inptr1]]\n"
+ "ldr qw_22, [%x[inptr1], %x[colstride1]]\n"
+ "ldr qw_23, [%x[inptr1], %x[colstride2]]\n"
+ "add %x[inptr1], %x[inptr1], #0x10\n"
+
+ "ldr qw_31, [%x[inptr2]]\n"
+ "str qU41, [%x[outptr12]]\n"
+ "ldr qw_32, [%x[inptr2], %x[colstride1]]\n"
+ "ldr qw_33, [%x[inptr2], %x[colstride2]]\n"
+ "str qU44, [%x[outptr12], %x[mstride3]]\n"
+ "add %x[inptr2], %x[inptr2], #0x10\n"
+
+ // Compute 2nd and 3rd rows of Ww
+ "fadd scratch.4s, w_11.4s, w_31.4s\n"
+ "fmul Ww21.4s, scratch.4s, half.4s\n"
+ "fmla Ww21.4s, w_21.4s, half.4s\n"
+ "str qU21, [%x[outptr4]]\n"
+ "fmul Ww31.4s, scratch.4s, half.4s\n"
+ "fmls Ww31.4s, w_21.4s, half.4s\n"
+ "str qU31, [%x[outptr8]]\n"
+
+ "fadd scratch.4s, w_12.4s, w_32.4s\n"
+ "fmul Ww22.4s, scratch.4s, half.4s\n"
+ "fmla Ww22.4s, w_22.4s, half.4s\n"
+ "fmul Ww32.4s, scratch.4s, half.4s\n"
+ "fmls Ww32.4s, w_22.4s, half.4s\n"
+
+ "fadd scratch.4s, w_13.4s, w_33.4s\n"
+ "fmul Ww23.4s, scratch.4s, half.4s\n"
+ "fmla Ww23.4s, w_23.4s, half.4s\n"
+ "str qU24, [%x[outptr4], %x[mstride3]]\n"
+ "fmul Ww33.4s, scratch.4s, half.4s\n"
+ "fmls Ww33.4s, w_23.4s, half.4s\n"
+ "str qU34, [%x[outptr8], %x[mstride3]]\n"
+
+ // Compute and store U, only need to compute the 2nd and 3rd columns
+ // of U and update output pointers
+ "fadd scratch.4s, Ww11.4s, Ww13.4s\n"
+ "fmul U12.4s, scratch.4s, half.4s\n"
+ "fmla U12.4s, Ww12.4s, half.4s\n"
+ "str qU12, [%x[outptr0], %x[mstride1]]\n"
+ "fmul U13.4s, scratch.4s, half.4s\n"
+ "fmls U13.4s, Ww12.4s, half.4s\n"
+ "str qU13, [%x[outptr0], %x[mstride2]]\n"
+ "add %x[outptr0], %x[outptr0], #0x10\n"
+
+ "fadd scratch.4s, Ww21.4s, Ww23.4s\n"
+ "fmul U22.4s, scratch.4s, half.4s\n"
+ "fmla U22.4s, Ww22.4s, half.4s\n"
+ "str qU22, [%x[outptr4], %x[mstride1]]\n"
+ "fmul U23.4s, scratch.4s, half.4s\n"
+ "fmls U23.4s, Ww22.4s, half.4s\n"
+ "str qU23, [%x[outptr4], %x[mstride2]]\n"
+ "add %x[outptr4], %x[outptr4], #0x10\n"
+
+ "fadd scratch.4s, Ww31.4s, Ww33.4s\n"
+ "fmul U32.4s, scratch.4s, half.4s\n"
+ "fmla U32.4s, Ww32.4s, half.4s\n"
+ "str qU32, [%x[outptr8], %x[mstride1]]\n"
+ "fmul U33.4s, scratch.4s, half.4s\n"
+ "fmls U33.4s, Ww32.4s, half.4s\n"
+ "str qU33, [%x[outptr8], %x[mstride2]]\n"
+ "add %x[outptr8], %x[outptr8], #0x10\n"
+
+ "fadd scratch.4s, Ww41.4s, Ww43.4s\n"
+ "fmul U42.4s, scratch.4s, half.4s\n"
+ "fmla U42.4s, Ww42.4s, half.4s\n"
+ "str qU42, [%x[outptr12], %x[mstride1]]\n"
+ "fmul U43.4s, scratch.4s, half.4s\n"
+ "fmls U43.4s, Ww42.4s, half.4s\n"
+ "str qU43, [%x[outptr12], %x[mstride2]]\n"
+ "add %x[outptr12], %x[outptr12], #0x10\n"
+
+ "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n"
+ "bne 1b\n"
+
+ // Tail size 2
+ "2:"
+ // Load tile of the kernel
+ "ldr dw_11, [%x[inptr0]]\n"
+ "str dU11, [%x[outptr0]]\n"
+ "ldr dw_12, [%x[inptr0], %x[colstride1]]\n"
+ "ldr dw_13, [%x[inptr0], %x[colstride2]]\n"
+ "str dU14, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[inptr0], %x[inptr0], #0x08\n"
+
+ "ldr dw_21, [%x[inptr1]]\n"
+ "ldr dw_22, [%x[inptr1], %x[colstride1]]\n"
+ "ldr dw_23, [%x[inptr1], %x[colstride2]]\n"
+ "add %x[inptr1], %x[inptr1], #0x08\n"
+
+ "ldr dw_31, [%x[inptr2]]\n"
+ "str dU41, [%x[outptr12]]\n"
+ "ldr dw_32, [%x[inptr2], %x[colstride1]]\n"
+ "ldr dw_33, [%x[inptr2], %x[colstride2]]\n"
+ "str dU44, [%x[outptr12], %x[mstride3]]\n"
+ "add %x[inptr2], %x[inptr2], #0x08\n"
+
+ // Compute 2nd and 3rd rows of Ww
+ "fadd scratch.2s, w_11.2s, w_31.2s\n"
+ "fmul Ww21.2s, scratch.2s, half.2s\n"
+ "fmla Ww21.2s, w_21.2s, half.2s\n"
+ "str dU21, [%x[outptr4]]\n"
+ "fmul Ww31.2s, scratch.2s, half.2s\n"
+ "fmls Ww31.2s, w_21.2s, half.2s\n"
+ "str dU31, [%x[outptr8]]\n"
+
+ "fadd scratch.2s, w_12.2s, w_32.2s\n"
+ "fmul Ww22.2s, scratch.2s, half.2s\n"
+ "fmla Ww22.2s, w_22.2s, half.2s\n"
+ "fmul Ww32.2s, scratch.2s, half.2s\n"
+ "fmls Ww32.2s, w_22.2s, half.2s\n"
+
+ "fadd scratch.2s, w_13.2s, w_33.2s\n"
+ "fmul Ww23.2s, scratch.2s, half.2s\n"
+ "fmla Ww23.2s, w_23.2s, half.2s\n"
+ "str dU24, [%x[outptr4], %x[mstride3]]\n"
+ "fmul Ww33.2s, scratch.2s, half.2s\n"
+ "fmls Ww33.2s, w_23.2s, half.2s\n"
+ "str dU34, [%x[outptr8], %x[mstride3]]\n"
+
+ // Compute and store U, only need to compute the 2nd and 3rd columns of
+ // U and update output pointers
+ "fadd scratch.2s, Ww11.2s, Ww13.2s\n"
+ "fmul U12.2s, scratch.2s, half.2s\n"
+ "fmla U12.2s, Ww12.2s, half.2s\n"
+ "str dU12, [%x[outptr0], %x[mstride1]]\n"
+ "fmul U13.2s, scratch.2s, half.2s\n"
+ "fmls U13.2s, Ww12.2s, half.2s\n"
+ "str dU13, [%x[outptr0], %x[mstride2]]\n"
+ "add %x[outptr0], %x[outptr0], #0x08\n"
+
+ "fadd scratch.2s, Ww21.2s, Ww23.2s\n"
+ "fmul U22.2s, scratch.2s, half.2s\n"
+ "fmla U22.2s, Ww22.2s, half.2s\n"
+ "str dU22, [%x[outptr4], %x[mstride1]]\n"
+ "fmul U23.2s, scratch.2s, half.2s\n"
+ "fmls U23.2s, Ww22.2s, half.2s\n"
+ "str dU23, [%x[outptr4], %x[mstride2]]\n"
+ "add %x[outptr4], %x[outptr4], #0x08\n"
+
+ "fadd scratch.2s, Ww31.2s, Ww33.2s\n"
+ "fmul U32.2s, scratch.2s, half.2s\n"
+ "fmla U32.2s, Ww32.2s, half.2s\n"
+ "str dU32, [%x[outptr8], %x[mstride1]]\n"
+ "fmul U33.2s, scratch.2s, half.2s\n"
+ "fmls U33.2s, Ww32.2s, half.2s\n"
+ "str dU33, [%x[outptr8], %x[mstride2]]\n"
+ "add %x[outptr8], %x[outptr8], #0x08\n"
+
+ "fadd scratch.2s, Ww41.2s, Ww43.2s\n"
+ "fmul U42.2s, scratch.2s, half.2s\n"
+ "fmla U42.2s, Ww42.2s, half.2s\n"
+ "str dU42, [%x[outptr12], %x[mstride1]]\n"
+ "fmul U43.2s, scratch.2s, half.2s\n"
+ "fmls U43.2s, Ww42.2s, half.2s\n"
+ "str dU43, [%x[outptr12], %x[mstride2]]\n"
+ "add %x[outptr12], %x[outptr12], #0x08\n"
+
+ // Clear aliases
+ ".unreq half\n"
+ ".unreq scratch\n"
+ ".unreq w_11\n" ".unreq qw_11\n" ".unreq dw_11\n"
+ ".unreq w_12\n" ".unreq qw_12\n" ".unreq dw_12\n"
+ ".unreq w_13\n" ".unreq qw_13\n" ".unreq dw_13\n"
+ ".unreq w_21\n" ".unreq qw_21\n" ".unreq dw_21\n"
+ ".unreq w_22\n" ".unreq qw_22\n" ".unreq dw_22\n"
+ ".unreq w_23\n" ".unreq qw_23\n" ".unreq dw_23\n"
+ ".unreq w_31\n" ".unreq qw_31\n" ".unreq dw_31\n"
+ ".unreq w_32\n" ".unreq qw_32\n" ".unreq dw_32\n"
+ ".unreq w_33\n" ".unreq qw_33\n" ".unreq dw_33\n"
+ ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n"
+ ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n"
+ ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n"
+ ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n"
+ ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n"
+ ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n"
+ ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n"
+ ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n"
+ ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n"
+ ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n"
+ ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n"
+ ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n"
+ ".unreq dU11\n" ".unreq dU12\n" ".unreq dU13\n" ".unreq dU14\n"
+ ".unreq dU21\n" ".unreq dU22\n" ".unreq dU23\n" ".unreq dU24\n"
+ ".unreq dU31\n" ".unreq dU32\n" ".unreq dU33\n" ".unreq dU34\n"
+ ".unreq dU41\n" ".unreq dU42\n" ".unreq dU43\n" ".unreq dU44\n"
+
+ : [inptr0] "+r" (inptr0),
+ [inptr1] "+r" (inptr1),
+ [inptr2] "+r" (inptr2),
+ [outptr0] "+r" (outptr0),
+ [outptr4] "+r" (outptr4),
+ [outptr8] "+r" (outptr8),
+ [outptr12] "+r" (outptr12),
+ [n_remaining_channels] "+r" (n_remaining_channels)
+ : [mstride1] "r" (sizeof(float) * mstride),
+ [mstride2] "r" (sizeof(float) * mstride * 2),
+ [mstride3] "r" (sizeof(float) * mstride * 3),
+ [colstride1] "r" (sizeof(float) * kernel_col_stride),
+ [colstride2] "r" (sizeof(float) * kernel_col_stride * 2),
+ [one_half] "r" (0.5f)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24"
+ );
+
+ // Progression to complete stride
+ outptr0 += matrix_row_stride - n_output_channels;
+ outptr4 += matrix_row_stride - n_output_channels;
+ outptr8 += matrix_row_stride - n_output_channels;
+ outptr12 += matrix_row_stride - n_output_channels;
+ }
+}
+
+template <>
+template <>
+inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel<1>(
+ const float* const kernel,
+ const int n_input_channels,
+ const int n_output_channels,
+ float* const matrix_base,
+ const int mstride,
+ const int matrix_row_stride
+) {
+ // Use one input pointer for each row of the kernel, use two additional
+ // offsets to extract columns.
+ const int kernel_col_stride = n_input_channels * n_output_channels;
+ const int kernel_row_stride = 3 * kernel_col_stride;
+ const float *inptr0 = kernel;
+ const float *inptr1 = kernel + kernel_row_stride;
+ const float *inptr2 = kernel + kernel_row_stride*2;
+
+ // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
+ // offsets to extract further matrices.
+ float *outptr0 = matrix_base;
+ float *outptr4 = matrix_base + mstride * 4;
+ float *outptr8 = matrix_base + mstride * 8;
+ float *outptr12 = matrix_base + mstride * 12;
+
+ // For every input channel
+ for (int in_c = 0; in_c < n_input_channels; in_c++) {
+ int n_remaining_channels = n_output_channels;
+
+ asm volatile (
+ // Registers into which to read the kernel
+ "w_11 .req v0\n" "qw_11 .req q0\n" "sw_11 .req s0\n"
+ "w_12 .req v1\n" "qw_12 .req q1\n" "sw_12 .req s1\n"
+ "w_13 .req v2\n" "qw_13 .req q2\n" "sw_13 .req s2\n"
+ "w_21 .req v3\n" "qw_21 .req q3\n" "sw_21 .req s3\n"
+ "w_22 .req v4\n" "qw_22 .req q4\n" "sw_22 .req s4\n"
+ "w_23 .req v5\n" "qw_23 .req q5\n" "sw_23 .req s5\n"
+ "w_31 .req v6\n" "qw_31 .req q6\n" "sw_31 .req s6\n"
+ "w_32 .req v7\n" "qw_32 .req q7\n" "sw_32 .req s7\n"
+ "w_33 .req v8\n" "qw_33 .req q8\n" "sw_33 .req s8\n"
+
+ // Transformed matrix Ww
+ "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n"
+ "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n"
+ "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n"
+ "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n"
+
+ // Output matrix U = WwWT
+ "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n"
+ "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n"
+ "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n"
+ "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n"
+
+ // Storage view of output matrices
+ "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n"
+ "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n"
+ "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n"
+ "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n"
+
+ "sU11 .req s0\n" "sU12 .req s15\n" "sU13 .req s16\n" "sU14 .req s2\n"
+ "sU21 .req s9\n" "sU22 .req s17\n" "sU23 .req s18\n" "sU24 .req s11\n"
+ "sU31 .req s12\n" "sU32 .req s19\n" "sU33 .req s20\n" "sU34 .req s14\n"
+ "sU41 .req s6\n" "sU42 .req s21\n" "sU43 .req s22\n" "sU44 .req s8\n"
+
+ "half .req v23\n" // {0.5, ..., 0.5}
+ "dup half.4s, %w[one_half]\n"
+ "scratch .req v24\n"
+
+ // Subtract the tail from the number of remaining channels and jump to
+ // the tail if necessary.
+ "subs %x[n_remaining_channels], %x[n_remaining_channels], #1\n"
+ "beq 2f\n"
+
+ "1:"
+ // Load tile of the kernel
+ "ldr qw_11, [%x[inptr0]]\n"
+ "str qU11, [%x[outptr0]]\n"
+ "ldr qw_12, [%x[inptr0], %x[colstride1]]\n"
+ "ldr qw_13, [%x[inptr0], %x[colstride2]]\n"
+ "str qU14, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[inptr0], %x[inptr0], #0x10\n"
+
+ "ldr qw_21, [%x[inptr1]]\n"
+ "ldr qw_22, [%x[inptr1], %x[colstride1]]\n"
+ "ldr qw_23, [%x[inptr1], %x[colstride2]]\n"
+ "add %x[inptr1], %x[inptr1], #0x10\n"
+
+ "ldr qw_31, [%x[inptr2]]\n"
+ "str qU41, [%x[outptr12]]\n"
+ "ldr qw_32, [%x[inptr2], %x[colstride1]]\n"
+ "ldr qw_33, [%x[inptr2], %x[colstride2]]\n"
+ "str qU44, [%x[outptr12], %x[mstride3]]\n"
+ "add %x[inptr2], %x[inptr2], #0x10\n"
+
+ // Compute 2nd and 3rd rows of Ww
+ "fadd scratch.4s, w_11.4s, w_31.4s\n"
+ "fmul Ww21.4s, scratch.4s, half.4s\n"
+ "fmla Ww21.4s, w_21.4s, half.4s\n"
+ "str qU21, [%x[outptr4]]\n"
+ "fmul Ww31.4s, scratch.4s, half.4s\n"
+ "fmls Ww31.4s, w_21.4s, half.4s\n"
+ "str qU31, [%x[outptr8]]\n"
+
+ "fadd scratch.4s, w_12.4s, w_32.4s\n"
+ "fmul Ww22.4s, scratch.4s, half.4s\n"
+ "fmla Ww22.4s, w_22.4s, half.4s\n"
+ "fmul Ww32.4s, scratch.4s, half.4s\n"
+ "fmls Ww32.4s, w_22.4s, half.4s\n"
+
+ "fadd scratch.4s, w_13.4s, w_33.4s\n"
+ "fmul Ww23.4s, scratch.4s, half.4s\n"
+ "fmla Ww23.4s, w_23.4s, half.4s\n"
+ "str qU24, [%x[outptr4], %x[mstride3]]\n"
+ "fmul Ww33.4s, scratch.4s, half.4s\n"
+ "fmls Ww33.4s, w_23.4s, half.4s\n"
+ "str qU34, [%x[outptr8], %x[mstride3]]\n"
+
+ // Compute and store U, only need to compute the 2nd and 3rd columns
+ // of U and update output pointers
+ "fadd scratch.4s, Ww11.4s, Ww13.4s\n"
+ "fmul U12.4s, scratch.4s, half.4s\n"
+ "fmla U12.4s, Ww12.4s, half.4s\n"
+ "str qU12, [%x[outptr0], %x[mstride1]]\n"
+ "fmul U13.4s, scratch.4s, half.4s\n"
+ "fmls U13.4s, Ww12.4s, half.4s\n"
+ "str qU13, [%x[outptr0], %x[mstride2]]\n"
+ "add %x[outptr0], %x[outptr0], #0x10\n"
+
+ "fadd scratch.4s, Ww21.4s, Ww23.4s\n"
+ "fmul U22.4s, scratch.4s, half.4s\n"
+ "fmla U22.4s, Ww22.4s, half.4s\n"
+ "str qU22, [%x[outptr4], %x[mstride1]]\n"
+ "fmul U23.4s, scratch.4s, half.4s\n"
+ "fmls U23.4s, Ww22.4s, half.4s\n"
+ "str qU23, [%x[outptr4], %x[mstride2]]\n"
+ "add %x[outptr4], %x[outptr4], #0x10\n"
+
+ "fadd scratch.4s, Ww31.4s, Ww33.4s\n"
+ "fmul U32.4s, scratch.4s, half.4s\n"
+ "fmla U32.4s, Ww32.4s, half.4s\n"
+ "str qU32, [%x[outptr8], %x[mstride1]]\n"
+ "fmul U33.4s, scratch.4s, half.4s\n"
+ "fmls U33.4s, Ww32.4s, half.4s\n"
+ "str qU33, [%x[outptr8], %x[mstride2]]\n"
+ "add %x[outptr8], %x[outptr8], #0x10\n"
+
+ "fadd scratch.4s, Ww41.4s, Ww43.4s\n"
+ "fmul U42.4s, scratch.4s, half.4s\n"
+ "fmla U42.4s, Ww42.4s, half.4s\n"
+ "str qU42, [%x[outptr12], %x[mstride1]]\n"
+ "fmul U43.4s, scratch.4s, half.4s\n"
+ "fmls U43.4s, Ww42.4s, half.4s\n"
+ "str qU43, [%x[outptr12], %x[mstride2]]\n"
+ "add %x[outptr12], %x[outptr12], #0x10\n"
+
+ "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n"
+ "bne 1b\n"
+
+ // Tail size 1
+ "2:"
+ // Load tile of the kernel
+ "ldr sw_11, [%x[inptr0]]\n"
+ "str sU11, [%x[outptr0]]\n"
+ "ldr sw_12, [%x[inptr0], %x[colstride1]]\n"
+ "ldr sw_13, [%x[inptr0], %x[colstride2]]\n"
+ "str sU14, [%x[outptr0], %x[mstride3]]\n"
+ "add %x[inptr0], %x[inptr0], #0x04\n"
+
+ "ldr sw_21, [%x[inptr1]]\n"
+ "ldr sw_22, [%x[inptr1], %x[colstride1]]\n"
+ "ldr sw_23, [%x[inptr1], %x[colstride2]]\n"
+ "add %x[inptr1], %x[inptr1], #0x04\n"
+
+ "ldr sw_31, [%x[inptr2]]\n"
+ "str sU41, [%x[outptr12]]\n"
+ "ldr sw_32, [%x[inptr2], %x[colstride1]]\n"
+ "ldr sw_33, [%x[inptr2], %x[colstride2]]\n"
+ "str sU44, [%x[outptr12], %x[mstride3]]\n"
+ "add %x[inptr2], %x[inptr2], #0x04\n"
+
+ // Compute 2nd and 3rd rows of Ww
+ "fadd scratch.2s, w_11.2s, w_31.2s\n"
+ "fmul Ww21.2s, scratch.2s, half.2s\n"
+ "fmla Ww21.2s, w_21.2s, half.2s\n"
+ "str sU21, [%x[outptr4]]\n"
+ "fmul Ww31.2s, scratch.2s, half.2s\n"
+ "fmls Ww31.2s, w_21.2s, half.2s\n"
+ "str sU31, [%x[outptr8]]\n"
+
+ "fadd scratch.2s, w_12.2s, w_32.2s\n"
+ "fmul Ww22.2s, scratch.2s, half.2s\n"
+ "fmla Ww22.2s, w_22.2s, half.2s\n"
+ "fmul Ww32.2s, scratch.2s, half.2s\n"
+ "fmls Ww32.2s, w_22.2s, half.2s\n"
+
+ "fadd scratch.2s, w_13.2s, w_33.2s\n"
+ "fmul Ww23.2s, scratch.2s, half.2s\n"
+ "fmla Ww23.2s, w_23.2s, half.2s\n"
+ "str sU24, [%x[outptr4], %x[mstride3]]\n"
+ "fmul Ww33.2s, scratch.2s, half.2s\n"
+ "fmls Ww33.2s, w_23.2s, half.2s\n"
+ "str sU34, [%x[outptr8], %x[mstride3]]\n"
+
+ // Compute and store U, only need to compute the 2nd and 3rd columns of
+ // U and update output pointers
+ "fadd scratch.2s, Ww11.2s, Ww13.2s\n"
+ "fmul U12.2s, scratch.2s, half.2s\n"
+ "fmla U12.2s, Ww12.2s, half.2s\n"
+ "str sU12, [%x[outptr0], %x[mstride1]]\n"
+ "fmul U13.2s, scratch.2s, half.2s\n"
+ "fmls U13.2s, Ww12.2s, half.2s\n"
+ "str sU13, [%x[outptr0], %x[mstride2]]\n"
+ "add %x[outptr0], %x[outptr0], #0x04\n"
+
+ "fadd scratch.2s, Ww21.2s, Ww23.2s\n"
+ "fmul U22.2s, scratch.2s, half.2s\n"
+ "fmla U22.2s, Ww22.2s, half.2s\n"
+ "str sU22, [%x[outptr4], %x[mstride1]]\n"
+ "fmul U23.2s, scratch.2s, half.2s\n"
+ "fmls U23.2s, Ww22.2s, half.2s\n"
+ "str sU23, [%x[outptr4], %x[mstride2]]\n"
+ "add %x[outptr4], %x[outptr4], #0x04\n"
+
+ "fadd scratch.2s, Ww31.2s, Ww33.2s\n"
+ "fmul U32.2s, scratch.2s, half.2s\n"
+ "fmla U32.2s, Ww32.2s, half.2s\n"
+ "str sU32, [%x[outptr8], %x[mstride1]]\n"
+ "fmul U33.2s, scratch.2s, half.2s\n"
+ "fmls U33.2s, Ww32.2s, half.2s\n"
+ "str sU33, [%x[outptr8], %x[mstride2]]\n"
+ "add %x[outptr8], %x[outptr8], #0x04\n"
+
+ "fadd scratch.2s, Ww41.2s, Ww43.2s\n"
+ "fmul U42.2s, scratch.2s, half.2s\n"
+ "fmla U42.2s, Ww42.2s, half.2s\n"
+ "str sU42, [%x[outptr12], %x[mstride1]]\n"
+ "fmul U43.2s, scratch.2s, half.2s\n"
+ "fmls U43.2s, Ww42.2s, half.2s\n"
+ "str sU43, [%x[outptr12], %x[mstride2]]\n"
+ "add %x[outptr12], %x[outptr12], #0x04\n"
+
+ // Clear aliases
+ ".unreq half\n"
+ ".unreq scratch\n"
+ ".unreq w_11\n" ".unreq qw_11\n" ".unreq sw_11\n"
+ ".unreq w_12\n" ".unreq qw_12\n" ".unreq sw_12\n"
+ ".unreq w_13\n" ".unreq qw_13\n" ".unreq sw_13\n"
+ ".unreq w_21\n" ".unreq qw_21\n" ".unreq sw_21\n"
+ ".unreq w_22\n" ".unreq qw_22\n" ".unreq sw_22\n"
+ ".unreq w_23\n" ".unreq qw_23\n" ".unreq sw_23\n"
+ ".unreq w_31\n" ".unreq qw_31\n" ".unreq sw_31\n"
+ ".unreq w_32\n" ".unreq qw_32\n" ".unreq sw_32\n"
+ ".unreq w_33\n" ".unreq qw_33\n" ".unreq sw_33\n"
+ ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n"
+ ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n"
+ ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n"
+ ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n"
+ ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n"
+ ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n"
+ ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n"
+ ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n"
+ ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n"
+ ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n"
+ ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n"
+ ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n"
+ ".unreq sU11\n" ".unreq sU12\n" ".unreq sU13\n" ".unreq sU14\n"
+ ".unreq sU21\n" ".unreq sU22\n" ".unreq sU23\n" ".unreq sU24\n"
+ ".unreq sU31\n" ".unreq sU32\n" ".unreq sU33\n" ".unreq sU34\n"
+ ".unreq sU41\n" ".unreq sU42\n" ".unreq sU43\n" ".unreq sU44\n"
+
+ : [inptr0] "+r" (inptr0),
+ [inptr1] "+r" (inptr1),
+ [inptr2] "+r" (inptr2),
+ [outptr0] "+r" (outptr0),
+ [outptr4] "+r" (outptr4),
+ [outptr8] "+r" (outptr8),
+ [outptr12] "+r" (outptr12),
+ [n_remaining_channels] "+r" (n_remaining_channels)
+ : [mstride1] "r" (sizeof(float) * mstride),
+ [mstride2] "r" (sizeof(float) * mstride * 2),
+ [mstride3] "r" (sizeof(float) * mstride * 3),
+ [colstride1] "r" (sizeof(float) * kernel_col_stride),
+ [colstride2] "r" (sizeof(float) * kernel_col_stride * 2),
+ [one_half] "r" (0.5f)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24"
+ );
+
+ // Progression to complete stride
+ outptr0 += matrix_row_stride - n_output_channels;
+ outptr4 += matrix_row_stride - n_output_channels;
+ outptr8 += matrix_row_stride - n_output_channels;
+ outptr12 += matrix_row_stride - n_output_channels;
+ }
+}
+}
+#endif // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp
new file mode 100644
index 0000000000..0992c0bb44
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp
@@ -0,0 +1,356 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+namespace winograd {
+ /* Transform from the Winograd domain back to the spatial domain.
+ */
+ template <typename T>
+ struct Winograd2x2_3x3GemmOutput {
+ static void execute(
+ const Tensor4DShape &output_shape,
+ T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ T* const output
+ );
+
+ protected:
+ /* Specialised implementation method. */
+ template <bool tail_M, bool tail_N, int channel_tail>
+ static void _execute(
+ const Tensor4DShape &output_shape,
+ T *output,
+ const T *input,
+ const int matrix_stride,
+ const int matrix_row_stride
+ );
+ };
+
+ /* Two-stage implementation of the transformation from the Winograd domain.
+ *
+ * First computes Z.F and then computes (Z.F).Z^T.
+ */
+ template <typename T>
+ struct Winograd2x2_3x3GemmOutput_TwoStage {
+ static void execute(
+ const Tensor4DShape &output_shape,
+ T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ T* const output
+ );
+
+ protected:
+ template <int channel_tail>
+ static void compute_zf(
+ const int n_rows, const int n_channels,
+ T* const zf, const T* const input[16]
+ );
+
+ template <bool tail_M, bool tail_N, int channel_tail>
+ static void compute_zfzT(
+ const Tensor4DShape &output_shape,
+ T* const output, const T* const zf
+ );
+ };
+}
+
+#include "output_2x2_3x3/a64_float.hpp"
+// #include "output_2x2_3x3/a64_float_two_stage.hpp"
+
+/*****************************************************************************/
+/*
+template <typename T>
+void winograd::Winograd2x2_3x3GemmOutput<T>::execute(
+ const Tensor4DShape &output_shape,
+ const int tile_M,
+ const int tile_N,
+ T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ T* const output
+) {
+ T* const antipadding = reinterpret_cast<T *>(malloc(sizeof(T) * output_shape.n_channels));
+
+ // Get input pointers
+ const T* inptrs[16];
+ for (int i = 0; i < 16; i++) {
+ inptrs[i] = matrices[i];
+ }
+
+ for (int batch = 0; batch < output_shape.n_batches; batch++) {
+ for (int tile_i = 0; tile_i < tile_M; tile_i++) {
+ for (int tile_j = 0; tile_j < tile_N; tile_j++) {
+ // Get pointers for each of the 4 output cells required for this computation
+ T* outptrs[4];
+ for (int cell_i = 0, c = 0; cell_i < 2; cell_i++) {
+ for (int cell_j = 0; cell_j < 2; cell_j++, c++) {
+ const int i = tile_i*2 + cell_i;
+ const int j = tile_j*2 + cell_j;
+
+ if (i < output_shape.n_rows && j < output_shape.n_cols) {
+ outptrs[c] = output + (
+ (batch*output_shape.n_rows + i) * output_shape.n_cols +
+ j) * output_shape.n_channels;
+ } else {
+ outptrs[c] = antipadding;
+ }
+ } // cell_j
+ } // cell_i
+
+ for (int n = 0; n < output_shape.n_channels; n++) {
+ // Read 16 values and progress pointers
+ T v[16];
+ for (int i = 0; i < 16; i++) {
+ v[i] = *(inptrs[i]++);
+ }
+
+ // Compute output for 4 pixels
+ *(outptrs[0]++) = v[ 0] + v[ 1] + v[ 2] +
+ v[ 4] + v[ 5] + v[ 6] +
+ v[ 8] + v[ 9] + v[10];
+ *(outptrs[1]++) = v[ 1] - v[ 2] - v[ 3] +
+ v[ 5] - v[ 6] - v[ 7] +
+ v[ 9] - v[10] - v[11];
+ *(outptrs[2]++) = v[ 4] + v[ 5] + v[ 6] -
+ v[ 8] - v[ 9] - v[10] -
+ v[12] - v[13] - v[14];
+ *(outptrs[3]++) = v[ 5] - v[ 6] - v[ 7] -
+ v[ 9] + v[10] + v[11] -
+ v[13] + v[14] + v[15];
+ } // output_channel
+ } // tile_j
+ } // tile_i
+ } // batch
+
+ free(antipadding);
+}
+*/
+
+/*****************************************************************************/
+/*
+template <typename T>
+void winograd::Winograd2x2_3x3GemmOutput_TwoStage<T>::execute(
+ const Tensor4DShape &output_shape,
+ T* const matrices[16], T* const output
+) {
+ // Allocate memory for the intermediate matrices
+ const int tile_M = iceildiv(output_shape.n_rows, 2);
+ const int tile_N = iceildiv(output_shape.n_cols, 2);
+ const int n_rows = output_shape.n_batches * tile_M * tile_N;
+ const int n_channels = output_shape.n_channels;
+ T* matrices_zf = reinterpret_cast<T*>(
+ calloc(8 * n_rows * n_channels, sizeof(T))
+ );
+
+ // Perform the first stage transform, computing ZF.
+ // Specializations should dispatch to different methods based on tail size.
+ compute_zf<0>(n_rows, n_channels, matrices_zf, matrices);
+
+ // Perform the second stage transform, finishing Z F Z^T - variable dispatch
+ // based on size of the output. Specialisations can also dispatch based on
+ // the tail-size of the channel.
+ if (output_shape.n_rows % 2 && output_shape.n_cols % 2) {
+ compute_zfzT<true, true, 0>(output_shape, output, matrices_zf);
+ } else if (output_shape.n_rows % 2) {
+ compute_zfzT<true, false, 0>(output_shape, output, matrices_zf);
+ } else if (output_shape.n_cols % 2) {
+ compute_zfzT<false, true, 0>(output_shape, output, matrices_zf);
+ } else {
+ compute_zfzT<false, false, 0>(output_shape, output, matrices_zf);
+ }
+
+ free(reinterpret_cast<void*>(matrices_zf));
+}
+
+template <typename T>
+template <int channel_tail>
+void winograd::Winograd2x2_3x3GemmOutput_TwoStage<T>::compute_zf(
+ const int n_rows, const int n_channels,
+ T* output, const T* const input[16]
+) {
+ // Extract 8 output pointers
+ T* outptr[8];
+ for (int i = 0; i < 8; i++) {
+ outptr[i] = output + i*n_rows*n_channels;
+ }
+
+ // Copy the 16 input pointers
+ const T* inptr[16];
+ for (int i = 0; i < 16; i++) {
+ inptr[i] = input[i];
+ }
+
+ // For every row of the matrices
+ for (int i = 0; i < n_rows; i++) {
+ // For every channel
+ for (int j = 0; j < n_channels; j++) {
+ // Extract values from the input matrices
+ T val[16];
+ for (int n = 0; n < 16; n++) {
+ val[n] = *(inptr[n]++);
+ }
+
+ // Compute output values
+ *(outptr[0]++) = val[0] + val[1] + val[2];
+ *(outptr[1]++) = val[1] - val[2] - val[3];
+ *(outptr[2]++) = val[4] + val[5] + val[6];
+ *(outptr[3]++) = val[5] - val[6] - val[7];
+ *(outptr[4]++) = val[8] + val[9] + val[10];
+ *(outptr[5]++) = val[9] - val[10] - val[11];
+ *(outptr[6]++) = val[12] + val[13] + val[14];
+ *(outptr[7]++) = val[13] - val[14] - val[15];
+ }
+ }
+}
+
+template <typename T>
+template <bool tail_M, bool tail_N, int channel_tail>
+void winograd::Winograd2x2_3x3GemmOutput_TwoStage<T>::compute_zfzT(
+ const Tensor4DShape &output_shape,
+ T* const output, const T* const input
+) {
+ // Sizing information
+ const int tile_M = output_shape.n_rows / 2;
+ const int tile_N = output_shape.n_cols / 2;
+
+ const int n_rows = (output_shape.n_batches *
+ (tile_M + (tail_M ? 1 : 0)) *
+ (tile_N + (tail_N ? 1 : 0)));
+ const int n_channels = output_shape.n_channels;
+
+ // Extract 8 input pointers
+ const T* inptr[8];
+ for (int i = 0; i < 8; i++) {
+ inptr[i] = input + i*n_rows*n_channels;
+ }
+
+ // Extract 4 output pointers
+ T* outptr00 = output;
+ T* outptr01 = outptr00 + n_channels;
+ T* outptr10 = outptr00 + output_shape.n_cols * n_channels;
+ T* outptr11 = outptr10 + n_channels;
+
+ // Progress over the output tiles, generating output values.
+ for (int batch = 0; batch < output_shape.n_batches; batch++) {
+ for (int tile_i = 0; tile_i < tile_M; tile_i++) {
+ for (int tile_j = 0; tile_j < tile_N; tile_j++) {
+ for (int channel = 0; channel < n_channels; channel++) {
+ // Read values from the input pointers
+ T v[8];
+ for (int i = 0; i < 8; i++) {
+ v[i] = *(inptr[i]++);
+ }
+
+ // Compute the output values and progress the output pointers.
+ *(outptr00++) = v[0] + v[2] + v[4];
+ *(outptr01++) = v[1] + v[3] + v[5];
+ *(outptr10++) = v[2] - v[4] - v[6];
+ *(outptr11++) = v[3] - v[5] - v[7];
+ }
+
+ // Progress the output pointers to the next column
+ outptr00 += n_channels;
+ outptr01 += n_channels;
+ outptr10 += n_channels;
+ outptr11 += n_channels;
+ }
+
+ if (tail_N) {
+ // Only evaluate the left-most columns of the output
+ for (int channel = 0; channel < n_channels; channel++) {
+ // Read values from the input pointers
+ T v[8];
+ for (int i = 0; i < 4; i++) {
+ v[i * 2] = *inptr[i * 2];
+ }
+ for (int i = 0; i < 8; i++) {
+ inptr[i]++;
+ }
+
+ // Compute the output values and progress the output pointers.
+ *(outptr00++) = v[0] + v[2] + v[4];
+ *(outptr10++) = v[2] - v[4] - v[6];
+ }
+
+ // Progress the output pointers to the next column
+ outptr01 += n_channels; // Account for being skipped above
+ outptr11 += n_channels; // Account for being skipped above
+ }
+
+ // Progress the output pointers to the next row
+ outptr00 += output_shape.n_cols * n_channels;
+ outptr01 += output_shape.n_cols * n_channels;
+ outptr10 += output_shape.n_cols * n_channels;
+ outptr11 += output_shape.n_cols * n_channels;
+ }
+
+ if (tail_M) {
+ // Only work on the upper row of the output
+ for (int tile_j = 0; tile_j < tile_N; tile_j++) {
+ for (int channel = 0; channel < n_channels; channel++) {
+ // Read values from the input pointers
+ T v[8];
+ for (int i = 0; i < 8; i++) {
+ v[i] = *(inptr[i]++);
+ }
+
+ // Compute the output values and progress the output pointers.
+ *(outptr00++) = v[0] + v[2] + v[4];
+ *(outptr01++) = v[1] + v[3] + v[5];
+ }
+
+ // Progress the output pointers to the next column
+ outptr00 += n_channels;
+ outptr01 += n_channels;
+ outptr10 += 2 * n_channels; // Account for being skipped above
+ outptr11 += 2 * n_channels; // Account for being skipped above
+ }
+
+ if (tail_N) {
+ // Only evaluate the upper-left cell of the output
+ for (int channel = 0; channel < n_channels; channel++) {
+ // Read values from the input pointers
+ T v[8];
+ for (int i = 0; i < 3; i++) {
+ v[i * 2] = *inptr[i * 2];
+ }
+ for (int i = 0; i < 8; i++) {
+ inptr[i]++;
+ }
+
+ // Compute the output values and progress the output pointers.
+ *(outptr00++) = v[0] + v[2] + v[4];
+ }
+
+ // Progress the output pointers to the next column
+ outptr01 += n_channels; // Account for being skipped above
+ outptr10 += n_channels; // Account for being skipped above
+ outptr11 += n_channels; // Account for being skipped above
+ }
+ }
+ }
+}
+*/
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp
new file mode 100644
index 0000000000..5925f9d569
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp
@@ -0,0 +1,650 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+/* Float implementation for AArch64.
+ */
+#ifdef __aarch64__
+namespace winograd {
+
+
+template <>
+template <>
+inline void Winograd2x2_3x3GemmOutput<float>::_execute<false, false, 0>(
+ const Tensor4DShape &output_shape,
+ float *output,
+ const float *input,
+ const int mstride,
+ const int matrix_row_stride
+) {
+ const int tile_M = output_shape.n_rows / 2;
+ const int tile_N = output_shape.n_cols / 2;
+ int batch = output_shape.n_batches;
+ float *outptr = output;
+
+ const float *inptr0 = input;
+ const float *inptr4 = input + 4 * mstride;
+ const float *inptr8 = input + 8 * mstride;
+ const float *inptr12 = input + 12 * mstride;
+
+ const size_t col_stride = sizeof(float) * output_shape.n_channels;
+ const size_t row_stride = col_stride * tile_N * 2;
+
+ asm volatile (
+ // Aliases for elements of the input matrix `F`
+ // V-register Q-register
+ "F11 .req v0\n" "qF11 .req q0\n"
+ "F12 .req v1\n" "qF12 .req q1\n"
+ "F13 .req v2\n" "qF13 .req q2\n"
+ "F14 .req v3\n" "qF14 .req q3\n"
+ "F21 .req v4\n" "qF21 .req q4\n"
+ "F22 .req v5\n" "qF22 .req q5\n"
+ "F23 .req v6\n" "qF23 .req q6\n"
+ "F24 .req v7\n" "qF24 .req q7\n"
+ "F31 .req v8\n" "qF31 .req q8\n"
+ "F32 .req v9\n" "qF32 .req q9\n"
+ "F33 .req v10\n" "qF33 .req q10\n"
+ "F34 .req v11\n" "qF34 .req q11\n"
+ "F41 .req v12\n" "qF41 .req q12\n"
+ "F42 .req v13\n" "qF42 .req q13\n"
+ "F43 .req v14\n" "qF43 .req q14\n"
+ "F44 .req v15\n" "qF44 .req q15\n"
+
+ // Aliases for elements of the intermediate matrix `FZ`
+ "FZ11 .req v16\n"
+ "FZ12 .req v17\n"
+ "FZ21 .req v18\n"
+ "FZ22 .req v19\n"
+ "FZ31 .req v20\n"
+ "FZ32 .req v21\n"
+ "FZ41 .req v22\n"
+ "FZ42 .req v23\n"
+
+ // Aliases for elements of the output matrix `f` (called `g` due to case
+ // insensitivity of aliases).
+ " g11 .req v24\n"
+ "qg11 .req q24\n"
+ " g12 .req v25\n"
+ "qg12 .req q25\n"
+ " g21 .req v26\n"
+ "qg21 .req q26\n"
+ " g22 .req v27\n"
+ "qg22 .req q27\n"
+
+ // Prepare the various strides
+ "col_stride .req %x[col_stride]\n"
+ "row_stride .req %x[row_stride]\n"
+ "row_plus_col_stride .req %x[row_plus_col_stride]\n"
+
+ "mstride1 .req %x[mstride1]\n"
+ "mstride2 .req %x[mstride2]\n"
+ "mstride3 .req %x[mstride3]\n"
+
+ "tile_i .req x19\n" // Tile row counter
+ "tile_j .req x20\n" // Tile column counter
+ "channel .req x21\n" // Channel counter
+
+ "1:" // Loop over batches
+ "mov tile_i, %x[tile_M]\n" // Reset tile row counter
+
+ "2:" // Loop over rows of tiles
+ "mov tile_j, %x[tile_N]\n" // Reset tile column counter
+
+ "3:" // Loop over columns of tiles
+ // Perform initial loads of the matrix `F`
+ "ldr qF11, [%x[inptr0]]\n"
+ "ldr qF12, [%x[inptr0], mstride1]\n"
+ "ldr qF13, [%x[inptr0], mstride2]\n"
+ "ldr qF14, [%x[inptr0], mstride3]\n"
+ "add %x[inptr0], %x[inptr0], #0x10\n"
+ "ldr qF21, [%x[inptr4]]\n"
+ "ldr qF22, [%x[inptr4], mstride1]\n"
+ "subs channel, %x[n_channels], #4\n" // Reset channel counter
+
+ "ldr qF23, [%x[inptr4], mstride2]\n"
+ "ldr qF24, [%x[inptr4], mstride3]\n"
+ "add %x[inptr4], %x[inptr4], #0x10\n"
+ "beq 5f\n" // Jump straight to tail if necessary
+
+ "4:" // Loop over channels
+ "ldr qF31, [%x[inptr8]]\n"
+ "fadd FZ11.4s, F11.4s, F12.4s\n"
+
+ "ldr qF32, [%x[inptr8], mstride1]\n"
+ "fsub FZ12.4s, F12.4s, F13.4s\n"
+
+ "ldr qF33, [%x[inptr8], mstride2]\n"
+ "fadd FZ11.4s, FZ11.4s, F13.4s\n"
+
+ "ldr qF34, [%x[inptr8], mstride3]\n"
+ "fsub FZ12.4s, FZ12.4s, F14.4s\n"
+
+ "ldr qF41, [%x[inptr12]]\n"
+ "fadd FZ21.4s, F21.4s, F22.4s\n"
+
+ "ldr qF42, [%x[inptr12], mstride1]\n"
+ "fsub FZ22.4s, F22.4s, F23.4s\n"
+
+ "ldr qF43, [%x[inptr12], mstride2]\n"
+ "fadd FZ21.4s, FZ21.4s, F23.4s\n"
+
+ "ldr qF44, [%x[inptr12], mstride3]\n"
+ "fsub FZ22.4s, FZ22.4s, F24.4s\n"
+
+ "fadd FZ31.4s, F31.4s, F32.4s\n"
+ "add %x[inptr8], %x[inptr8], #0x10\n"
+
+ "fsub FZ32.4s, F32.4s, F33.4s\n"
+ "add %x[inptr12], %x[inptr12], #0x10\n"
+
+ "fadd FZ31.4s, FZ31.4s, F33.4s\n"
+
+ "fsub FZ32.4s, FZ32.4s, F34.4s\n"
+
+ "fadd g11.4s, FZ11.4s, FZ21.4s\n"
+
+ "fadd g12.4s, FZ12.4s, FZ22.4s\n"
+
+ "fadd g11.4s, g11.4s, FZ31.4s\n"
+
+ "fadd g12.4s, g12.4s, FZ32.4s\n"
+
+ "ldr qF11, [%x[inptr0]]\n"
+ "fadd FZ41.4s, F41.4s, F42.4s\n"
+
+ "ldr qF12, [%x[inptr0], mstride1]\n"
+ "fsub g21.4s, FZ21.4s, FZ31.4s\n"
+
+ "ldr qF13, [%x[inptr0], mstride2]\n"
+ "fsub FZ42.4s, F42.4s, F43.4s\n"
+
+ "ldr qF14, [%x[inptr0], mstride3]\n"
+ "str qg11, [%x[outptr]]\n"
+
+ "ldr qF21, [%x[inptr4]]\n"
+ "fadd FZ41.4s, FZ41.4s, F43.4s\n"
+
+ "ldr qF22, [%x[inptr4], mstride1]\n"
+ "str qg12, [%x[outptr], col_stride]\n"
+
+ "ldr qF23, [%x[inptr4], mstride2]\n"
+ "fsub FZ42.4s, FZ42.4s, F44.4s\n"
+
+ "ldr qF24, [%x[inptr4], mstride3]\n"
+ "fsub g22.4s, FZ22.4s, FZ32.4s\n"
+
+ "fsub g21.4s, g21.4s, FZ41.4s\n"
+ "add %x[inptr0], %x[inptr0], #0x10\n"
+
+ "fsub g22.4s, g22.4s, FZ42.4s\n"
+ "add %x[inptr4], %x[inptr4], #0x10\n"
+
+ "subs channel, channel, #4\n"
+
+ "str qg21, [%x[outptr], row_stride]\n"
+
+ "str qg22, [%x[outptr], row_plus_col_stride]\n"
+
+ "add %x[outptr], %x[outptr], #0x10\n"
+
+ "bne 4b\n"
+
+ "5:" // Channel tail
+ "ldr qF31, [%x[inptr8]]\n"
+ "fadd FZ11.4s, F11.4s, F12.4s\n"
+
+ "ldr qF32, [%x[inptr8], mstride1]\n"
+ "fsub FZ12.4s, F12.4s, F13.4s\n"
+
+ "ldr qF33, [%x[inptr8], mstride2]\n"
+ "fadd FZ11.4s, FZ11.4s, F13.4s\n"
+
+ "ldr qF34, [%x[inptr8], mstride3]\n"
+ "fsub FZ12.4s, FZ12.4s, F14.4s\n"
+
+ "ldr qF41, [%x[inptr12]]\n"
+ "fadd FZ21.4s, F21.4s, F22.4s\n"
+
+ "ldr qF42, [%x[inptr12], mstride1]\n"
+ "fsub FZ22.4s, F22.4s, F23.4s\n"
+
+ "ldr qF43, [%x[inptr12], mstride2]\n"
+ "fadd FZ21.4s, FZ21.4s, F23.4s\n"
+
+ "ldr qF44, [%x[inptr12], mstride3]\n"
+ "fsub FZ22.4s, FZ22.4s, F24.4s\n"
+
+ "fadd FZ31.4s, F31.4s, F32.4s\n"
+ "add %x[inptr8], %x[inptr8], #0x10\n"
+
+ "fsub FZ32.4s, F32.4s, F33.4s\n"
+ "add %x[inptr12], %x[inptr12], #0x10\n"
+
+ "fadd FZ31.4s, FZ31.4s, F33.4s\n"
+
+ "fsub FZ32.4s, FZ32.4s, F34.4s\n"
+
+ "fadd g11.4s, FZ11.4s, FZ21.4s\n"
+
+ "fadd g12.4s, FZ12.4s, FZ22.4s\n"
+
+ "fadd g11.4s, g11.4s, FZ31.4s\n"
+
+ "fadd g12.4s, g12.4s, FZ32.4s\n"
+
+ "fadd FZ41.4s, F41.4s, F42.4s\n"
+
+ "fsub g21.4s, FZ21.4s, FZ31.4s\n"
+
+ "fsub FZ42.4s, F42.4s, F43.4s\n"
+
+ "str qg11, [%x[outptr]]\n"
+
+ "fadd FZ41.4s, FZ41.4s, F43.4s\n"
+
+ "str qg12, [%x[outptr], col_stride]\n"
+
+ "fsub FZ42.4s, FZ42.4s, F44.4s\n"
+
+ "fsub g22.4s, FZ22.4s, FZ32.4s\n"
+
+ "fsub g21.4s, g21.4s, FZ41.4s\n"
+
+ "fsub g22.4s, g22.4s, FZ42.4s\n"
+
+ "subs channel, channel, #4\n"
+
+ "str qg21, [%x[outptr], row_stride]\n"
+
+ // Progress input pointers to the next row of the matrix
+ "add %x[inptr0], %x[inptr0], %x[mrowpad]\n"
+ "add %x[inptr4], %x[inptr4], %x[mrowpad]\n"
+ "add %x[inptr8], %x[inptr8], %x[mrowpad]\n"
+ "add %x[inptr12], %x[inptr12], %x[mrowpad]\n"
+
+ "str qg22, [%x[outptr], row_plus_col_stride]\n"
+
+ "add %x[outptr], %x[outptr], #0x10\n"
+
+
+ "add %x[outptr], %x[outptr], col_stride\n"
+ "subs tile_j, tile_j, #1\n"
+ "bne 3b\n"
+
+ "add %x[outptr], %x[outptr], row_stride\n"
+ "subs tile_i, tile_i, #1\n"
+ "bne 2b\n"
+
+ "subs %[batch], %[batch], #1\n"
+ "bne 1b\n"
+
+ ".unreq F11\n" ".unreq qF11\n"
+ ".unreq F12\n" ".unreq qF12\n"
+ ".unreq F13\n" ".unreq qF13\n"
+ ".unreq F14\n" ".unreq qF14\n"
+ ".unreq F21\n" ".unreq qF21\n"
+ ".unreq F22\n" ".unreq qF22\n"
+ ".unreq F23\n" ".unreq qF23\n"
+ ".unreq F24\n" ".unreq qF24\n"
+ ".unreq F31\n" ".unreq qF31\n"
+ ".unreq F32\n" ".unreq qF32\n"
+ ".unreq F33\n" ".unreq qF33\n"
+ ".unreq F34\n" ".unreq qF34\n"
+ ".unreq F41\n" ".unreq qF41\n"
+ ".unreq F42\n" ".unreq qF42\n"
+ ".unreq F43\n" ".unreq qF43\n"
+ ".unreq F44\n" ".unreq qF44\n"
+
+ ".unreq FZ11\n" ".unreq FZ12\n"
+ ".unreq FZ21\n" ".unreq FZ22\n"
+ ".unreq FZ31\n" ".unreq FZ32\n"
+ ".unreq FZ41\n" ".unreq FZ42\n"
+
+ ".unreq g11\n" ".unreq qg11\n"
+ ".unreq g12\n" ".unreq qg12\n"
+ ".unreq g21\n" ".unreq qg21\n"
+ ".unreq g22\n" ".unreq qg22\n"
+
+ ".unreq col_stride\n"
+ ".unreq row_stride\n"
+ ".unreq row_plus_col_stride\n"
+
+ ".unreq mstride1\n"
+ ".unreq mstride2\n"
+ ".unreq mstride3\n"
+
+ ".unreq tile_i \n"
+ ".unreq tile_j \n"
+ ".unreq channel\n"
+
+ : [batch] "+r" (batch),
+ [outptr] "+r" (outptr),
+ [inptr0] "+r" (inptr0),
+ [inptr4] "+r" (inptr4),
+ [inptr8] "+r" (inptr8),
+ [inptr12] "+r" (inptr12)
+ : [tile_M] "r" (tile_M),
+ [tile_N] "r" (tile_N),
+ [n_channels] "r" (output_shape.n_channels),
+ [col_stride] "r" (col_stride),
+ [row_stride] "r" (row_stride),
+ [row_plus_col_stride] "r" (row_stride + col_stride),
+ [mstride1] "r" (mstride * sizeof(float)),
+ [mstride2] "r" (2 * mstride * sizeof(float)),
+ [mstride3] "r" (3 * mstride * sizeof(float)),
+ [mrowpad] "r" ((matrix_row_stride - output_shape.n_channels) * sizeof(float))
+ : "x19", "x20", "x21",
+ "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
+ "q12", "q13", "q14", "q15", "q16", "q17", "q18", "q19", "q20", "q21",
+ "q22", "q23", "q24", "q25", "q26", "q27",
+ "cc", "memory"
+ );
+}
+
+template <>
+template <bool tail_M, bool tail_N, const int channel_tail>
+inline void Winograd2x2_3x3GemmOutput<float>::_execute(
+ const Tensor4DShape &output_shape,
+ float *output,
+ const float *input,
+ const int mstride,
+ const int matrix_row_stride
+) {
+ // Compute basic information about the shape of the matrices
+ const int tile_M = output_shape.n_rows / 2;
+ const int tile_N = output_shape.n_cols / 2;
+ const int n_channels = output_shape.n_channels;
+
+ // Extract 16 input pointers
+ const float* inptr[16];
+ for (int i = 0; i < 16; i++) {
+ inptr[i] = input + i*mstride;
+ }
+
+ // Extract 4 output pointers
+ float *outptr00 = output;
+ float *outptr01 = outptr00 + n_channels;
+ float *outptr10 = outptr00 + output_shape.n_cols * n_channels;
+ float *outptr11 = outptr10 + n_channels;
+
+ // Progress over the output tiles, generating output values.
+ for (int batch = 0; batch < output_shape.n_batches; batch++) {
+ for (int tile_i = 0; tile_i < tile_M; tile_i++) {
+ for (int tile_j = 0; tile_j < tile_N; tile_j++) {
+ for (int channel = 0; channel < n_channels; channel++) {
+ // Read values from the input pointers
+ float F[4][4];
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ F[i][j] = *(inptr[i*4 + j]++);
+ }
+ }
+
+ // Compute the matrix F.Z
+ float ZF[4][2];
+ ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
+ ZF[0][1] = F[0][1] - F[0][2] - F[0][3];
+ ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
+ ZF[1][1] = F[1][1] - F[1][2] - F[1][3];
+ ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
+ ZF[2][1] = F[2][1] - F[2][2] - F[2][3];
+ ZF[3][0] = F[3][0] + F[3][1] + F[3][2];
+ ZF[3][1] = F[3][1] - F[3][2] - F[3][3];
+
+ // Hence compute the output matrix Z^T . (F.Z)
+ *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
+ *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1];
+ *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0];
+ *(outptr11++) = ZF[1][1] - ZF[2][1] - ZF[3][1];
+ }
+
+ // Progress the input pointers to the next row
+ for (int i = 0; i < 16; i++) {
+ inptr[i] += matrix_row_stride - n_channels;
+ }
+
+ // Progress the output pointers to the next column
+ outptr00 += n_channels;
+ outptr01 += n_channels;
+ outptr10 += n_channels;
+ outptr11 += n_channels;
+ }
+
+ if (tail_N) {
+ // Only evaluate the left-most columns of the output
+ for (int channel = 0; channel < n_channels; channel++) {
+ // Read values from the input pointers
+ float F[4][3];
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 3; j++) {
+ F[i][j] = *(inptr[i*4 + j]++);
+ }
+ }
+ for (int i = 0; i < 4; i++) {
+ inptr[i*4 + 3]++;
+ }
+
+ // Compute the matrix F.Z
+ float ZF[4][1];
+ ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
+ ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
+ ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
+ ZF[3][0] = F[3][0] + F[3][1] + F[3][2];
+
+ // Hence compute the output matrix Z^T . (F.Z)
+ *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
+ *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0];
+ }
+
+ // Progress the input pointers to the next row
+ for (int i = 0; i < 16; i++) {
+ inptr[i] += matrix_row_stride - n_channels;
+ }
+
+ // Progress the output pointers to the next column
+ outptr01 += n_channels; // Account for being skipped above
+ outptr11 += n_channels; // Account for being skipped above
+ }
+
+ // Progress the output pointers to the next row
+ outptr00 += output_shape.n_cols * n_channels;
+ outptr01 += output_shape.n_cols * n_channels;
+ outptr10 += output_shape.n_cols * n_channels;
+ outptr11 += output_shape.n_cols * n_channels;
+ }
+
+ if (tail_M) {
+ // Only work on the upper row of the output
+ for (int tile_j = 0; tile_j < tile_N; tile_j++) {
+ for (int channel = 0; channel < n_channels; channel++) {
+ // Read values from the input pointers
+ float F[3][4];
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 4; j++) {
+ F[i][j] = *(inptr[i*4 + j]++);
+ }
+ }
+ for (int j = 0; j < 4; j++) {
+ inptr[12 + j]++;
+ }
+
+ // Compute the matrix F.Z
+ float ZF[3][2];
+ ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
+ ZF[0][1] = F[0][1] - F[0][2] - F[0][3];
+ ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
+ ZF[1][1] = F[1][1] - F[1][2] - F[1][3];
+ ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
+ ZF[2][1] = F[2][1] - F[2][2] - F[2][3];
+
+ // Hence compute the output matrix Z^T . (F.Z)
+ *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
+ *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1];
+ }
+
+ // Progress the input pointers to the next row
+ for (int i = 0; i < 16; i++) {
+ inptr[i] += matrix_row_stride - n_channels;
+ }
+
+ // Progress the output pointers to the next column
+ outptr00 += n_channels;
+ outptr01 += n_channels;
+ outptr10 += 2 * n_channels; // Account for being skipped above
+ outptr11 += 2 * n_channels; // Account for being skipped above
+ }
+
+ if (tail_N) {
+ // Only evaluate the upper-left cell of the output
+ for (int channel = 0; channel < n_channels; channel++) {
+ // Read values from the input pointers
+ float F[3][3];
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 3; j++) {
+ F[i][j] = *(inptr[i*4 + j]);
+ }
+ }
+ for (int i = 0; i < 16; i++) {
+ inptr[i]++;
+ }
+
+ // Compute the matrix F.Z
+ float ZF[3][1];
+ ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
+ ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
+ ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
+
+ // Hence compute the output matrix Z^T . (F.Z)
+ *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
+ }
+
+ // Progress the input pointers to the next row
+ for (int i = 0; i < 16; i++) {
+ inptr[i] += matrix_row_stride - n_channels;
+ }
+
+ // Progress the output pointers to the next column
+ outptr01 += n_channels; // Account for being skipped above
+ outptr10 += n_channels; // Account for being skipped above
+ outptr11 += n_channels; // Account for being skipped above
+ }
+ }
+ }
+}
+
+/*****************************************************************************/
+template <>
+inline void Winograd2x2_3x3GemmOutput<float>::execute(
+ const Tensor4DShape &output_shape,
+ float* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ float* const output
+) {
+ // Dispatch to an appropriate implementation based on the shape of the output
+ // tensor.
+ if (output_shape.n_rows % 2 && output_shape.n_cols % 2) {
+ constexpr bool tail_M = true, tail_N = true;
+ switch (output_shape.n_channels % 4) {
+ case 0:
+ _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ case 1:
+ _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ case 2:
+ _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ case 3:
+ _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ default:
+ assert(0);
+ break;
+ }
+ } else if (output_shape.n_rows % 2) {
+ constexpr bool tail_M = true, tail_N = false;
+ switch (output_shape.n_channels % 4) {
+ case 0:
+ _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ case 1:
+ _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ case 2:
+ _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ case 3:
+ _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ default:
+ assert(0);
+ break;
+ }
+ } else if (output_shape.n_cols % 2) {
+ constexpr bool tail_M = false, tail_N = true;
+ switch (output_shape.n_channels % 4) {
+ case 0:
+ _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ case 1:
+ _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ case 2:
+ _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ case 3:
+ _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ default:
+ assert(0);
+ break;
+
+ }
+ } else {
+ constexpr bool tail_M = false, tail_N = false;
+ switch (output_shape.n_channels % 4) {
+ case 0:
+ _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ case 1:
+ _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ case 2:
+ _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ case 3:
+ _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
+ break;
+ default:
+ assert(0);
+ break;
+
+ }
+ }
+}
+/*****************************************************************************/
+
+} // namespace winograd
+#endif // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp
new file mode 100644
index 0000000000..f551b12b52
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp
@@ -0,0 +1,655 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#ifdef __aarch64__
+
+/*****************************************************************************/
+// Compute ZF specializations
+
+template <>
+template <>
+inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::compute_zf<0>(
+ const int n_rows, const int n_channels,
+ float* output, const float* const input[16]
+) {
+ // Make copies of some variables
+ int row = n_rows;
+ float* outptr = output;
+ const float* inptr = input[0];
+
+ // Perform the transformation
+ asm volatile (
+ // "inptr0 .req %x[inptr]\n"
+ "inptr1 .req x0\n"
+ "inptr2 .req x1\n"
+ "inptr3 .req x2\n"
+ "inptr4 .req x3\n"
+ "inptr5 .req x4\n"
+ "inptr6 .req x5\n"
+ "inptr7 .req x6\n"
+ "inptr8 .req x7\n"
+ "inptr9 .req x8\n"
+ "inptr10 .req x9\n"
+ "inptr11 .req x10\n"
+ "inptr12 .req x11\n"
+ "inptr13 .req x12\n"
+ "inptr14 .req x13\n"
+ "inptr15 .req x14\n"
+
+ // "outptr0 .req %x[outptr]\n"
+ "outptr1 .req x15\n"
+ "outptr2 .req x16\n"
+ "outptr3 .req x17\n"
+ "outptr4 .req x18\n"
+ "outptr5 .req x19\n"
+ "outptr6 .req x20\n"
+ "outptr7 .req x21\n"
+
+ // Compute additional pointers into the input and output matrices.
+ "mstride .req x22\n" // Matrix stride
+ "mul mstride, %x[row], %x[n_channels]\n"
+ "lsl mstride, mstride, #2\n" // * sizeof(float)
+
+ "add inptr1, %x[inptr], mstride\n"
+ "add inptr2, %x[inptr], mstride, LSL #1\n"
+ "add inptr3, inptr2, mstride\n"
+ "add inptr4, inptr3, mstride\n"
+ "add inptr5, inptr4, mstride\n"
+ "add inptr6, inptr5, mstride\n"
+ "add inptr7, inptr6, mstride\n"
+ "add inptr8, inptr7, mstride\n"
+ "add inptr9, inptr8, mstride\n"
+ "add inptr10, inptr9, mstride\n"
+ "add inptr11, inptr10, mstride\n"
+ "add inptr12, inptr11, mstride\n"
+ "add inptr13, inptr12, mstride\n"
+ "add inptr14, inptr13, mstride\n"
+ "add inptr15, inptr14, mstride\n"
+
+ "add outptr1, %[outptr], mstride\n"
+ "add outptr2, outptr1, mstride\n"
+ "add outptr3, outptr2, mstride\n"
+ "add outptr4, outptr3, mstride\n"
+ "add outptr5, outptr4, mstride\n"
+ "add outptr6, outptr5, mstride\n"
+ "add outptr7, outptr6, mstride\n"
+
+ ".unreq mstride\n"
+
+ "column .req x22\n" // Column loop counter
+
+ "1:" // Loop over rows
+ "ldr q0, [%x[inptr]], #0x10\n"
+ "ldr q1, [inptr1], #0x10\n"
+ "ldr q2, [inptr2], #0x10\n"
+ "ldr q3, [inptr3], #0x10\n"
+ "ldr q4, [inptr4], #0x10\n"
+ "ldr q5, [inptr5], #0x10\n"
+ "ldr q6, [inptr6], #0x10\n"
+ "ldr q7, [inptr7], #0x10\n"
+ "subs column, %x[n_channels], #0x4\n"
+ "beq 3f\n"
+
+ "2:" // Loop over columns
+ "ldr q8, [inptr8], #0x10\n"
+ "prfm pldl1keep, [%x[inptr], #196]\n"
+ "fadd v16.4s, v0.4s, v1.4s\n"
+
+ "ldr q9, [inptr9], #0x10\n"
+ "prfm pldl1keep, [inptr1, #196]\n"
+ "fsub v17.4s, v1.4s, v2.4s\n"
+
+ "ldr q10, [inptr10], #0x10\n"
+ "prfm pldl1keep, [inptr2, #196]\n"
+ "fadd v16.4s, v16.4s, v2.4s\n"
+
+ "ldr q11, [inptr11], #0x10\n"
+ "prfm pldl1keep, [inptr3, #196]\n"
+ "fsub v17.4s, v17.4s, v3.4s\n"
+
+ "ldr q12, [inptr12], #0x10\n"
+ "prfm pldl1keep, [inptr4, #196]\n"
+ "str q16, [%x[outptr]], #0x10\n"
+
+ "ldr q13, [inptr13], #0x10\n"
+ "prfm pldl1keep, [inptr5, #196]\n"
+ "str q17, [outptr1], #0x10\n"
+
+ "ldr q14, [inptr14], #0x10\n"
+ "prfm pldl1keep, [inptr6, #196]\n"
+ "fadd v16.4s, v4.4s, v5.4s\n"
+
+ "ldr q15, [inptr15], #0x10\n"
+ "prfm pldl1keep, [inptr7, #196]\n"
+ "fsub v17.4s, v5.4s, v6.4s\n"
+
+ "ldr q0, [%x[inptr]], #0x10\n"
+ "prfm pldl1keep, [inptr8, #196]\n"
+ "fadd v16.4s, v16.4s, v6.4s\n"
+
+ "ldr q1, [inptr1], #0x10\n"
+ "prfm pldl1keep, [inptr9, #196]\n"
+ "fsub v17.4s, v17.4s, v7.4s\n"
+
+ "ldr q2, [inptr2], #0x10\n"
+ "prfm pldl1keep, [inptr10, #196]\n"
+ "str q16, [outptr2], #0x10\n"
+
+ "ldr q3, [inptr3], #0x10\n"
+ "prfm pldl1keep, [inptr11, #196]\n"
+ "str q17, [outptr3], #0x10\n"
+
+ "ldr q4, [inptr4], #0x10\n"
+ "prfm pldl1keep, [inptr12, #196]\n"
+ "fadd v16.4s, v8.4s, v9.4s\n"
+
+ "ldr q5, [inptr5], #0x10\n"
+ "prfm pldl1keep, [inptr13, #196]\n"
+ "fsub v17.4s, v9.4s, v10.4s\n"
+
+ "ldr q6, [inptr6], #0x10\n"
+ "prfm pldl1keep, [inptr14, #196]\n"
+ "fadd v16.4s, v16.4s, v10.4s\n"
+
+ "ldr q7, [inptr7], #0x10\n"
+ "prfm pldl1keep, [inptr15, #196]\n"
+ "fsub v17.4s, v17.4s, v11.4s\n"
+
+ "str q16, [outptr4], #0x10\n"
+ "fadd v16.4s, v12.4s, v13.4s\n"
+ "fsub v18.4s, v13.4s, v14.4s\n"
+
+ "str q17, [outptr5], #0x10\n"
+ "fadd v16.4s, v16.4s, v14.4s\n"
+ "fsub v18.4s, v18.4s, v15.4s\n"
+
+ "str q16, [outptr6], #0x10\n"
+ "subs column, column, #0x4\n"
+
+ "str q18, [outptr7], #0x10\n"
+ "bne 2b\n"
+
+ "3:" // Tail
+ "ldr q8, [inptr8], #0x10\n"
+ "prfm pldl1keep, [%x[inptr], #196]\n"
+ "fadd v16.4s, v0.4s, v1.4s\n"
+
+ "ldr q9, [inptr9], #0x10\n"
+ "prfm pldl1keep, [inptr1, #196]\n"
+ "fsub v17.4s, v1.4s, v2.4s\n"
+
+ "ldr q10, [inptr10], #0x10\n"
+ "prfm pldl1keep, [inptr2, #196]\n"
+ "fadd v16.4s, v16.4s, v2.4s\n"
+
+ "ldr q11, [inptr11], #0x10\n"
+ "prfm pldl1keep, [inptr3, #196]\n"
+ "fsub v17.4s, v17.4s, v3.4s\n"
+
+ "ldr q12, [inptr12], #0x10\n"
+ "prfm pldl1keep, [inptr4, #196]\n"
+ "str q16, [%x[outptr]], #0x10\n"
+
+ "ldr q13, [inptr13], #0x10\n"
+ "prfm pldl1keep, [inptr5, #196]\n"
+ "str q17, [outptr1], #0x10\n"
+
+ "ldr q14, [inptr14], #0x10\n"
+ "prfm pldl1keep, [inptr6, #196]\n"
+ "fadd v16.4s, v4.4s, v5.4s\n"
+
+ "ldr q15, [inptr15], #0x10\n"
+ "prfm pldl1keep, [inptr7, #196]\n"
+ "fsub v17.4s, v5.4s, v6.4s\n"
+
+ "prfm pldl1keep, [inptr8, #196]\n"
+ "prfm pldl1keep, [inptr9, #196]\n"
+ "fadd v16.4s, v16.4s, v6.4s\n"
+
+ "prfm pldl1keep, [inptr10, #196]\n"
+ "prfm pldl1keep, [inptr11, #196]\n"
+ "fsub v17.4s, v17.4s, v7.4s\n"
+
+ "prfm pldl1keep, [inptr12, #196]\n"
+ "prfm pldl1keep, [inptr13, #196]\n"
+ "str q16, [outptr2], #0x10\n"
+
+ "prfm pldl1keep, [inptr14, #196]\n"
+ "prfm pldl1keep, [inptr15, #196]\n"
+ "str q17, [outptr3], #0x10\n"
+
+ "fadd v16.4s, v8.4s, v9.4s\n"
+ "fsub v17.4s, v9.4s, v10.4s\n"
+
+ "fadd v16.4s, v16.4s, v10.4s\n"
+ "fsub v17.4s, v17.4s, v11.4s\n"
+
+ "str q16, [outptr4], #0x10\n"
+ "fadd v16.4s, v12.4s, v13.4s\n"
+ "fsub v18.4s, v13.4s, v14.4s\n"
+
+ "str q17, [outptr5], #0x10\n"
+ "fadd v16.4s, v16.4s, v14.4s\n"
+ "fsub v18.4s, v18.4s, v15.4s\n"
+
+ "str q16, [outptr6], #0x10\n"
+ "str q18, [outptr7], #0x10\n"
+
+ "subs %x[row], %x[row], #0x1\n"
+ "bne 1b\n"
+
+ ".unreq inptr1\n"
+ ".unreq inptr2\n"
+ ".unreq inptr3\n"
+ ".unreq inptr4\n"
+ ".unreq inptr5\n"
+ ".unreq inptr6\n"
+ ".unreq inptr7\n"
+ ".unreq inptr8\n"
+ ".unreq inptr9\n"
+ ".unreq inptr10\n"
+ ".unreq inptr11\n"
+ ".unreq inptr12\n"
+ ".unreq inptr13\n"
+ ".unreq inptr14\n"
+ ".unreq inptr15\n"
+ ".unreq outptr1\n"
+ ".unreq outptr2\n"
+ ".unreq outptr3\n"
+ ".unreq outptr4\n"
+ ".unreq outptr5\n"
+ ".unreq outptr6\n"
+ ".unreq outptr7\n"
+
+ : [row] "+r" (row),
+ [inptr] "+r" (inptr),
+ [outptr] "+r" (outptr)
+ : [n_channels] "r" (n_channels),
+ [sizeof_float] "i" (sizeof(float))
+ : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
+ "q12", "q13", "q14", "q15", "q16", "q17", "x0", "x1", "x2", "x3", "x4",
+ "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15",
+ "x16", "x17", "x18", "x19", "x20", "x21", "x22", "cc", "memory"
+ );
+}
+
+/*****************************************************************************/
+// Compute ZFZ^T specializations
+
+template <>
+template <>
+inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::compute_zfzT<false, false, 0>(
+ const Tensor4DShape &output_shape,
+ float* const output, const float* const input
+) {
+ const int tile_M = output_shape.n_rows / 2;
+ const int tile_N = output_shape.n_cols / 2;
+ int batch = output_shape.n_batches;
+ float *outptr = output;
+ const float *inptr = input;
+
+ asm volatile (
+ // Compute input pointers
+ "inptr1 .req x0\n"
+ "inptr2 .req x1\n"
+ "inptr3 .req x2\n"
+ "inptr4 .req x3\n"
+ "inptr5 .req x4\n"
+ "inptr6 .req x5\n"
+ "inptr7 .req x6\n"
+ "inptr8 .req x7\n"
+
+ "mstride .req x8\n"
+ "mul mstride, %x[tile_M], %x[tile_N]\n"
+ "mul mstride, mstride, %x[n_channels]\n"
+ "lsl mstride, mstride, #2\n" // * sizeof(float)
+
+ "add inptr1, %[inptr], mstride\n"
+ "add inptr2, inptr1, mstride\n"
+ "add inptr3, inptr2, mstride\n"
+ "add inptr4, inptr3, mstride\n"
+ "add inptr5, inptr4, mstride\n"
+ "add inptr6, inptr5, mstride\n"
+ "add inptr7, inptr6, mstride\n"
+ "add inptr8, inptr7, mstride\n"
+
+ ".unreq mstride\n"
+
+ // Compute initial output pointers
+ "outptr01 .req x8\n"
+ "outptr10 .req x9\n"
+ "outptr11 .req x10\n"
+
+ "add outptr01, %x[outptr], %x[n_channels], LSL #2\n"
+ "add outptr10, %x[outptr], %x[row_stride], LSL #2\n"
+ "add outptr11, outptr10, %x[n_channels], LSL #2\n"
+
+ "tile_i .req x11\n"
+ "tile_j .req x12\n"
+ "channel .req x13\n"
+
+ "1:" // Loop over batches
+ "mov tile_i, %x[tile_M]\n"
+
+ "2:" // Loop over rows of output tiles
+ "mov tile_j, %x[tile_N]\n"
+
+ "3:" // Loop over columns of output tiles
+ "ldr q0, [%x[inptr]], #0x10\n"
+ "ldr q2, [inptr2], #0x10\n"
+ "subs channel, %x[n_channels], #0x4\n"
+
+ "ldr q1, [inptr1], #0x10\n"
+ "ldr q3, [inptr3], #0x10\n"
+ "beq 6f\n"
+
+ "4:"
+ "ldr q4, [inptr4], #0x10\n"
+ "ldr q5, [inptr5], #0x10\n"
+ "fadd v16.4s, v0.4s, v2.4s\n"
+
+ "ldr q6, [inptr6], #0x10\n"
+ "ldr q7, [inptr7], #0x10\n"
+ "fadd v17.4s, v1.4s, v3.4s\n"
+
+ "ldr q8, [%x[inptr]], #0x10\n"
+ "ldr q10, [inptr2], #0x10\n"
+ "fadd v16.4s, v16.4s, v4.4s\n"
+
+ "ldr q9, [inptr1], #0x10\n"
+ "ldr q11, [inptr3], #0x10\n"
+ "fadd v17.4s, v17.4s, v5.4s\n"
+
+ "str q16, [%x[outptr]], #0x10\n"
+ "prfm pldl1strm, [%x[inptr], #196]\n"
+ "fsub v18.4s, v2.4s, v4.4s\n"
+
+ "str q17, [outptr01], #0x10\n"
+ "prfm pldl1strm, [inptr2, #196]\n"
+ "fsub v19.4s, v3.4s, v5.4s\n"
+
+ "prfm pldl1strm, [inptr1, #196]\n"
+ "prfm pldl1strm, [inptr3, #196]\n"
+ "fsub v18.4s, v18.4s, v6.4s\n"
+
+ "prfm pldl1strm, [inptr4, #196]\n"
+ "prfm pldl1strm, [inptr5, #196]\n"
+ "fsub v19.4s, v19.4s, v7.4s\n"
+
+ "str q18, [outptr10], #0x10\n"
+ "prfm pldl1strm, [inptr6, #196]\n"
+ "prfm pldl1strm, [inptr7, #196]\n"
+
+ "subs channel, channel, #0x4\n"
+
+ "str q19, [outptr11], #0x10\n"
+ "beq 6f\n" // Branch to tail
+
+ "ldr q12, [inptr4], #0x10\n"
+ "ldr q13, [inptr5], #0x10\n"
+ "fadd v16.4s, v8.4s, v10.4s\n"
+
+ "ldr q14, [inptr6], #0x10\n"
+ "ldr q15, [inptr7], #0x10\n"
+ "fadd v17.4s, v9.4s, v11.4s\n"
+
+ "ldr q0, [%x[inptr]], #0x10\n"
+ "ldr q2, [inptr2], #0x10\n"
+ "fadd v16.4s, v16.4s, v12.4s\n"
+
+ "ldr q1, [inptr1], #0x10\n"
+ "ldr q3, [inptr3], #0x10\n"
+ "fadd v17.4s, v17.4s, v13.4s\n"
+
+ "str q16, [%x[outptr]], #0x10\n"
+ "prfm pldl1strm, [%x[inptr], #196]\n"
+ "fsub v18.4s, v10.4s, v12.4s\n"
+
+ "str q17, [outptr01], #0x10\n"
+ "prfm pldl1strm, [inptr2, #196]\n"
+ "fsub v19.4s, v11.4s, v13.4s\n"
+
+ "prfm pldl1strm, [inptr1, #196]\n"
+ "prfm pldl1strm, [inptr3, #196]\n"
+ "fsub v18.4s, v18.4s, v14.4s\n"
+
+ "prfm pldl1strm, [inptr4, #196]\n"
+ "prfm pldl1strm, [inptr5, #196]\n"
+ "fsub v19.4s, v19.4s, v15.4s\n"
+
+ "str q18, [outptr10], #0x10\n"
+ "prfm pldl1strm, [inptr6, #196]\n"
+ "prfm pldl1strm, [inptr7, #196]\n"
+
+ "subs channel, channel, #0x4\n"
+
+ "str q19, [outptr11], #0x10\n"
+ "bne 4b\n" // Continue loop
+
+ "5:" // Tail
+ "ldr q12, [inptr4], #0x10\n"
+ "ldr q13, [inptr5], #0x10\n"
+ "fadd v16.4s, v8.4s, v10.4s\n"
+
+ "ldr q14, [inptr6], #0x10\n"
+ "ldr q15, [inptr7], #0x10\n"
+ "fadd v17.4s, v9.4s, v11.4s\n"
+
+ "fadd v16.4s, v16.4s, v12.4s\n"
+
+ "fadd v17.4s, v17.4s, v13.4s\n"
+
+ "str q16, [%x[outptr]], #0x10\n"
+ "fsub v18.4s, v10.4s, v12.4s\n"
+ "fsub v19.4s, v11.4s, v13.4s\n"
+
+ "str q17, [outptr01], #0x10\n"
+ "fsub v18.4s, v18.4s, v14.4s\n"
+ "fsub v19.4s, v19.4s, v15.4s\n"
+
+ "str q18, [outptr10], #0x10\n"
+ "str q19, [outptr11], #0x10\n"
+ "b 7f\n"
+
+ "6:" // Tail
+ "ldr q4, [inptr4], #0x10\n"
+ "ldr q5, [inptr5], #0x10\n"
+ "fadd v16.4s, v0.4s, v2.4s\n"
+
+ "ldr q6, [inptr6], #0x10\n"
+ "ldr q7, [inptr7], #0x10\n"
+ "fadd v17.4s, v1.4s, v3.4s\n"
+
+ "fadd v16.4s, v16.4s, v4.4s\n"
+
+ "fadd v17.4s, v17.4s, v5.4s\n"
+
+ "str q16, [%x[outptr]], #0x10\n"
+ "fsub v18.4s, v2.4s, v4.4s\n"
+ "fsub v19.4s, v3.4s, v5.4s\n"
+
+ "str q17, [outptr01], #0x10\n"
+ "fsub v18.4s, v18.4s, v6.4s\n"
+ "fsub v19.4s, v19.4s, v7.4s\n"
+
+ "str q18, [outptr10], #0x10\n"
+ "str q19, [outptr11], #0x10\n"
+
+ "7:"
+ "add %x[outptr], %x[outptr], %x[n_channels], LSL #2\n"
+ "add outptr01, outptr01, %x[n_channels], LSL #2\n"
+ "add outptr10, outptr10, %x[n_channels], LSL #2\n"
+ "add outptr11, outptr11, %x[n_channels], LSL #2\n"
+
+ "subs tile_j, tile_j, #1\n"
+ "bne 3b\n"
+
+ // Progress the output pointers to the new row
+ "add %x[outptr], %x[outptr], %x[row_stride], LSL #2\n"
+ "add outptr01, outptr01, %x[row_stride], LSL #2\n"
+ "add outptr10, outptr10, %x[row_stride], LSL #2\n"
+ "add outptr11, outptr11, %x[row_stride], LSL #2\n"
+
+ "subs tile_i, tile_i, #1\n"
+ "bne 2b\n"
+
+ "subs %[batch], %[batch], #1\n"
+ "bne 1b\n"
+ "5:"
+
+ ".unreq inptr1\n"
+ ".unreq inptr2\n"
+ ".unreq inptr3\n"
+ ".unreq inptr4\n"
+ ".unreq inptr5\n"
+ ".unreq inptr6\n"
+ ".unreq inptr7\n"
+ ".unreq inptr8\n"
+ ".unreq outptr01\n"
+ ".unreq outptr10\n"
+ ".unreq outptr11\n"
+ : [batch] "+r" (batch),
+ [outptr] "+r" (outptr),
+ [inptr] "+r" (inptr)
+ : [tile_M] "r" (tile_M),
+ [tile_N] "r" (tile_N),
+ [n_channels] "r" (output_shape.n_channels),
+ [row_stride] "r" (output_shape.n_cols * output_shape.n_channels)
+ : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11",
+ "x12", "x13", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
+ "cc", "memory"
+ );
+}
+/*****************************************************************************/
+
+/*****************************************************************************/
+template <>
+inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::execute(
+ const Tensor4DShape &output_shape,
+ float* const matrices[16], float* const output
+) {
+ // profiler prof;
+
+ // Allocate memory for the intermediate matrices
+ const int tile_M = iceildiv(output_shape.n_rows, 2);
+ const int tile_N = iceildiv(output_shape.n_cols, 2);
+ const int n_rows = output_shape.n_batches * tile_M * tile_N;
+ const int n_channels = output_shape.n_channels;
+ float* matrices_zf = reinterpret_cast<float*>(
+ calloc(8 * n_rows * n_channels, sizeof(float))
+ );
+
+ // Perform the first stage transform, computing ZF.
+ const auto f_compute_zf = [&] () {
+ switch (n_channels % 4) {
+ case 0:
+ compute_zf<0>(n_rows, n_channels, matrices_zf, matrices);
+ break;
+ case 1:
+ compute_zf<1>(n_rows, n_channels, matrices_zf, matrices);
+ break;
+ case 2:
+ compute_zf<2>(n_rows, n_channels, matrices_zf, matrices);
+ break;
+ case 3:
+ compute_zf<3>(n_rows, n_channels, matrices_zf, matrices);
+ };
+ };
+ // prof("Compute ZF", f_compute_zf, 16 * n_rows * n_channels * sizeof(float), 0, 8 * n_rows * n_channels * sizeof(float));
+ f_compute_zf();
+
+ // Perform the second stage transform, finishing Z F Z^T - variable dispatch
+ // based on size of the output and the channel tail.
+ const auto f_compute_zfzT = [&] () {
+ if (output_shape.n_rows % 2 && output_shape.n_cols % 2) {
+ constexpr bool tail_M = true, tail_N = true;
+ switch (n_channels % 4) {
+ case 0:
+ compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
+ break;
+ case 1:
+ compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
+ break;
+ case 2:
+ compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
+ break;
+ case 3:
+ compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
+ }
+ } else if (output_shape.n_rows % 2) {
+ constexpr bool tail_M = true, tail_N = false;
+ switch (n_channels % 4) {
+ case 0:
+ compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
+ break;
+ case 1:
+ compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
+ break;
+ case 2:
+ compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
+ break;
+ case 3:
+ compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
+ }
+ } else if (output_shape.n_cols % 2) {
+ constexpr bool tail_M = false, tail_N = true;
+ switch (n_channels % 4) {
+ case 0:
+ compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
+ break;
+ case 1:
+ compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
+ break;
+ case 2:
+ compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
+ break;
+ case 3:
+ compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
+ }
+ } else {
+ constexpr bool tail_M = false, tail_N = false;
+ switch (n_channels % 4) {
+ case 0:
+ compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
+ break;
+ case 1:
+ compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
+ break;
+ case 2:
+ compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
+ break;
+ case 3:
+ compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
+ }
+ }
+ };
+ // prof("Compute ZFZT", f_compute_zfzT, 8 * n_rows * n_channels * sizeof(float), 0, 4 * n_rows * n_channels * sizeof(float));
+ f_compute_zfzT();
+
+ free(reinterpret_cast<void*>(matrices_zf));
+}
+/*****************************************************************************/
+
+#endif // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/winograd/utils.hpp b/arm_compute/core/NEON/kernels/winograd/utils.hpp
new file mode 100644
index 0000000000..14e709f028
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/utils.hpp
@@ -0,0 +1,55 @@
+
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+#include <ctime>
+
+inline double TimeInUs(void) {
+#ifdef CYCLE_PROFILING
+ timespec t;
+ clock_gettime(CLOCK_THREAD_CPUTIME_ID, &t);
+ return 1e6*t.tv_sec + 1e-3*t.tv_nsec;
+#else
+ return 0;
+#endif
+}
+
+inline int iceildiv(const int a, const int b) {
+ return (a + b - 1) / b;
+}
+
+template <typename T>
+inline T roundup(const T a, const T b) {
+ return a + b - (a % b);
+}
+
+inline void PrintMatrix(const float* const m, const int M, const int N, const int row_stride) {
+ for (int i = 0; i < M; i++) {
+ for (int j = 0; j < N; j++) {
+ printf("%.3f ", m[i*row_stride + j]);
+ }
+ printf("\n");
+ }
+ printf("\n");
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp
new file mode 100644
index 0000000000..c990cd0252
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp
@@ -0,0 +1,346 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+#include <cstdint>
+#include <cstdlib>
+#include <cassert>
+
+#include "alloc.hpp"
+#include "gemm.hpp"
+#include "profiler.hpp"
+#include "utils.hpp"
+#include "shims.hpp"
+
+#include "transforms.hpp"
+
+namespace winograd {
+ /***************************************************************************/
+ /* Implementation of the Winograd F(2x2, 3x3, 4x4) algorithm using GEMM
+ * internally.
+ */
+ template <typename TOut, typename TIn>
+ class Winograd2x2_3x3GEMM {
+ public:
+ /* Instantiate a new Winograd operator.
+ */
+ Winograd2x2_3x3GEMM(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage);
+ virtual ~Winograd2x2_3x3GEMM();
+
+ /** Transform the weights into the Winograd domain.
+ */
+ template <typename KernelTransform=winograd2x2_3x3_gemm_kernel_transform_impl<TIn>>
+ void transform_weights(const TIn* const kernel, void *transform_working_space);
+
+ /* Initializes matrices pointers, to be called once before execute()
+ */
+ template <typename InputTransform=Winograd2x2_3x3GemmInputChannelwise<TIn>>
+ void reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const TIn* const input, void* working_space);
+
+ /* Apply the Winograd operator to some input.
+ */
+ template <typename OutputTransform=Winograd2x2_3x3GemmOutput<TOut>>
+ void reshape_output(const Tensor4DShape& input_shape, const PaddingType padding_type, TOut* const output);
+
+
+ /* Apply the Winograd operator to some input.
+ */
+ void execute(size_t first, size_t last);
+
+ /* Get the memory required to transform the kernel.
+ */
+ static inline size_t get_kernel_transform_working_size(const KernelShape &shape);
+
+ /* Get the output shape of a convolution.
+ */
+ static Tensor4DShape get_output_shape(const Tensor4DShape &input_shape, const KernelShape &k_shape,
+ const PaddingType padding_type);
+
+ /* Get the memory required to instantiate a new Winograd operator.
+ */
+ static size_t get_kernel_storage_size(const KernelShape &shape);
+
+ /* Get the memory required to apply a Winograd operator to some input.
+ */
+ static size_t get_working_space_size(const Tensor4DShape &input_shape,const KernelShape &k_shape,
+ const PaddingType padding);
+
+
+ Winograd2x2_3x3GEMM(const Winograd2x2_3x3GEMM &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ Winograd2x2_3x3GEMM &operator=(const Winograd2x2_3x3GEMM &) = delete;
+ /** Allow instances of this class to be moved */
+ Winograd2x2_3x3GEMM(Winograd2x2_3x3GEMM &&) = default;
+ /** Allow instances of this class to be moved */
+ Winograd2x2_3x3GEMM &operator=(Winograd2x2_3x3GEMM &&) = default;
+
+ protected:
+ /* Get the memory required by a single "input" matrix.
+ */
+ static size_t get_input_matrix_size(const Tensor4DShape &input_shape,const KernelShape &k_shape,
+ const PaddingType padding);
+
+ /* Get the memory required by a single "output" matrix.
+ */
+ static size_t get_output_matrix_size(const Tensor4DShape &input_shape, const KernelShape &k_shape,
+ const PaddingType padding);
+
+ /* Get the memory required by a single "kernel" matrix.
+ */
+ static size_t get_kernel_matrix_size(const KernelShape &shape);
+
+ const KernelShape kernel_shape; // Shape of applied kernel
+ const Tensor4DShape in_shape;
+ const PaddingType padding;
+
+ const int kernel_matrix_row_stride; // Stride within kernel matrix
+
+ const bool manage_kernel_storage; // Free kernel storage when done
+ void* const _kernel_storage; // Base pointer for kernel matrices
+
+ profiler prof; // Profiler
+
+ TIn *kernel_matrices[16]; // Prepared form of kernel
+ TIn *input_matrices[16];
+ TOut *output_matrices[16];
+
+
+ static const int M_BLOCK = 4;
+ static const int N_BLOCK = 16;
+ };
+} // namespace winograd
+
+template <typename TOut, typename TIn>
+size_t winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_kernel_transform_working_size(
+ const KernelShape &shape
+)
+{
+ // Need to re-order the kernel into HWIO form, require enough space to
+ // represent the tensor.
+ return sizeof(TIn) * shape.size();
+}
+
+
+template <typename TOut, typename TIn>
+template <typename KernelTransform>
+void winograd::Winograd2x2_3x3GEMM<TOut, TIn>::transform_weights(
+ const TIn* const kernel,
+ void *transform_working_space
+)
+{
+ const int kernel_matrix_size_bytes = get_kernel_matrix_size(kernel_shape);
+ int8_t* const ks_bytes = reinterpret_cast<int8_t *>(_kernel_storage);
+ for (int i = 0; i < 16; i++) {
+ kernel_matrices[i] = reinterpret_cast<TIn *>(
+ ks_bytes + i*kernel_matrix_size_bytes);
+ }
+
+ const TIn *kernel_hwio = kernel;
+ if( transform_working_space)
+ {
+ kernel_hwio = reinterpret_cast<TIn *>(transform_working_space);
+ ofm_ifm_h_w_to_h_w_ifm_ofm(
+ kernel, const_cast<TIn *>(kernel_hwio),
+ kernel_shape.n_output_channels,
+ kernel_shape.n_input_channels,
+ kernel_shape.n_rows,
+ kernel_shape.n_cols
+ );
+ }
+ KernelTransform::execute(
+ kernel_shape, kernel_hwio, kernel_matrices[0],
+ kernel_matrix_size_bytes / sizeof(TIn),
+ kernel_matrix_row_stride
+ );
+}
+
+template <typename TOut, typename TIn>
+winograd::Winograd2x2_3x3GEMM<TOut, TIn>::Winograd2x2_3x3GEMM( const KernelShape &kernel_shape, const Tensor4DShape input_shape,
+ const PaddingType padding_type, void *kernel_storage)
+ : kernel_shape(kernel_shape), in_shape(input_shape), padding(padding_type),kernel_matrix_row_stride(roundup(kernel_shape.n_output_channels, N_BLOCK)), manage_kernel_storage(false),
+ _kernel_storage(kernel_storage), prof() {
+ memset(kernel_matrices, 0x00, sizeof(TIn)*16);
+ memset(input_matrices, 0x00, sizeof(TIn)*16);
+ memset(output_matrices, 0x00, sizeof(TOut)*16);
+}
+
+/*****************************************************************************/
+template <typename TOut, typename TIn>
+winograd::Winograd2x2_3x3GEMM<TOut, TIn>::~Winograd2x2_3x3GEMM() {}
+
+/*****************************************************************************/
+template <typename TOut, typename TIn>
+template <typename InputTransform>
+void winograd::Winograd2x2_3x3GEMM<TOut, TIn>::reshape_input(
+ const Tensor4DShape& input_shape,
+ const PaddingType padding_type,
+ const TIn* const input,
+ void *working_space
+) {
+ assert(working_space);
+ int8_t* const ws_bytes = reinterpret_cast<int8_t *>(working_space);
+ // Split the working space into that required for 16 input matrices and
+ // output matrices.
+ const int in_matrix_stride_bytes = get_input_matrix_size(input_shape, kernel_shape, padding_type);
+ const int out_matrix_stride_bytes = get_output_matrix_size(input_shape, kernel_shape, padding_type);
+
+ for (int i = 0; i < 16; i++) {
+ input_matrices[i] = reinterpret_cast<TIn *>(
+ ws_bytes + i*in_matrix_stride_bytes);
+ output_matrices[i] = reinterpret_cast<TIn *>(
+ ws_bytes + 16*in_matrix_stride_bytes + i*out_matrix_stride_bytes);
+ }
+
+ // Compute shape for the GEMM
+ const auto output_shape = get_output_shape(input_shape,kernel_shape, padding_type);
+ const int tile_rows = iceildiv(output_shape.n_rows, 2);
+ const int tile_cols = iceildiv(output_shape.n_cols, 2);
+ const int K = kernel_shape.n_input_channels;
+
+ const int in_matrix_row_stride = K;
+ const int in_matrix_batch_stride = tile_rows*tile_cols*in_matrix_row_stride;
+
+ // Transform the input tensor into an appropriate form
+ auto input_prep = [&] () {
+ InputTransform::execute(
+ input, input_shape, padding_type, tile_rows, tile_cols,
+ input_matrices[0], in_matrix_stride_bytes / sizeof(TIn),
+ in_matrix_batch_stride, in_matrix_row_stride
+ );
+ };
+ prof(
+ "Input Prep", input_prep,
+ InputTransform::bytes_read(input_shape, output_shape),
+ InputTransform::flops_performed(input_shape, output_shape),
+ InputTransform::bytes_written(input_shape, output_shape)
+ );
+
+}
+
+/*****************************************************************************/
+template <typename TOut, typename TIn>
+template <typename OutputTransform>
+void winograd::Winograd2x2_3x3GEMM<TOut, TIn>::reshape_output(const Tensor4DShape& input_shape, const PaddingType padding_type, TOut* const output) {
+ assert(output_matrices[0]);
+ const int out_matrix_stride_bytes = get_output_matrix_size(input_shape, kernel_shape, padding_type);
+ const auto output_shape = get_output_shape(input_shape,kernel_shape, padding_type);
+ const int out_matrix_row_stride = kernel_matrix_row_stride;
+
+ // Transform the output tensor into an appropriate form
+ OutputTransform::execute(
+ output_shape,
+ output_matrices[0],
+ out_matrix_stride_bytes / sizeof(TOut),
+ out_matrix_row_stride,
+ output
+ );
+}
+
+
+/*****************************************************************************/
+template <typename TOut, typename TIn>
+void winograd::Winograd2x2_3x3GEMM<TOut, TIn>::execute( size_t first, size_t last ) {
+ assert(input_matrices[0] && kernel_matrices[0] && output_matrices[0]);
+ assert(first < 16 && last < 16 && first < last);
+ // Compute shape for the GEMM
+ const auto output_shape = get_output_shape(in_shape,kernel_shape, padding);
+ const int tile_rows = iceildiv(output_shape.n_rows, 2);
+ const int tile_cols = iceildiv(output_shape.n_cols, 2);
+ const int M = in_shape.n_batches * tile_rows * tile_cols;
+ const int K = kernel_shape.n_input_channels;
+ const int N = kernel_shape.n_output_channels;
+
+ const int in_matrix_row_stride = K;
+ const int out_matrix_row_stride = kernel_matrix_row_stride;
+ // Perform the GEMMs
+ for (size_t i = first; i <= last; i++) {
+ BlockedGemm<M_BLOCK, N_BLOCK>(
+ input_matrices[i], kernel_matrices[i], output_matrices[i], M, K, N,
+ in_matrix_row_stride, kernel_matrix_row_stride, out_matrix_row_stride
+ );
+// prof("GEMM", perform_gemm, 0, 2*M*K*N, 0); // TODO Memory
+ }
+
+}
+
+/*****************************************************************************/
+template <typename TOut, typename TIn>
+Tensor4DShape winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_output_shape(
+ const Tensor4DShape &in_shape, const KernelShape &k_shape, const PaddingType padding) {
+ return Tensor4DShape {
+ in_shape.n_batches,
+ (padding == PADDING_SAME) ? in_shape.n_rows : in_shape.n_rows - 2,
+ (padding == PADDING_SAME) ? in_shape.n_cols : in_shape.n_cols - 2,
+ k_shape.n_output_channels
+ };
+}
+
+template <typename TOut, typename TIn>
+size_t winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_kernel_storage_size(
+ const KernelShape &shape) {
+ return 16 * get_kernel_matrix_size(shape);
+}
+
+template <typename TOut, typename TIn>
+size_t winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_kernel_matrix_size(
+ const KernelShape &shape) {
+ const int K = shape.n_input_channels;
+ const int N = roundup(shape.n_output_channels, N_BLOCK);
+ return sizeof(TIn) * K * N;
+}
+
+template <typename TOut, typename TIn>
+size_t winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_working_space_size(
+ const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type
+) {
+ return 16 * get_input_matrix_size(input_shape, k_shape, padding_type) +
+ 16 * get_output_matrix_size(input_shape, k_shape, padding_type);
+}
+
+template <typename TOut, typename TIn>
+size_t winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_input_matrix_size(
+ const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type
+) {
+ // Compute shape for the GEMM
+ const auto output_shape = get_output_shape(input_shape, k_shape, padding_type);
+ const int tile_rows = iceildiv(output_shape.n_rows, 2);
+ const int tile_cols = iceildiv(output_shape.n_cols, 2);
+ const int M = roundup(tile_rows * tile_cols, M_BLOCK);
+ const int K = k_shape.n_input_channels;
+
+ return input_shape.n_batches * M * K * sizeof(TIn);
+}
+
+template <typename TOut, typename TIn>
+size_t winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_output_matrix_size(
+ const Tensor4DShape& input_shape, const KernelShape &k_shape,const PaddingType padding_type
+) {
+ // Compute shape for the GEMM
+ const auto output_shape = get_output_shape(input_shape, k_shape, padding_type);
+ const int tile_rows = iceildiv(output_shape.n_rows, 2);
+ const int tile_cols = iceildiv(output_shape.n_cols, 2);
+ const int M = roundup(tile_rows * tile_cols, M_BLOCK);
+ const int N = roundup(k_shape.n_output_channels, N_BLOCK);
+
+ return input_shape.n_batches * M * N * sizeof(TOut);
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp
new file mode 100644
index 0000000000..4c7e291c58
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp
@@ -0,0 +1,192 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+#include <cstdint>
+#include <cstdlib>
+
+#include "alloc.hpp"
+#include "gemm.hpp"
+#include "profiler.hpp"
+#include "utils.hpp"
+#include "shims.hpp"
+#include "winograd_gemm.hpp"
+
+#include "transforms.hpp"
+
+#ifndef ALLOC_ALIGN
+#define ALLOC_ALIGN 64
+#endif // ALLOC_ALIGN
+
+
+namespace winograd_shim_nchw {
+ /***************************************************************************/
+ /* Implementation of the Winograd F(2x2, 3x3, 4x4) algorithm using GEMM
+ * internally.
+ */
+ template <typename TOut, typename TIn>
+ class Winograd2x2_3x3GEMM : public winograd::Winograd2x2_3x3GEMM<TOut, TIn> {
+ public:
+ /* Instantiate a new Winograd operator.
+ */
+ Winograd2x2_3x3GEMM(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage);
+
+ void nchw2nhwc( const Tensor4DShape& input_shape, const PaddingType padding_type, void *working_space, const TIn* const input);
+ void nhwc2nchw( const Tensor4DShape& input_shape, const PaddingType padding_type, void *working_space, TOut* const output);
+
+
+ std::pair<TOut*,TIn*> get_nhwc_ptrs(const Tensor4DShape& input_shape,const PaddingType padding_type,void *working_space);
+
+ static size_t get_working_space_size(const Tensor4DShape &input_shape,const KernelShape &k_shape, const PaddingType padding);
+ protected:
+ /* Get the memory required to store an NHWC copy of the input tensor. */
+ static size_t get_working_nhwc_input_size(const Tensor4DShape &input_shape);
+
+ /* Get the memory required to store an NHWC copy of the input tensor. */
+ static size_t get_working_nhwc_output_size(const Tensor4DShape &output_shape, const KernelShape &k_shape, const PaddingType padding) ;
+ };
+} // namespace winograd
+
+/*****************************************************************************/
+template <typename TOut, typename TIn>
+winograd_shim_nchw::Winograd2x2_3x3GEMM<TOut, TIn>::Winograd2x2_3x3GEMM(
+ const KernelShape &kernel_shape, const Tensor4DShape input_shape,
+ const PaddingType padding_type, void *kernel_storage
+) : winograd::Winograd2x2_3x3GEMM<TOut, TIn>(kernel_shape,input_shape,padding_type,kernel_storage) {
+}
+
+/*****************************************************************************/
+template <typename TOut, typename TIn>
+void winograd_shim_nchw::Winograd2x2_3x3GEMM<TOut, TIn>::nchw2nhwc(const Tensor4DShape& input_shape, const PaddingType padding_type, void *working_space, const TIn* const input) {
+ assert(working_space);
+ int8_t* const ws_bytes = reinterpret_cast<int8_t *>(working_space);
+
+ // Extract the top chunk of the working space to store the input and output
+ // tensors in NHWC format.
+ const int in_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_input_matrix_size(input_shape, this->kernel_shape, padding_type);
+ const int out_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_output_matrix_size(input_shape, this->kernel_shape, padding_type);
+
+ // Allocate working space for the input and output in NHWC format
+ TIn* const input_nhwc = reinterpret_cast<TIn *>(
+ ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes)
+ );
+
+ // Re-order the input tensor
+ this->prof(
+ "NCHW -> NHWC",
+ [input, input_shape, input_nhwc] () {
+ nchw_to_nhwc(
+ input, input_nhwc,
+ input_shape.n_batches,
+ input_shape.n_channels,
+ input_shape.n_rows,
+ input_shape.n_cols
+ );
+ },
+ input_shape.size(), 0, input_shape.size()
+ );
+}
+
+/*****************************************************************************/
+template <typename TOut, typename TIn>
+void winograd_shim_nchw::Winograd2x2_3x3GEMM<TOut, TIn>::nhwc2nchw(const Tensor4DShape& input_shape, const PaddingType padding_type,
+ void *working_space, TOut* const output) {
+
+ assert(working_space);
+ int8_t* const ws_bytes = reinterpret_cast<int8_t *>(working_space);
+
+ // Extract the top chunk of the working space to store the input and output
+ // tensors in NHWC format.
+ const int in_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_input_matrix_size(input_shape, this->kernel_shape, padding_type);
+ const int out_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_output_matrix_size(input_shape, this->kernel_shape, padding_type);
+
+ TOut* const output_nhwc = reinterpret_cast<TOut *>(ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes) + get_working_nhwc_input_size(input_shape));
+
+ // Re-order the output tensor into NCHW
+ const auto output_shape = winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_output_shape(input_shape, this->kernel_shape, padding_type);
+ this->prof(
+ "NHWC -> NCHW",
+ [output_nhwc, output_shape, output] () {
+ nhwc_to_nchw(
+ output_nhwc, output,
+ output_shape.n_batches,
+ output_shape.n_rows,
+ output_shape.n_cols,
+ output_shape.n_channels
+ );
+ },
+ output_shape.size(), 0, output_shape.size()
+ );
+}
+
+
+/*****************************************************************************/
+template <typename TOut, typename TIn>
+std::pair<TOut*,TIn*> winograd_shim_nchw::Winograd2x2_3x3GEMM<TOut, TIn>::get_nhwc_ptrs(
+ const Tensor4DShape& input_shape,
+ const PaddingType padding_type,
+ void *working_space
+) {
+ assert(working_space);
+ int8_t* const ws_bytes = reinterpret_cast<int8_t *>(working_space);
+
+ // Extract the top chunk of the working space to store the input and output
+ // tensors in NHWC format.
+ const int in_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_input_matrix_size(input_shape, this->kernel_shape, padding_type);
+ const int out_matrix_stride_bytes = winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_output_matrix_size(input_shape, this->kernel_shape, padding_type);
+
+ // Allocate working space for the input and output in NHWC format
+ TIn* input_nhwc = reinterpret_cast<TIn *>(ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes));
+ TOut* output_nhwc = reinterpret_cast<TOut *>(ws_bytes + 16*(in_matrix_stride_bytes + out_matrix_stride_bytes) + get_working_nhwc_input_size(input_shape));
+ return std::make_pair(output_nhwc,input_nhwc);
+}
+
+
+
+
+/*****************************************************************************/
+template <typename TOut, typename TIn>
+size_t winograd_shim_nchw::Winograd2x2_3x3GEMM<TOut, TIn>::get_working_space_size(
+ const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type
+) {
+ // TODO Add memory required for NHWC copies of input tensors
+ return winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_working_space_size(
+ input_shape, k_shape, padding_type)
+ + get_working_nhwc_input_size(input_shape)
+ + get_working_nhwc_output_size(input_shape, k_shape, padding_type);
+}
+
+template <typename TOut, typename TIn>
+size_t winograd_shim_nchw::Winograd2x2_3x3GEMM<TOut, TIn>::get_working_nhwc_input_size(
+ const Tensor4DShape& input_shape
+) {
+ return roundup(input_shape.size() * sizeof(TIn), static_cast<size_t>(ALLOC_ALIGN));
+}
+
+template <typename TOut, typename TIn>
+size_t winograd_shim_nchw::Winograd2x2_3x3GEMM<TOut, TIn>::get_working_nhwc_output_size(
+ const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type
+) {
+ const auto output_shape = winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_output_shape(input_shape,k_shape, padding_type);
+ return roundup(output_shape.size() * sizeof(TIn), static_cast<size_t>(ALLOC_ALIGN));
+}
diff --git a/arm_compute/runtime/NEON/NEFunctions.h b/arm_compute/runtime/NEON/NEFunctions.h
index 5baaa50d40..2e8c084371 100644
--- a/arm_compute/runtime/NEON/NEFunctions.h
+++ b/arm_compute/runtime/NEON/NEFunctions.h
@@ -108,5 +108,6 @@
#include "arm_compute/runtime/NEON/functions/NETranspose.h"
#include "arm_compute/runtime/NEON/functions/NEWarpAffine.h"
#include "arm_compute/runtime/NEON/functions/NEWarpPerspective.h"
+#include "arm_compute/runtime/NEON/functions/NEWinogradLayer.h"
#endif /* __ARM_COMPUTE_NEFUNCTIONS_H__ */
diff --git a/arm_compute/runtime/NEON/functions/NEWinogradLayer.h b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h
new file mode 100644
index 0000000000..7dca4570e5
--- /dev/null
+++ b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h
@@ -0,0 +1,84 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_NEWINOGRADLAYER_H__
+#define __ARM_COMPUTE_NEWINOGRADLAYER_H__
+
+#include "arm_compute/runtime/IFunction.h"
+
+#include "arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/MemoryGroup.h"
+#include "arm_compute/runtime/Tensor.h"
+
+#include <memory>
+
+namespace arm_compute
+{
+class ITensor;
+/** Basic function to simulate a convolution layer. This function calls the following NEON kernels:
+ */
+class NEWinogradLayer : public IFunction
+{
+public:
+ /** Constructor */
+ NEWinogradLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
+
+ /** Set the input and output tensors.
+ *
+ * @param[in] input Source tensor. 3 lower dimensions represent a single input [width, height, IFM],
+ * while every optional dimension from 4 and above represent a batch of inputs.
+ * Data types supported: F32.
+ * @param[in] weights Weights tensor. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. Data type supported: Same as @p input.
+ * Currently only 3x3 kernels are supported.
+ * @param[in] biases Not supported, biases will be ignored.
+ * @param[out] output Destination tensor. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs.
+ * Data types supported: Same as @p input.
+ * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. Currently only unit strides are supported.
+ */
+ void configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info);
+
+ // Inherited methods overridden:
+ void run() override;
+
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ NEWinogradLayer(const NEWinogradLayer &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ NEWinogradLayer &operator=(const NEWinogradLayer &) = delete;
+
+private:
+ using Winograd3x3F32 = NEWinogradLayerKernel::Winograd3x3F32;
+
+ MemoryGroup _memory_group;
+ NEWinogradLayerKernel _winograd_kernel;
+ Tensor _weights_workspace;
+ Tensor _workspace;
+ Tensor _kernel_storage;
+ const ITensor *_input;
+ const ITensor *_weights;
+ ITensor *_output;
+ bool _reshaped_kernel;
+ std::unique_ptr<Winograd3x3F32> _conv;
+};
+}
+#endif /* __ARM_COMPUTE_NEWINOGRADLAYER_H__ */
diff --git a/scripts/check_bad_style.sh b/scripts/check_bad_style.sh
index e7f6f1af54..4cd69757d6 100755
--- a/scripts/check_bad_style.sh
+++ b/scripts/check_bad_style.sh
@@ -5,7 +5,7 @@ set -e
DIRECTORIES="./arm_compute ./src ./examples ./tests ./utils ./support"
-grep -HrnP --exclude-dir=assembly "/\*\*$" $DIRECTORIES | tee bad_style.log
+grep -HrnP --exclude-dir=assembly --exclude-dir=winograd "/\*\*$" $DIRECTORIES | tee bad_style.log
if (( `cat bad_style.log | wc -l` > 0 ))
then
echo ""
@@ -13,7 +13,7 @@ then
exit -1
fi
-grep -Hnr --exclude-dir=assembly --exclude=Doxyfile "@brief" $DIRECTORIES | tee bad_style.log
+grep -Hnr --exclude-dir=assembly --exclude-dir=winograd --exclude=Doxyfile "@brief" $DIRECTORIES | tee bad_style.log
if (( `cat bad_style.log | wc -l` > 0 ))
then
echo ""
@@ -21,7 +21,7 @@ then
exit -1
fi
-grep -HnRE --exclude-dir=assembly "\buint " --exclude-dir=cl_kernels --exclude-dir=cs_shaders $DIRECTORIES | tee bad_style.log
+grep -HnRE --exclude-dir=assembly --exclude-dir=winograd "\buint " --exclude-dir=cl_kernels --exclude-dir=cs_shaders $DIRECTORIES | tee bad_style.log
if [[ $(cat bad_style.log | wc -l) > 0 ]]
then
echo ""
@@ -29,7 +29,7 @@ then
exit -1
fi
-grep -HnR --exclude-dir=assembly "float32_t" $DIRECTORIES | tee bad_style.log
+grep -HnR --exclude-dir=assembly --exclude-dir=winograd "float32_t" $DIRECTORIES | tee bad_style.log
if [[ $(cat bad_style.log | wc -l) > 0 ]]
then
echo ""
@@ -37,7 +37,7 @@ then
exit -1
fi
-grep -Hnir --exclude-dir=assembly "arm[_ ]\?cv" $DIRECTORIES | tee bad_style.log
+grep -Hnir --exclude-dir=assembly --exclude-dir=winograd "arm[_ ]\?cv" $DIRECTORIES | tee bad_style.log
if [[ $(cat bad_style.log | wc -l) > 0 ]]
then
echo ""
@@ -45,7 +45,7 @@ then
exit -1
fi
-grep -Hnir --exclude-dir=assembly "#.*if.*defined[^(]" $DIRECTORIES | tee bad_style.log
+grep -Hnir --exclude-dir=assembly --exclude-dir=winograd "#.*if.*defined[^(]" $DIRECTORIES | tee bad_style.log
if [[ $(cat bad_style.log | wc -l) > 0 ]]
then
echo ""
@@ -53,7 +53,7 @@ then
exit -1
fi
-grep -Hnir --exclude-dir=assembly "#else$\|#endif$" $DIRECTORIES | tee bad_style.log
+grep -Hnir --exclude-dir=assembly --exclude-dir=winograd "#else$\|#endif$" $DIRECTORIES | tee bad_style.log
if [[ $(cat bad_style.log | wc -l) > 0 ]]
then
echo ""
@@ -61,7 +61,7 @@ then
exit -1
fi
-grep -Hnir --exclude-dir=assembly "ARM_COMPUTE_AARCH64_V8_2" ./tests/validation/CL | tee bad_style.log
+grep -Hnir --exclude-dir=assembly --exclude-dir=winograd "ARM_COMPUTE_AARCH64_V8_2" ./tests/validation/CL | tee bad_style.log
if [[ $(cat bad_style.log | wc -l) > 0 ]]
then
echo ""
diff --git a/scripts/clang_tidy_rules.py b/scripts/clang_tidy_rules.py
index 9c012680d4..5b27dd5be5 100755
--- a/scripts/clang_tidy_rules.py
+++ b/scripts/clang_tidy_rules.py
@@ -42,6 +42,9 @@ def filter_clang_tidy_lines( lines ):
if "/assembly/" in line:
continue
+ if "/winograd/" in line:
+ continue
+
if "error:" in line:
if (("Utils.cpp" in line and "'arm_compute_version.embed' file not found" in line) or
("cl2.hpp" in line and "cast from pointer to smaller type 'cl_context_properties' (aka 'int') loses information" in line) or
diff --git a/src/core/NEON/kernels/NEWinogradLayerKernel.cpp b/src/core/NEON/kernels/NEWinogradLayerKernel.cpp
new file mode 100644
index 0000000000..b9109dcff2
--- /dev/null
+++ b/src/core/NEON/kernels/NEWinogradLayerKernel.cpp
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/TensorInfo.h"
+
+namespace arm_compute
+{
+NEWinogradLayerKernel::NEWinogradLayerKernel()
+ : _convolver(nullptr), _output(nullptr)
+{
+}
+
+void NEWinogradLayerKernel::configure(ITensor *output, Winograd3x3F32 *convolver)
+{
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32);
+ _convolver = convolver;
+ Window win = calculate_max_window(*output->info());
+ INEKernel::configure(win);
+}
+
+void NEWinogradLayerKernel::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(window);
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+ ARM_COMPUTE_ERROR_ON(info.num_threads < 1);
+ const size_t tid = info.thread_id;
+ const size_t num_threads = std::min(info.num_threads, 16);
+ const size_t num_gemms_per_thread = 16 / num_threads;
+ const size_t first_gemm = tid * num_gemms_per_thread;
+ const size_t last_gemm = (tid == (num_threads - 1)) ? 15 : first_gemm + num_gemms_per_thread - 1;
+ _convolver->execute(first_gemm, last_gemm);
+}
+} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NEWinogradLayer.cpp b/src/runtime/NEON/functions/NEWinogradLayer.cpp
new file mode 100644
index 0000000000..a9dec4ea0d
--- /dev/null
+++ b/src/runtime/NEON/functions/NEWinogradLayer.cpp
@@ -0,0 +1,155 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/NEON/functions/NEWinogradLayer.h"
+
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "support/ToolchainSupport.h"
+
+namespace
+{
+inline Tensor4DShape internal_get_input_shape(const arm_compute::ITensor *input)
+{
+ const int in_width = input->info()->dimension(0);
+ const int in_height = input->info()->dimension(1);
+ const int in_batches = input->info()->dimension(3);
+ const int in_channels = input->info()->dimension(2);
+ return Tensor4DShape({ in_batches, in_height, in_width, in_channels });
+}
+} /* namespace */
+
+namespace arm_compute
+{
+NEWinogradLayer::NEWinogradLayer(std::shared_ptr<IMemoryManager> memory_manager)
+ : _memory_group(std::move(memory_manager)), _winograd_kernel(), _weights_workspace(), _workspace(), _kernel_storage(), _input(), _weights(), _output(), _reshaped_kernel(false), _conv()
+{
+} /* arm_compute */
+
+void NEWinogradLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info)
+{
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
+ ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(1) != 3 || weights->info()->dimension(0) != 3, "Only 3x3 kernels are supported");
+ ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4);
+
+ if(biases != nullptr)
+ {
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
+ ARM_COMPUTE_ERROR_ON(biases->info()->num_dimensions() > 1);
+ }
+
+ _weights = weights;
+ _input = input;
+ _output = output;
+
+ // Get parameters from conv_info
+ unsigned int stride_x = 0;
+ unsigned int stride_y = 0;
+ std::tie(stride_x, stride_y) = conv_info.stride();
+ ARM_COMPUTE_ERROR_ON_MSG(stride_y != 1 || stride_x != 1, "Winograd layer only supports unit strides.");
+
+ // Get convolved dimensions
+ auto padding = PADDING_VALID;
+ const int in_channels = input->info()->dimension(2);
+
+ const int out_channels = output->info()->dimension(2);
+ const int weights_width = weights->info()->dimension(0);
+ const int weights_height = weights->info()->dimension(1);
+
+ const KernelShape kernel_shape({ out_channels, weights_height, weights_width, in_channels });
+ const Tensor4DShape in_shape(internal_get_input_shape(input));
+
+ // Get the memory required to instantiate a new Winograd operator.
+ constexpr size_t kstore_alignment = 64;
+ const size_t kernel_storage_per_thread = Winograd3x3F32::get_kernel_storage_size(kernel_shape);
+ _kernel_storage.allocator()->init(TensorInfo(TensorShape{ (kernel_storage_per_thread + kstore_alignment - 1) }, 1, DataType::U8));
+ _memory_group.manage(&_kernel_storage);
+
+ // Get workbench size and allocate memory
+ constexpr size_t wspace_alignment = 64;
+ const size_t ws_size = Winograd3x3F32::get_working_space_size(in_shape, kernel_shape, padding);
+ _workspace.allocator()->init(TensorInfo(TensorShape{ (ws_size + wspace_alignment - 1) }, 1, DataType::U8));
+ _memory_group.manage(&_workspace);
+
+ // Workspace for weights transform
+ const size_t weights_transform_size = Winograd3x3F32::get_kernel_transform_working_size(kernel_shape);
+ _weights_workspace.allocator()->init(TensorInfo(TensorShape{ (weights_transform_size + wspace_alignment - 1) }, 1, DataType::U8));
+ _memory_group.manage(&_weights_workspace);
+
+ _kernel_storage.allocator()->allocate();
+ _workspace.allocator()->allocate();
+ _weights_workspace.allocator()->allocate();
+
+ // Create Winograd operator object
+ _conv = support::cpp14::make_unique<Winograd3x3F32>(kernel_shape, in_shape, padding, _kernel_storage.buffer());
+
+ // Configure the kernel, padding not needed so it's safe to call configure after allocare
+ _winograd_kernel.configure(output, _conv.get());
+}
+
+void NEWinogradLayer::run()
+{
+#if defined(__aarch64__)
+ _memory_group.acquire();
+ if(!_reshaped_kernel)
+ {
+ _conv->transform_weights(reinterpret_cast<const float *>(_weights->buffer()), reinterpret_cast<float *>(_weights_workspace.buffer()));
+ _reshaped_kernel = true;
+ }
+ const Tensor4DShape in_shape(internal_get_input_shape(_input));
+ auto padding = PADDING_VALID;
+
+ //Bring channels to the front as Winograd code expects the tensor to be in the format NHWC
+ _conv->nchw2nhwc(in_shape, padding, _workspace.buffer(), reinterpret_cast<const float *>(_input->buffer()));
+
+ //Get ptrs into the workspace
+ std::pair<float *, float *> nhwc_ptrs = _conv->get_nhwc_ptrs(in_shape, padding, _workspace.buffer());
+
+ //Setup matrices ptrs and transfor the input tensor to the appropriate form before running GEMM.
+ _conv->reshape_input(in_shape, padding, nhwc_ptrs.second, _workspace.buffer());
+
+ //Run 16 GEMMs in multiple threads, each kernel runs one or more GEMMs
+ NEScheduler::get().schedule(&_winograd_kernel, Window::DimY);
+
+ //Transform the output to the appropriate form
+ _conv->reshape_output(in_shape, padding, nhwc_ptrs.first);
+
+ //Transform back to NCHW
+ _conv->nhwc2nchw(in_shape, padding, _workspace.buffer(), reinterpret_cast<float *>(_output->buffer()));
+
+ _memory_group.release();
+#else /* __aarch64__ */
+ ARM_COMPUTE_UNUSED(_winograd_kernel);
+ ARM_COMPUTE_UNUSED(_workspace);
+ ARM_COMPUTE_UNUSED(_kernel_storage);
+ ARM_COMPUTE_UNUSED(_input);
+ ARM_COMPUTE_UNUSED(_weights);
+ ARM_COMPUTE_UNUSED(_output);
+ ARM_COMPUTE_UNUSED(_reshaped_kernel);
+ ARM_COMPUTE_UNUSED(_conv);
+ ARM_COMPUTE_ERROR("Winograd only supported for aarch64, recompile with arch=arm64-v8a.");
+#endif /* __aarch64__ */
+}
+} // namespace arm_compute
diff --git a/tests/datasets/SmallConvolutionLayerDataset.h b/tests/datasets/SmallConvolutionLayerDataset.h
index aa9d9f8899..ccdd6e16af 100644
--- a/tests/datasets/SmallConvolutionLayerDataset.h
+++ b/tests/datasets/SmallConvolutionLayerDataset.h
@@ -37,6 +37,18 @@ namespace test
{
namespace datasets
{
+class SmallWinogradLayerDataset final : public ConvolutionLayerDataset
+{
+public:
+ SmallWinogradLayerDataset()
+ {
+ // Batch size 1
+ add_config(TensorShape(8U, 8U, 2U), TensorShape(3U, 3U, 2U, 1U), TensorShape(1U), TensorShape(6U, 6U, 1U), PadStrideInfo(1, 1, 0, 0));
+ // Batch size 4
+ add_config(TensorShape(23U, 27U, 5U, 4U), TensorShape(3U, 3U, 5U, 21U), TensorShape(21U), TensorShape(21U, 25U, 21U, 4U), PadStrideInfo(1, 1, 0, 0));
+ }
+};
+
class SmallConvolutionLayerDataset final : public ConvolutionLayerDataset
{
public:
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index 5e14a7c3cc..575ffe17a9 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEWinogradLayer.h"
#include "arm_compute/runtime/Tensor.h"
#include "arm_compute/runtime/TensorAllocator.h"
#include "tests/NEON/Accessor.h"
@@ -34,6 +35,7 @@
#include "tests/framework/datasets/Datasets.h"
#include "tests/validation/Validation.h"
#include "tests/validation/fixtures/ConvolutionLayerFixture.h"
+#include "tests/validation/fixtures/WinogradLayerFixture.h"
namespace arm_compute
{
@@ -62,6 +64,23 @@ const auto CNNDataTypes = framework::dataset::make("DataType",
} // namespace
TEST_SUITE(NEON)
+
+#if defined(__aarch64__)
+TEST_SUITE(WinogradLayer)
+template <typename T>
+using NEWinogradLayerFixture = WinogradLayerValidationFixture<Tensor, Accessor, NEWinogradLayer, T>;
+
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEWinogradLayerFixture<float>, framework::DatasetMode::PRECOMMIT, datasets::SmallWinogradLayerDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f32);
+}
+
+TEST_SUITE_END()
+TEST_SUITE_END()
+#endif /* __aarch64__ */
+
TEST_SUITE(ConvolutionLayer)
DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallConvolutionLayerDataset(), datasets::LargeConvolutionLayerDataset()), CNNDataTypes),
diff --git a/tests/validation/fixtures/WinogradLayerFixture.h b/tests/validation/fixtures/WinogradLayerFixture.h
new file mode 100644
index 0000000000..a5d6fc966d
--- /dev/null
+++ b/tests/validation/fixtures/WinogradLayerFixture.h
@@ -0,0 +1,145 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_TEST_WINOGRAD_LAYER_FIXTURE
+#define ARM_COMPUTE_TEST_WINOGRAD_LAYER_FIXTURE
+
+#include "arm_compute/core/TensorShape.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "tests/AssetsLibrary.h"
+#include "tests/Globals.h"
+#include "tests/IAccessor.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Fixture.h"
+#include "tests/validation/CPP/ConvolutionLayer.h"
+#include "tests/validation/CPP/Utils.h"
+#include "tests/validation/Helpers.h"
+
+#include <random>
+
+namespace arm_compute
+{
+class NEWinogradLayer;
+
+namespace test
+{
+namespace validation
+{
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class WinogradLayerValidationFixture : public framework::Fixture
+{
+public:
+ template <typename...>
+ void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info)
+ {
+ _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info);
+ _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info);
+ }
+
+protected:
+ template <typename U>
+ void fill(U &&tensor, int i, float min, float max)
+ {
+ switch(tensor.data_type())
+ {
+ case DataType::F32:
+ {
+ std::uniform_real_distribution<> distribution(min, max);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Not supported");
+ library->fill_tensor_uniform(tensor, i);
+ break;
+ }
+ }
+ }
+
+ TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info)
+ {
+ // Create tensors
+ TensorType src = create_tensor<TensorType>(input_shape, DataType::F32, 1);
+ TensorType weights = create_tensor<TensorType>(weights_shape, DataType::F32, 1);
+ TensorType bias = create_tensor<TensorType>(bias_shape, DataType::F32, 1);
+ TensorType dst = create_tensor<TensorType>(output_shape, DataType::F32, 1);
+
+ // Create and configure function
+ FunctionType conv;
+ conv.configure(&src, &weights, nullptr, &dst, info);
+
+ ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ // Allocate tensors
+ src.allocator()->allocate();
+ weights.allocator()->allocate();
+ bias.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!weights.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ // Fill tensors
+ fill(AccessorType(src), 0, -1.f, 1.f);
+ fill(AccessorType(weights), 1, -1.f, 1.f);
+ fill(AccessorType(bias), 2, 0.f, 0.f);
+ fill(AccessorType(dst), 3, -1.f, 1.f);
+
+ // Compute NEWinogradLayer function
+ conv.run();
+
+ return dst;
+ }
+
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info)
+ {
+ // Create reference
+ SimpleTensor<T> src{ input_shape, DataType::F32, 1 };
+ SimpleTensor<T> weights{ weights_shape, DataType::F32, 1 };
+ SimpleTensor<T> bias{ bias_shape, DataType::F32, 1 };
+
+ // Fill reference
+ fill(src, 0, -1.f, 1.f);
+ fill(weights, 1, -1.f, 1.f);
+ fill(bias, 2, 0.f, 0.f);
+
+ return reference::convolution_layer<T>(src, weights, bias, output_shape, info);
+ }
+
+ TensorType _target{};
+ SimpleTensor<T> _reference{};
+ int _fractional_bits{};
+ DataType _data_type{};
+};
+
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
+#endif /* ARM_COMPUTE_TEST_WINOGRAD_LAYER_FIXTURE */