aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/convolution/winograd
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/convolution/winograd')
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp69
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp127
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp355
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp1446
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp195
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp77
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp181
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp447
8 files changed, 2897 insertions, 0 deletions
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp
new file mode 100644
index 0000000000..663b3c414f
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+namespace winograd
+{
+
+template <const int M_BLOCK, const int N_BLOCK, typename TIn, typename TOut>
+class BatchedBlockedGemm
+{
+ public:
+ /** Create a new batched blocked GEMM operator. */
+ BatchedBlockedGemm(
+ const unsigned int n_gemms,
+ const int M, const int K, const int N,
+ const int a_matrix_stride,
+ const int a_row_stride,
+ const int b_matrix_stride,
+ const int b_row_stride,
+ const int c_matrix_stride,
+ const int c_row_stride,
+ const TIn* const a_ptr,
+ const TIn* const b_ptr,
+ TOut* const c_ptr
+ );
+
+ BatchedBlockedGemm(const BatchedBlockedGemm&) = delete;
+ BatchedBlockedGemm operator=(const BatchedBlockedGemm&) = delete;
+
+ /** Get a window of work performed by the operator. */
+ unsigned int get_window() const;
+
+ /** Perform a portion of the work of the operator. */
+ void run(const unsigned int start, const unsigned int stop);
+
+ private:
+ const unsigned int n_gemms;
+ const int M, N, K;
+ const int a_matrix_stride, a_row_stride;
+ const int b_matrix_stride, b_row_stride;
+ const int c_matrix_stride, c_row_stride;
+ const TIn* const a_ptr;
+ const TIn* const b_ptr;
+ TOut* const c_ptr;
+};
+
+} // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp
new file mode 100644
index 0000000000..62a20c9eea
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp
@@ -0,0 +1,127 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
+
+template <typename TIn, typename TOut>
+inline void Gemm(const TIn* const a, const TIn* const b, TOut *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride,
+ const bool a_transposed=false,
+ const bool b_transposed=false) {
+ // Array access methods
+ const auto A = [a, a_transposed, M, K, a_row_stride] (const int i, const int j) -> TIn {
+ return a[(!a_transposed) ? i*a_row_stride + j : i + j*M];
+ };
+
+ const auto B = [b, b_transposed, K, N, b_row_stride] (const int i, const int j) -> TIn {
+ return b[(!b_transposed) ? i*b_row_stride + j : i + j*N];
+ };
+
+ const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& {
+ return c[i*c_row_stride + j];
+ };
+
+ // Perform the matrix multiplication
+ for (int i = 0; i < M; i++) {
+ for (int j = 0; j < N; j++) {
+ for (int k = 0; k < K; k++) {
+ C(i, j) += A(i, k) * B(k, j);
+ }
+ }
+ }
+}
+
+template <const int M_BLOCK, const int N_BLOCK, typename TIn, typename TOut>
+inline void BlockedGemm(
+ const TIn* const a, const TIn* const b, TOut *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ // Array access methods
+ const auto A = [a, M, K, a_row_stride] (const int i, const int j) -> TIn {
+ return a[i*a_row_stride + j];
+ };
+
+ const auto B = [b, K, N, b_row_stride] (const int i, const int j) -> TIn {
+ return b[i*b_row_stride + j];
+ };
+
+ const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& {
+ return c[i*c_row_stride + j];
+ };
+
+ const int M_BLOCKS = iceildiv(M, M_BLOCK);
+ const int N_BLOCKS = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < M_BLOCKS; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < N_BLOCKS; nblock++) {
+ // Create an appropriately sized block of accumulators
+ TOut accum[M_BLOCK][N_BLOCK];
+ for (int i = 0; i < M_BLOCK; i++) {
+ for (int j = 0; j < N_BLOCK; j++) {
+ accum[i][j] = static_cast<TOut>(0);
+ }
+ }
+
+ // Perform this portion of the matrix multiply
+ for (int k = 0; k < K; k++) {
+ // Load elements of A
+ TIn elems_a[M_BLOCK];
+ for (int i = 0; i < M_BLOCK; i++) {
+ elems_a[i] = A(mblock*M_BLOCK + i, k);
+ }
+
+ // Load elements of B
+ TIn elems_b[N_BLOCK];
+ for (int j = 0; j < N_BLOCK; j++) {
+ elems_b[j] = B(k, nblock*N_BLOCK + j);
+ }
+
+ // Perform the partial matrix multiply
+ for (int i = 0; i < M_BLOCK; i++) {
+ for (int j = 0; j < N_BLOCK; j++) {
+ accum[i][j] += elems_a[i] * elems_b[j];
+ }
+ }
+ }
+
+ // Store the partial product
+ for (int i = 0; i < M_BLOCK; i++) {
+ for (int j = 0; j < N_BLOCK; j++) {
+ C(mblock*M_BLOCK + i, nblock*N_BLOCK + j) = accum[i][j];
+ }
+ }
+ }
+ }
+}
+
+#include "gemm/a64_sgemm.hpp"
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp
new file mode 100644
index 0000000000..8073cb1896
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp
@@ -0,0 +1,355 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include <cassert>
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
+
+#ifdef __aarch64__
+
+template <>
+inline void BlockedGemm<8, 12, float, float>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int M_BLOCK = 8;
+ const int N_BLOCK = 12;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = K;
+
+ asm volatile (
+ // Create an 8x12 block of accumulators
+ " A_1 .req v27\n"
+ "sA_1 .req s27\n"
+ " A_2 .req v28\n"
+ "sA_2 .req s28\n"
+ " A_3 .req v29\n"
+ "sA_3 .req s29\n"
+ " A_4 .req v30\n"
+ "sA_4 .req s30\n"
+
+ " B_1 .req v24\n" " B_2 .req v25\n" " B_3 .req v26\n"
+ "qB_1 .req q24\n" "qB_2 .req q25\n" "qB_3 .req q26\n"
+
+ " C_11 .req v0\n" " C_12 .req v1\n" " C_13 .req v2\n"
+ " C_21 .req v3\n" " C_22 .req v4\n" " C_23 .req v5\n"
+ " C_31 .req v6\n" " C_32 .req v7\n" " C_33 .req v8\n"
+ " C_41 .req v9\n" " C_42 .req v10\n" " C_43 .req v11\n"
+ " C_51 .req v12\n" " C_52 .req v13\n" " C_53 .req v14\n"
+ " C_61 .req v15\n" " C_62 .req v16\n" " C_63 .req v17\n"
+ " C_71 .req v18\n" " C_72 .req v19\n" " C_73 .req v20\n"
+ " C_81 .req v21\n" " C_82 .req v22\n" " C_83 .req v23\n"
+
+ "qC_11 .req q0\n" "qC_12 .req q1\n" "qC_13 .req q2\n"
+ "qC_21 .req q3\n" "qC_22 .req q4\n" "qC_23 .req q5\n"
+ "qC_31 .req q6\n" "qC_32 .req q7\n" "qC_33 .req q8\n"
+ "qC_41 .req q9\n" "qC_42 .req q10\n" "qC_43 .req q11\n"
+ "qC_51 .req q12\n" "qC_52 .req q13\n" "qC_53 .req q14\n"
+ "qC_61 .req q15\n" "qC_62 .req q16\n" "qC_63 .req q17\n"
+ "qC_71 .req q18\n" "qC_72 .req q19\n" "qC_73 .req q20\n"
+ "qC_81 .req q21\n" "qC_82 .req q22\n" "qC_83 .req q23\n"
+
+ "aptr1 .req x17\n"
+ "aptr2 .req x18\n"
+ "aptr3 .req x19\n"
+ "aptr4 .req x20\n"
+ "aptr5 .req x21\n"
+ "aptr6 .req x22\n"
+ "aptr7 .req x23\n"
+
+ // Initialise accumulators with 0
+ // Initialise pointers
+ "movi C_11.4s, #0\n"
+ "add aptr1, %x[aptr], %x[a_row_stride]\n"
+ "movi C_12.4s, #0\n"
+ "add aptr2, aptr1, %x[a_row_stride]\n"
+ "movi C_13.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride]\n"
+ "movi C_21.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride]\n"
+ "movi C_22.4s, #0\n"
+ "add aptr5, aptr4, %x[a_row_stride]\n"
+ "movi C_23.4s, #0\n"
+ "add aptr6, aptr5, %x[a_row_stride]\n"
+ "movi C_31.4s, #0\n"
+ "add aptr7, aptr6, %x[a_row_stride]\n"
+ "movi C_32.4s, #0\n"
+ "ldr qB_1, [%x[bptr]]\n"
+ "movi C_33.4s, #0\n"
+ "ldr qB_2, [%x[bptr], #0x10]\n"
+ "movi C_41.4s, #0\n"
+ "prfm pldl1keep, [%x[bptr], #0x00]\n"
+ "movi C_42.4s, #0\n"
+ "prfm pldl1keep, [%x[bptr], #0x10]\n"
+ "movi C_43.4s, #0\n"
+ "prfm pldl1keep, [%x[bptr], #0x20]\n"
+ "movi C_51.4s, #0\n"
+ "prfm pldl1keep, [%x[aptr], #0x00]\n"
+ "movi C_52.4s, #0\n"
+ "prfm pldl1keep, [ aptr1, #0x00]\n"
+ "movi C_53.4s, #0\n"
+ "prfm pldl1keep, [ aptr2, #0x00]\n"
+ "movi C_61.4s, #0\n"
+ "prfm pldl1keep, [ aptr3, #0x00]\n"
+ "movi C_62.4s, #0\n"
+ "prfm pldl1keep, [ aptr4, #0x00]\n"
+ "movi C_63.4s, #0\n"
+ "prfm pldl1keep, [ aptr5, #0x00]\n"
+ "movi C_71.4s, #0\n"
+ "prfm pldl1keep, [ aptr6, #0x00]\n"
+ "movi C_72.4s, #0\n"
+ "prfm pldl1keep, [ aptr7, #0x00]\n"
+ "movi C_73.4s, #0\n"
+ "ldr sA_1, [%x[aptr]], #0x4\n"
+ "movi C_81.4s, #0\n"
+ "ldr sA_2, [ aptr1], #0x4\n"
+ "movi C_82.4s, #0\n"
+ "ldr sA_3, [ aptr2], #0x4\n"
+ "movi C_83.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 2f\n"
+
+ "1:"
+ "fmla C_11.4s, B_1.4s, A_1.s[0]\n"
+ "ldr qB_3, [%x[bptr], #0x20]\n"
+ "fmla C_12.4s, B_2.4s, A_1.s[0]\n"
+ "ldr sA_4, [ aptr3], #0x4\n"
+ "fmla C_13.4s, B_3.4s, A_1.s[0]\n"
+ "ldr sA_1, [ aptr4], #0x04\n"
+
+ "fmla C_21.4s, B_1.4s, A_2.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride]\n"
+ "fmla C_22.4s, B_2.4s, A_2.s[0]\n"
+ "prfm pldl1keep, [ aptr3, #0x10]\n"
+ "fmla C_23.4s, B_3.4s, A_2.s[0]\n"
+ "ldr sA_2, [ aptr5], #0x04\n"
+
+ "fmla C_31.4s, B_1.4s, A_3.s[0]\n"
+ "prfm pldl1keep, [%x[bptr], #0x00]\n"
+ "fmla C_32.4s, B_2.4s, A_3.s[0]\n"
+ "prfm pldl1keep, [%x[bptr], #0x10]\n"
+ "fmla C_33.4s, B_3.4s, A_3.s[0]\n"
+ "ldr sA_3, [ aptr6], #0x04\n"
+
+ "fmla C_41.4s, B_1.4s, A_4.s[0]\n"
+ "prfm pldl1keep, [%x[bptr], #0x20]\n"
+ "fmla C_42.4s, B_2.4s, A_4.s[0]\n"
+ "prfm pldl1keep, [ aptr4, #0x10]\n"
+ "fmla C_43.4s, B_3.4s, A_4.s[0]\n"
+ "ldr sA_4, [ aptr7], #0x04\n"
+
+ "fmla C_51.4s, B_1.4s, A_1.s[0]\n"
+ "prfm pldl1keep, [ aptr5, #0x10]\n"
+ "fmla C_52.4s, B_2.4s, A_1.s[0]\n"
+ "prfm pldl1keep, [ aptr6, #0x10]\n"
+ "fmla C_53.4s, B_3.4s, A_1.s[0]\n"
+ "ldr sA_1, [%x[aptr]], #0x04\n"
+
+ "fmla C_61.4s, B_1.4s, A_2.s[0]\n"
+ "prfm pldl1keep, [ aptr7, #0x10]\n"
+ "fmla C_62.4s, B_2.4s, A_2.s[0]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla C_63.4s, B_3.4s, A_2.s[0]\n"
+ "ldr sA_2, [ aptr1], #0x04\n"
+
+ "fmla C_71.4s, B_1.4s, A_3.s[0]\n"
+ "prfm pldl1keep, [%x[aptr], #0x10]\n"
+ "fmla C_72.4s, B_2.4s, A_3.s[0]\n"
+ "prfm pldl1keep, [ aptr1, #0x10]\n"
+ "fmla C_73.4s, B_3.4s, A_3.s[0]\n"
+ "ldr sA_3, [ aptr2], #0x04\n"
+
+ "fmla C_81.4s, B_1.4s, A_4.s[0]\n"
+ "prfm pldl1keep, [ aptr2, #0x10]\n"
+ "fmla C_82.4s, B_2.4s, A_4.s[0]\n"
+ "ldp qB_1, qB_2, [%x[bptr]]\n"
+ "fmla C_83.4s, B_3.4s, A_4.s[0]\n"
+ "bne 1b\n"
+
+ "2:"
+ "fmla C_11.4s, B_1.4s, A_1.s[0]\n"
+ "ldr qB_3, [%x[bptr], #0x20]\n"
+ "fmla C_12.4s, B_2.4s, A_1.s[0]\n"
+ "stp qC_11, qC_12, [%x[cptr]]\n"
+ "fmla C_13.4s, B_3.4s, A_1.s[0]\n"
+ "str qC_13, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+ "ldr sA_1, [ aptr4], #0x04\n"
+
+ "fmla C_21.4s, B_1.4s, A_2.s[0]\n"
+ "ldr sA_4, [ aptr3], #0x4\n"
+ "fmla C_22.4s, B_2.4s, A_2.s[0]\n"
+ "stp qC_21, qC_22, [%x[cptr]]\n"
+ "fmla C_23.4s, B_3.4s, A_2.s[0]\n"
+ "str qC_23, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+ "ldr sA_2, [ aptr5], #0x04\n"
+
+ "fmla C_31.4s, B_1.4s, A_3.s[0]\n"
+ "fmla C_32.4s, B_2.4s, A_3.s[0]\n"
+ "stp qC_31, qC_32, [%x[cptr]]\n"
+ "fmla C_33.4s, B_3.4s, A_3.s[0]\n"
+ "str qC_33, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+ "ldr sA_3, [ aptr6], #0x04\n"
+
+ "fmla C_41.4s, B_1.4s, A_4.s[0]\n"
+ "fmla C_42.4s, B_2.4s, A_4.s[0]\n"
+ "stp qC_41, qC_42, [%x[cptr]]\n"
+ "fmla C_43.4s, B_3.4s, A_4.s[0]\n"
+ "str qC_43, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+ "ldr sA_4, [ aptr7], #0x04\n"
+
+ "fmla C_51.4s, B_1.4s, A_1.s[0]\n"
+ "fmla C_52.4s, B_2.4s, A_1.s[0]\n"
+ "stp qC_51, qC_52, [%x[cptr]]\n"
+ "fmla C_53.4s, B_3.4s, A_1.s[0]\n"
+ "str qC_53, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+ "fmla C_61.4s, B_1.4s, A_2.s[0]\n"
+ "fmla C_62.4s, B_2.4s, A_2.s[0]\n"
+ "stp qC_61, qC_62, [%x[cptr]]\n"
+ "fmla C_63.4s, B_3.4s, A_2.s[0]\n"
+ "str qC_63, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+ "fmla C_71.4s, B_1.4s, A_3.s[0]\n"
+ "fmla C_72.4s, B_2.4s, A_3.s[0]\n"
+ "stp qC_71, qC_72, [%x[cptr]]\n"
+ "fmla C_73.4s, B_3.4s, A_3.s[0]\n"
+ "str qC_73, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+ "fmla C_81.4s, B_1.4s, A_4.s[0]\n"
+ "fmla C_82.4s, B_2.4s, A_4.s[0]\n"
+ "stp qC_81, qC_82, [%x[cptr]]\n"
+ "fmla C_83.4s, B_3.4s, A_4.s[0]\n"
+ "str qC_83, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+ // Clear aliases
+ ".unreq aptr1\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+ ".unreq aptr5\n"
+ ".unreq aptr6\n"
+ ".unreq aptr7\n"
+
+ ".unreq A_1\n" ".unreq A_2\n" ".unreq A_3\n" ".unreq A_4\n"
+ ".unreq sA_1\n" ".unreq sA_2\n" ".unreq sA_3\n" ".unreq sA_4\n"
+
+ ".unreq B_1\n" ".unreq B_2\n" ".unreq B_3\n"
+ ".unreq qB_1\n" ".unreq qB_2\n" ".unreq qB_3\n"
+
+ ".unreq C_11\n" ".unreq C_12\n" ".unreq C_13\n"
+ ".unreq C_21\n" ".unreq C_22\n" ".unreq C_23\n"
+ ".unreq C_31\n" ".unreq C_32\n" ".unreq C_33\n"
+ ".unreq C_41\n" ".unreq C_42\n" ".unreq C_43\n"
+ ".unreq C_51\n" ".unreq C_52\n" ".unreq C_53\n"
+ ".unreq C_61\n" ".unreq C_62\n" ".unreq C_63\n"
+ ".unreq C_71\n" ".unreq C_72\n" ".unreq C_73\n"
+ ".unreq C_81\n" ".unreq C_82\n" ".unreq C_83\n"
+
+ ".unreq qC_11\n" ".unreq qC_12\n" ".unreq qC_13\n"
+ ".unreq qC_21\n" ".unreq qC_22\n" ".unreq qC_23\n"
+ ".unreq qC_31\n" ".unreq qC_32\n" ".unreq qC_33\n"
+ ".unreq qC_41\n" ".unreq qC_42\n" ".unreq qC_43\n"
+ ".unreq qC_51\n" ".unreq qC_52\n" ".unreq qC_53\n"
+ ".unreq qC_61\n" ".unreq qC_62\n" ".unreq qC_63\n"
+ ".unreq qC_71\n" ".unreq qC_72\n" ".unreq qC_73\n"
+ ".unreq qC_81\n" ".unreq qC_82\n" ".unreq qC_83\n"
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
+ "v29", "v30", "x17", "x18", "x19", "x20", "x21", "x22", "x23"
+ );
+ }
+ }
+}
+
+/*****************************************************************************/
+/* 4x16 blocked GEMM with specialised tails
+ */
+#include "a64_sgemm_4x16.hpp"
+
+template <>
+inline void BlockedGemm<4, 16, float, float>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ // Despatch based on tail of K
+ switch (K % 4) {
+ case 3:
+ sgemm_4x16_impl<3>(
+ a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+ );
+ break;
+ case 2:
+ sgemm_4x16_impl<2>(
+ a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+ );
+ break;
+ case 1:
+ sgemm_4x16_impl<1>(
+ a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+ );
+ break;
+ case 0:
+ sgemm_4x16_impl<0>(
+ a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+ );
+ break;
+ default:
+ assert(false);
+ }
+}
+
+#endif // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp
new file mode 100644
index 0000000000..5cd37de7a0
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp
@@ -0,0 +1,1446 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+template <const unsigned int tail>
+inline void sgemm_4x16_impl(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+);
+
+template <>
+inline void sgemm_4x16_impl<0>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int TAIL_SIZE = 0;
+ const int M_BLOCK = 4;
+ const int N_BLOCK = 16;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = (K - TAIL_SIZE) / 4;
+
+ asm volatile(
+ "aptr2 .req X20\n"
+ "aptr3 .req X21\n"
+ "aptr4 .req X22\n"
+ "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n"
+ "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n"
+ "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n"
+ "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n"
+ "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+ "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+ "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+ "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+ "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+ "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+ "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+ "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+ "vB1 .req v20\n" "qB1 .req q20\n"
+ "vB2 .req v21\n" "qB2 .req q21\n"
+ "vB3 .req v22\n" "qB3 .req q22\n"
+ "vB4 .req v23\n" "qB4 .req q23\n"
+
+ // Clear accumulators, initialise pointers
+ "movi vC11.4s, #0\n"
+ "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+ "movi vC12.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride_bytes]\n"
+ "movi vC13.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride_bytes]\n"
+ "movi vC14.4s, #0\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "movi vC21.4s, #0\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "movi vC22.4s, #0\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "movi vC23.4s, #0\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "movi vC24.4s, #0\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "movi vC31.4s, #0\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 2f\n"
+
+ "1:" // Loop proper
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "bne 1b\n"
+
+ "2:" // Tail
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "stp qC11, qC12, [%x[cptr], #0x00]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "stp qC13, qC14, [%x[cptr], #0x20]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "stp qC21, qC22, [%x[cptr], #0x00]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "stp qC23, qC24, [%x[cptr], #0x20]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "stp qC31, qC32, [%x[cptr], #0x00]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "stp qC33, qC34, [%x[cptr], #0x20]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "stp qC41, qC42, [%x[cptr], #0x00]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "stp qC43, qC44, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+ ".unreq vB4\n" ".unreq qB4\n"
+ ".unreq vB3\n" ".unreq qB3\n"
+ ".unreq vB2\n" ".unreq qB2\n"
+ ".unreq vB1\n" ".unreq qB1\n"
+ ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+ ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+ ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+ ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+ ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+ ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+ ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+ ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+ ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+ ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+ ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+ ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory", "x20", "x21", "x22",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23"
+ );
+ }
+ }
+}
+
+template <>
+inline void sgemm_4x16_impl<1>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int TAIL_SIZE = 1;
+ const int M_BLOCK = 4;
+ const int N_BLOCK = 16;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = (K - TAIL_SIZE) / 4;
+
+ asm volatile(
+ "aptr2 .req X20\n"
+ "aptr3 .req X21\n"
+ "aptr4 .req X22\n"
+ "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n"
+ "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n"
+ "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n"
+ "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n"
+ "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+ "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+ "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+ "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+ "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+ "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+ "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+ "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+ "vB1 .req v20\n" "qB1 .req q20\n"
+ "vB2 .req v21\n" "qB2 .req q21\n"
+ "vB3 .req v22\n" "qB3 .req q22\n"
+ "vB4 .req v23\n" "qB4 .req q23\n"
+
+ // Clear accumulators, initialise pointers
+ "movi vC11.4s, #0\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "movi vC12.4s, #0\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "movi vC13.4s, #0\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "movi vC14.4s, #0\n"
+ "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+ "movi vC21.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride_bytes]\n"
+ "movi vC22.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride_bytes]\n"
+ "movi vC23.4s, #0\n"
+ "cbnz %x[k], 3f\n"
+
+ // Prepare for tail in K
+ "movi vC24.4s, #0\n"
+ "ldr sA1, [%x[aptr]], #0x04\n"
+ "movi vC31.4s, #0\n"
+ "ldr sA2, [ aptr2], #0x04\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "b 2f\n" // Jump to tail
+
+ "3:" // Prepare for loop over K
+ "movi vC24.4s, #0\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "movi vC31.4s, #0\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 4f\n"
+
+ "1:" // Loop proper
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "bne 1b\n"
+
+ "4:" // Tail iteration
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr sA1, [%x[aptr]], #0x04\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr sA2, [ aptr2], #0x04\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+
+ "2:" // Common tail
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "stp qC11, qC12, [%x[cptr], #0x00]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "ldr sA3, [ aptr3], #0x04\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "stp qC13, qC14, [%x[cptr], #0x20]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "stp qC21, qC22, [%x[cptr], #0x00]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "ldr sA4, [ aptr4], #0x04\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "stp qC23, qC24, [%x[cptr], #0x20]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "stp qC31, qC32, [%x[cptr], #0x00]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "stp qC33, qC34, [%x[cptr], #0x20]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "stp qC41, qC42, [%x[cptr], #0x00]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+ "stp qC43, qC44, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+ ".unreq vB4\n" ".unreq qB4\n"
+ ".unreq vB3\n" ".unreq qB3\n"
+ ".unreq vB2\n" ".unreq qB2\n"
+ ".unreq vB1\n" ".unreq qB1\n"
+ ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+ ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+ ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+ ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+ ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+ ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+ ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+ ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+ ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+ ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+ ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+ ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory", "x20", "x21", "x22",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23"
+ );
+ }
+ }
+}
+
+template <>
+inline void sgemm_4x16_impl<2>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int TAIL_SIZE = 2;
+ const int M_BLOCK = 4;
+ const int N_BLOCK = 16;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = (K - TAIL_SIZE) / 4;
+
+ asm volatile(
+ "aptr2 .req X20\n"
+ "aptr3 .req X21\n"
+ "aptr4 .req X22\n"
+ "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n"
+ "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n"
+ "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n"
+ "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n"
+ "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+ "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+ "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+ "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+ "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+ "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+ "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+ "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+ "vB1 .req v20\n" "qB1 .req q20\n"
+ "vB2 .req v21\n" "qB2 .req q21\n"
+ "vB3 .req v22\n" "qB3 .req q22\n"
+ "vB4 .req v23\n" "qB4 .req q23\n"
+
+ // Clear accumulators, initialise pointers
+ "movi vC11.4s, #0\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "movi vC12.4s, #0\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "movi vC13.4s, #0\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "movi vC14.4s, #0\n"
+ "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+ "movi vC21.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride_bytes]\n"
+ "movi vC22.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride_bytes]\n"
+ "movi vC23.4s, #0\n"
+ "cbnz %x[k], 3f\n"
+
+ // Prepare for tail in K
+ "movi vC24.4s, #0\n"
+ "ldr dA1, [%x[aptr]], #0x08\n"
+ "movi vC31.4s, #0\n"
+ "ldr dA2, [ aptr2], #0x08\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "b 2f\n" // Jump to tail
+
+ "3:" // Prepare for loop over K
+ "movi vC24.4s, #0\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "movi vC31.4s, #0\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 4f\n"
+
+ "1:" // Loop proper
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "bne 1b\n"
+
+ "4:" // Tail iteration
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr dA1, [%x[aptr]], #0x08\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr dA2, [ aptr2], #0x08\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+
+ "2:" // Common tail
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr dA3, [ aptr3], #0x08\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr dA4, [ aptr4], #0x08\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "stp qC11, qC12, [%x[cptr], #0x00]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "stp qC13, qC14, [%x[cptr], #0x20]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "stp qC21, qC22, [%x[cptr], #0x00]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "stp qC23, qC24, [%x[cptr], #0x20]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "stp qC31, qC32, [%x[cptr], #0x00]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "stp qC33, qC34, [%x[cptr], #0x20]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "stp qC41, qC42, [%x[cptr], #0x00]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+ "stp qC43, qC44, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+ ".unreq vB4\n" ".unreq qB4\n"
+ ".unreq vB3\n" ".unreq qB3\n"
+ ".unreq vB2\n" ".unreq qB2\n"
+ ".unreq vB1\n" ".unreq qB1\n"
+ ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+ ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+ ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+ ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+ ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+ ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+ ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+ ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+ ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+ ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+ ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+ ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory", "x20", "x21", "x22",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23"
+ );
+ }
+ }
+}
+
+template <>
+inline void sgemm_4x16_impl<3>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int TAIL_SIZE = 3;
+ const int M_BLOCK = 4;
+ const int N_BLOCK = 16;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = (K - TAIL_SIZE) / 4;
+
+ asm volatile(
+ "aptr2 .req X20\n"
+ "aptr3 .req X21\n"
+ "aptr4 .req X22\n"
+ "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n"
+ "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n"
+ "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n"
+ "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n"
+ "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+ "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+ "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+ "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+ "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+ "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+ "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+ "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+ "vB1 .req v20\n" "qB1 .req q20\n"
+ "vB2 .req v21\n" "qB2 .req q21\n"
+ "vB3 .req v22\n" "qB3 .req q22\n"
+ "vB4 .req v23\n" "qB4 .req q23\n"
+
+ // Clear accumulators, initialise pointers
+ "movi vC11.4s, #0\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "movi vC12.4s, #0\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "movi vC13.4s, #0\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "movi vC14.4s, #0\n"
+ "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+ "movi vC21.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride_bytes]\n"
+ "movi vC22.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride_bytes]\n"
+ "movi vC23.4s, #0\n"
+ "cbnz %x[k], 3f\n"
+
+ // Prepare for tail in K
+ "movi vC24.4s, #0\n"
+ "ldr dA1, [%x[aptr]], #0x08\n"
+ "movi vC31.4s, #0\n"
+ "ldr dA2, [ aptr2], #0x08\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "b 2f\n" // Jump to tail
+
+ "3:" // Prepare for loop over K
+ "movi vC24.4s, #0\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "movi vC31.4s, #0\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 4f\n"
+
+ "1:" // Loop proper
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "bne 1b\n"
+
+ "4:" // Tail iteration
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr dA1, [%x[aptr]], #0x08\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr dA2, [ aptr2], #0x08\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+
+ "2:" // Common tail
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr dA3, [ aptr3], #0x08\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr dA4, [ aptr4], #0x08\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "ldr sA1, [%x[aptr]], #0x04\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "ldr sA2, [ aptr2], #0x04\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "stp qC11, qC12, [%x[cptr], #0x00]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "ldr sA3, [ aptr3], #0x04\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "stp qC13, qC14, [%x[cptr], #0x20]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "stp qC21, qC22, [%x[cptr], #0x00]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "ldr sA4, [ aptr4], #0x04\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "stp qC23, qC24, [%x[cptr], #0x20]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "stp qC31, qC32, [%x[cptr], #0x00]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "stp qC33, qC34, [%x[cptr], #0x20]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "stp qC41, qC42, [%x[cptr], #0x00]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+ "stp qC43, qC44, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+ ".unreq vB4\n" ".unreq qB4\n"
+ ".unreq vB3\n" ".unreq qB3\n"
+ ".unreq vB2\n" ".unreq qB2\n"
+ ".unreq vB1\n" ".unreq qB1\n"
+ ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+ ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+ ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+ ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+ ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+ ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+ ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+ ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+ ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+ ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+ ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+ ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory", "x20", "x21", "x22",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23"
+ );
+ }
+ }
+}
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
new file mode 100644
index 0000000000..6dd8f5460a
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
@@ -0,0 +1,195 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
+
+namespace winograd
+{
+ /***************************************************************************/
+ /* Instance-less API */
+ template <int output_tile_rows, int output_tile_cols,
+ int kernel_rows, int kernel_cols>
+ template <typename T>
+ void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::InputTransform<T>::execute(
+ const T *inptr,
+ const Tensor4DShape& input_shape,
+ const PaddingType padding_type,
+ const int tile_M,
+ const int tile_N,
+ T *outptr_base,
+ const int matrix_stride,
+ const int matrix_batch_stride,
+ const int matrix_row_stride
+ )
+ {
+ // Compute the padding required on each edge of the image
+ const bool base_padding = (padding_type == PADDING_SAME) ? 1 : 0;
+ const int pad_top = base_padding;
+ const int pad_left = base_padding;
+ const int tile_overlap = kernel_rows - 1;
+
+ // Compute striding values (assuming NHWC ordered data)
+ const int input_col_stride = input_shape.n_channels;
+ const int input_row_stride = input_shape.n_cols * input_col_stride;
+ const int input_batch_stride = input_shape.n_rows * input_row_stride;
+ const int output_col_stride = matrix_row_stride;
+ const int output_row_stride = tile_N * output_col_stride;
+
+ // Loop over batches
+ for (int batch = 0; batch < input_shape.n_batches; batch++)
+ {
+ // Pointer to the batch
+ const T* const input_base_batch = inptr + batch * input_batch_stride;
+ T* const outptr_base_batch = outptr_base + batch * matrix_batch_stride;
+
+ // Loop over rows of tiles
+ for (int tile_i = 0; tile_i < tile_M; tile_i++)
+ {
+ // Pointer to the row
+ const int row_offset = (tile_i == 0) ?
+ 0 : ((padding_type == PADDING_VALID) ? 0 : 1);
+ const T* const input_base_row = (
+ input_base_batch + ((inner_tile_rows - (kernel_rows - 1))*tile_i - row_offset)*input_row_stride
+ );
+ T* const outptr_base_row = outptr_base_batch + tile_i*output_row_stride;
+
+ // Padding (top + bottom) for the row
+ const int row_top = tile_i*(inner_tile_rows - tile_overlap) - pad_top;
+ const int row_bottom = row_top + inner_tile_rows;
+ const int row_pad_top = (tile_i == 0) ? pad_top : 0;
+ const int row_pad_bottom = (row_bottom <= input_shape.n_rows) ? 0 : row_bottom - input_shape.n_rows;
+
+ // Process the row
+ process_tile_row(
+ tile_N, input_shape.n_channels,
+ input_base_row, input_row_stride, input_col_stride,
+ outptr_base_row, matrix_stride, matrix_row_stride,
+ row_pad_top, pad_left, row_pad_bottom, input_shape.n_cols
+ );
+ }
+ }
+ }
+
+ template <int output_tile_rows, int output_tile_cols,
+ int kernel_rows, int kernel_cols>
+ template <typename T>
+ void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::InputTransform<T>::process_tile_row(
+ const int tile_N,
+ int n_channels,
+ const T* const input_base,
+ const int input_row_stride,
+ const int input_col_stride,
+ T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const int pad_top,
+ const int row_pad_left,
+ const int pad_bottom,
+ const int n_cols
+ )
+ {
+ constexpr int tile_overlap = kernel_cols - 1;
+
+ // Loop over columns of tiles
+ for (int tile_j = 0; tile_j < tile_N; tile_j++)
+ {
+ // Padding (left + right) for the tile
+ const int t_pad_left = (tile_j == 0) ? row_pad_left : 0;
+ const int t_start = tile_j*(inner_tile_cols - tile_overlap) - row_pad_left;
+ const int t_end = t_start + inner_tile_cols;
+ const int t_pad_right = (t_end <= n_cols) ? 0 : t_end - n_cols;
+
+ // Get pointers into the inputs and outputs
+ const int col_offset = (tile_j == 0) ? 0 : row_pad_left;
+ const T* const input_base_col = (
+ input_base + ((inner_tile_cols - tile_overlap)*tile_j - col_offset)*input_col_stride
+ );
+ T* const outptr = matrix_base + tile_j*matrix_row_stride;
+
+ // Apply the specific tile processing function
+ tile_fns[pad_top][t_pad_left][pad_bottom][t_pad_right](
+ n_channels,
+ input_base_col,
+ input_row_stride,
+ input_col_stride,
+ outptr,
+ matrix_stride
+ );
+ }
+ }
+
+ /***************************************************************************/
+ template <int otr, int otc, int kr, int kc>
+ template <typename T>
+ WinogradGEMM<otr, otc, kr, kc>::InputTransform<T>::InputTransform(
+ const T* const input, /** Input tensor data */
+ const int n_batches, /** Number of batches in input tensor. */
+ const int n_rows, /** Number of rows in input tensor. */
+ const int n_cols, /** Number of columns in input tensor. */
+ const int n_channels, /** Number of channels in input tensor. */
+ const PaddingType padding, /** Padding type. */
+ T* const output, /** Base of output matrices. */
+ const int matrix_stride, /** Stride between output matrices. */
+ const int matrix_row_stride /** Stride within matrices. */
+ ) : _inptr(input), _outptr(output),
+ _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels),
+ _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride),
+ _tiles_M(iceildiv((padding == PADDING_SAME) ? n_rows : n_rows - 2, output_tile_rows)),
+ _tiles_N(iceildiv((padding == PADDING_SAME) ? n_cols : n_cols - 2, output_tile_cols)),
+ _padding_type(padding)
+ {
+ }
+
+ template <int otr, int otc, int kr, int kc>
+ template <typename T>
+ unsigned int WinogradGEMM<otr, otc, kr, kc>::InputTransform<T>::get_window() const
+ {
+ // TODO When the input transform supports multithreading, return the total
+ // number of tile rows (allowing for multiple batches). For now we return 1
+ // to indicate that the activations must be transformed as a single block.
+ return 1; // TODO _tiles_M * _n_batches;
+ }
+
+ template <int otr, int otc, int kr, int kc>
+ template <typename T>
+ void WinogradGEMM<otr, otc, kr, kc>::InputTransform<T>::run(
+ const unsigned int start, const unsigned int stop
+ )
+ {
+ // TODO When the input transform supports multithreading call execute for a
+ // portion of the tile rows.
+ (void) start;
+ (void) stop;
+
+ // For now, just do all of the work.
+ const Tensor4DShape input_shape = {
+ _n_batches, _n_rows, _n_cols, _n_channels, NHWC
+ };
+ execute(
+ _inptr, input_shape, _padding_type, _tiles_M, _tiles_N, _outptr,
+ _matrix_stride, _matrix_row_stride * _tiles_M * _tiles_N, _matrix_row_stride
+ );
+ }
+}
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp
new file mode 100644
index 0000000000..bad3ef2249
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
+using namespace winograd;
+
+
+template <int otr, int otc, int kr, int kc>
+template <typename T>
+WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::WeightsTransform(
+ const T* const input,
+ T* const output,
+ const int matrix_stride, /** Stride across matrices in the output. */
+ const int matrix_row_stride, /** Stride across rows of the matrix. */
+ const int n_output_channels,
+ const int n_input_channels
+) : inptr(input), outptr(output),
+ matrix_stride(matrix_stride), matrix_row_stride(matrix_row_stride),
+ n_output_channels(n_output_channels), n_input_channels(n_input_channels)
+{
+}
+
+
+template <int otr, int otc, int kr, int kc>
+template <typename T>
+unsigned int WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::get_window() const
+{
+ // TODO When the weights transform supports multithreading, return the number
+ // of output channels. For now we return 1 to indicate that the weights must
+ // be transformed as a single block.
+ // return n_output_channels;
+ return 1;
+}
+
+
+template <int otr, int otc, int kr, int kc>
+template <typename T>
+void WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::run(
+ const unsigned int start, const unsigned int stop
+)
+{
+ // TODO When the weights transform supports multithreading call execute for a
+ // portion of the output channels.
+ (void) start;
+ (void) stop;
+
+ // For now, just do all of the work.
+ execute(
+ n_output_channels,
+ n_input_channels,
+ inptr,
+ outptr,
+ matrix_stride,
+ matrix_row_stride
+ );
+}
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp
new file mode 100644
index 0000000000..401b2816be
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp
@@ -0,0 +1,181 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
+
+namespace winograd
+{
+ template <int output_tile_rows, int output_tile_cols,
+ int kernel_rows, int kernel_cols>
+ template <typename T>
+ void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::OutputTransform<T>::execute(
+ const Tensor4DShape &output_shape,
+ const T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const T* const biases,
+ T* const output
+ )
+ {
+ // Compute the number of tiles and hence the padding required on the bottom
+ // and right of the image.
+ const int tile_M = iceildiv(output_shape.n_rows, output_tile_rows);
+ const int tile_N = iceildiv(output_shape.n_cols, output_tile_cols);
+ const int pad_bottom = output_tile_rows*tile_M - output_shape.n_rows;
+ const int pad_right = output_tile_cols*tile_N - output_shape.n_cols;
+
+ const int matrix_tile_row_stride = tile_N * matrix_row_stride;
+ const int matrix_batch_stride = tile_M * matrix_tile_row_stride;
+ const int output_col_stride = output_shape.n_channels;
+ const int output_row_stride = output_shape.n_cols * output_col_stride;
+ const int output_batch_stride = output_shape.n_rows * output_row_stride;
+
+ // Perform the output transformation for each batch
+ for (int batch = 0; batch < output_shape.n_batches; batch++)
+ {
+ // Get batch offset for input and outputs.
+ const T* const matrix_batch = matrix_base + batch*matrix_batch_stride;
+ T* const outptr_batch = output + batch*output_batch_stride;
+
+ // Perform the output transformation for each row of the output tensor.
+ for (int tile_i = 0; tile_i < tile_M; tile_i++)
+ {
+ // Compute properties of this row of output tiles
+ const int row_pad_bottom = (tile_i < tile_M - 1) ? 0: pad_bottom;
+ const T* const matrix_tile_row = matrix_batch + tile_i * matrix_tile_row_stride;
+ T* const outptr_row = outptr_batch + output_tile_rows*tile_i*output_row_stride;
+
+ // Process the row
+ process_tile_row(
+ tile_N, output_shape.n_channels, matrix_tile_row, matrix_stride,
+ matrix_row_stride, biases,
+ outptr_row, output_row_stride, output_col_stride, row_pad_bottom,
+ pad_right
+ );
+ }
+ }
+ }
+
+ template <int output_tile_rows, int output_tile_cols,
+ int kernel_rows, int kernel_cols>
+ template <typename T>
+ void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::OutputTransform<T>::process_tile_row(
+ const int tile_N,
+ const int n_channels,
+ const T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const T* const biases,
+ T* const output,
+ const int output_row_stride,
+ const int output_col_stride,
+ const int row_pad_bottom,
+ const int row_pad_right
+ )
+ {
+ // Loop over columns of tiles
+ for (int tile_j = 0; tile_j < tile_N; tile_j++)
+ {
+ // Properties of this tile
+ const int tile_pad_right = (tile_j < tile_N - 1) ? 0 : row_pad_right;
+ const T* const matrix_row = matrix_base + tile_j * matrix_row_stride;
+ T* const outptr = output + output_tile_cols*tile_j*output_col_stride;
+
+ // Perform the output transformation
+ tile_fns[row_pad_bottom][tile_pad_right](
+ n_channels, matrix_row, matrix_stride, biases,
+ outptr, output_row_stride, output_col_stride
+ );
+ }
+ }
+
+ template <int output_tile_rows, int output_tile_cols, int kr, int kc>
+ template <typename T>
+ size_t WinogradGEMM<output_tile_rows, output_tile_cols, kr, kc>::OutputTransform<T>::bytes_read(const Tensor4DShape &shape)
+ {
+ const int M = iceildiv(shape.n_rows, output_tile_rows) *
+ iceildiv(shape.n_cols, output_tile_cols);
+ const int N = shape.n_channels;
+ return inner_tile_rows * inner_tile_cols * M * N * sizeof(T);
+ }
+
+ template <int otr, int otc, int kr, int kc>
+ template <typename T>
+ size_t WinogradGEMM<otr, otc, kr, kc>::OutputTransform<T>::bytes_written(const Tensor4DShape &shape)
+ {
+ return shape.size() * sizeof(T);
+ }
+
+ template <int output_tile_rows, int output_tile_cols, int kr, int kc>
+ template <typename T>
+ WinogradGEMM<output_tile_rows, output_tile_cols, kr, kc>::OutputTransform<T>::OutputTransform(
+ const T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const T* const biases,
+ T* const output,
+ const int n_batches,
+ const int n_rows,
+ const int n_cols,
+ const int n_channels
+ ) : _matrix_base(matrix_base), _biases(biases),
+ _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride),
+ _outptr(output), _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols),
+ _n_channels(n_channels), _tile_M(iceildiv(n_rows, output_tile_rows)),
+ _tile_N(iceildiv(n_cols, output_tile_cols))
+ {
+ }
+
+ template <int otr, int otc, int kr, int kc>
+ template <typename T>
+ unsigned int WinogradGEMM<otr, otc, kr, kc>::OutputTransform<T>::get_window() const
+ {
+ // TODO When the output transform supports multithreading, return the total
+ // number of tile rows (allowing for multiple batches). For now we return 1
+ // to indicate that the activations must be transformed as a single block.
+ return 1; // TODO _tile_M * _n_batches;
+ }
+
+ template <int otr, int otc, int kr, int kc>
+ template <typename T>
+ void WinogradGEMM<otr, otc, kr, kc>::OutputTransform<T>::run(
+ const unsigned int start, const unsigned int stop
+ )
+ {
+ // TODO When the output transform supports multithreading call execute for a
+ // portion of the tile rows.
+ (void) start;
+ (void) stop;
+
+ // For now, just do all of the work.
+ const Tensor4DShape output_shape = {
+ _n_batches, _n_rows, _n_cols, _n_channels, NHWC
+ };
+ execute(
+ output_shape, _matrix_base, _matrix_stride, _matrix_row_stride, _biases,
+ _outptr
+ );
+ }
+} // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
new file mode 100644
index 0000000000..f3b2bb10ed
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
@@ -0,0 +1,447 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "arm_compute/core/NEON/kernels/convolution/common/alloc.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/convolution.hpp"
+#include "gemm.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/profiler.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/shims.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/tensor.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
+
+#include <thread>
+#include <utility>
+#include <vector>
+
+// Generic Winograd implementation using GEMM
+namespace winograd
+{
+
+template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+class WinogradGEMM
+{
+ public:
+ // Information about the specific Winograd instance
+ static constexpr int output_tile_rows = OutputTileRows;
+ static constexpr int output_tile_cols = OutputTileCols;
+ static constexpr int kernel_rows = KernelRows;
+ static constexpr int kernel_cols = KernelCols;
+ static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1; // TODO Check
+ static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1; // TODO Check
+ static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols;
+
+ /** Transform weights from the spatial to the Winograd domain. */
+ template <typename T>
+ struct WeightsTransform
+ {
+ /** Get the bytes read during the transform. */
+ static inline size_t bytes_read(const KernelShape &shape)
+ {
+ return shape.size() * sizeof(T);
+ }
+
+ /** Get the bytes written during the transform. */
+ static inline size_t bytes_written(const KernelShape &shape)
+ {
+ const int inner_tile_size = inner_tile_rows * inner_tile_cols;
+ return (inner_tile_size * shape.n_input_channels *
+ shape.n_output_channels * sizeof(T));
+ }
+
+ /** Get the count of operations performed by the transform. */
+ static int ops_performed(const KernelShape &shape);
+
+ /** Apply the transform to a tensor. */
+ static void execute(
+ const int n_output_channels,
+ const int n_input_channels,
+ const T* const input,
+ T* const output,
+ const int matrix_stride,
+ const int matrix_row_stride
+ );
+
+ /** Create a WeightsTransform operator fixed on a given problem and set
+ * of pointers.
+ */
+ WeightsTransform(
+ const T* const input,
+ T* const output,
+ const int matrix_stride, /** Stride across matrices in the output. */
+ const int matrix_row_stride, /** Stride across rows of the matrix. */
+ const int n_output_channels, /** Number of filters. */
+ const int n_input_channels /** Number of channels in each filter. */
+ );
+
+ /** Get the window of work a given operator can perform. */
+ unsigned int get_window() const;
+
+ /** Perform work upon a window of the input. */
+ void run(const unsigned int start, const unsigned int stop);
+
+ private:
+ const T* const inptr; /** Fixed pointer to input data. */
+ T* const outptr; /** Fixed pointer to output memory. */
+ const int matrix_stride; /** Stride between output matrices. */
+ const int matrix_row_stride; /** Stride within output matrices. */
+ const int n_output_channels; /** Number of filters. */
+ const int n_input_channels; /** Number of channels in each filter. */
+ };
+
+ /** Transform input feature maps from the spatial to the Winograd domain.
+ */
+ template <typename T>
+ struct InputTransform
+ {
+ /** Get the bytes read during the transform. */
+ static size_t bytes_read(const Tensor4DShape &shape)
+ {
+ return shape.size() * sizeof(T);
+ }
+
+ /** Get the bytes written during the transform. */
+ static size_t bytes_written(const Tensor4DShape &shape)
+ {
+ const int M = iceildiv(shape.n_rows, inner_tile_rows) *
+ iceildiv(shape.n_cols, inner_tile_cols);
+ const int K = shape.n_channels;
+ return inner_tile_rows * inner_tile_cols * M * K * sizeof(T);
+ }
+
+ /** Get the count of operations performed by the transform. */
+ static int ops_performed(const Tensor4DShape &shape);
+
+ /** Apply the transform to a tensor. */
+ static void execute(
+ const T *inptr,
+ const Tensor4DShape& input_shape,
+ const PaddingType padding_type,
+ const int tile_M,
+ const int tile_N,
+ T *outptr_base,
+ const int matrix_stride,
+ const int matrix_batch_stride,
+ const int matrix_row_stride
+ );
+
+ /***********************************************************************/
+ /** Create an InputTransform operator fixed on a given problem and set of
+ * pointers.
+ */
+ InputTransform(
+ const T* const input, /** Input tensor data */
+ const int n_batches, /** Number of batches in input tensor. */
+ const int n_rows, /** Number of rows in input tensor. */
+ const int n_cols, /** Number of columns in input tensor. */
+ const int n_channels, /** Number of channels in input tensor. */
+ const PaddingType padding, /** Padding type. */
+ T* const output, /** Base of output matrices. */
+ const int matrix_stride, /** Stride between output matrices. */
+ const int matrix_row_stride /** Stride within matrices. */
+ );
+
+ /** Get the winodw of work a given operator can perform. */
+ unsigned int get_window() const;
+
+ /** Perform work upon a window of the input. */
+ void run(const unsigned int start, const unsigned int stop);
+ /***********************************************************************/
+
+ private:
+ static void process_tile_row(
+ const int tile_N,
+ int n_channels,
+ const T* const input_base,
+ const int input_row_stride,
+ const int input_col_stride,
+ T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const int row_pad_top,
+ const int row_pad_left,
+ const int row_pad_bottom,
+ const int n_cols
+ );
+
+ static constexpr int max_pad_bottom = inner_tile_rows - 1;
+ static constexpr int max_pad_right = inner_tile_cols - 1;
+
+ /** Process a single tile of the input tensor. */
+ template <int pad_top, int pad_left, int pad_bottom, int pad_right>
+ static void process_tile(int, const T*, int, int, T*, int);
+
+ // Array of methods to transform tiles of the input tensor.
+ typedef void (*TileFn)(int, const T*, int, int, T*, int);
+ static const TileFn tile_fns[2][2][max_pad_bottom][max_pad_right];
+
+ /* Member values for instance-based API. */
+ const T* const _inptr;
+ T* const _outptr;
+ const int _n_batches, _n_rows, _n_cols, _n_channels, _matrix_stride,
+ _matrix_row_stride, _tiles_M, _tiles_N;
+ const PaddingType _padding_type;
+ };
+
+ /** Transform output feature maps from the Winograd to the spatial domain.
+ */
+ template <typename T>
+ struct OutputTransform
+ {
+ /** Get the bytes read during the transform. */
+ static size_t bytes_read(const Tensor4DShape &shape);
+
+ /** Get the bytes written during the transform. */
+ static size_t bytes_written(const Tensor4DShape &shape);
+
+ /** Get the count of operations performed by the transform. */
+ static int ops_performed(const Tensor4DShape &shape);
+
+ /** Apply the transform to create a tensor. */
+ static void execute(
+ const Tensor4DShape &output_shape,
+ const T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const T* const biases,
+ T* const output
+ );
+
+ /***********************************************************************/
+ /** Create an OutputTransform operator fixed on a given problem and set
+ * of pointers.
+ */
+ OutputTransform(
+ const T* const matrix_base, /** Pointer to base of matrices. */
+ const int matrix_stride, /** Stride between matrices. */
+ const int matrix_row_stride, /** Stride within a matrix. */
+ const T* const biases, /** Pointer to biases vector. */
+ T* const output, /** Pointer to output tensor. */
+ const int n_batches, /** Number of batches in output tensor. */
+ const int n_rows, /** Number of rows in output tensor. */
+ const int n_cols, /** Number of columns in output tensor. */
+ const int n_channels /** Number of channels in output tensor. */
+ );
+
+ /** Get the window of work a given operator can perform. */
+ unsigned int get_window() const;
+
+ /** Perform work upon a window of the input. */
+ void run(const unsigned int start, const unsigned int stop);
+ /***********************************************************************/
+
+ private:
+ static void process_tile_row(
+ const int tile_N,
+ const int n_channels,
+ const T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const T* const biases,
+ T* const output,
+ const int output_row_stride,
+ const int output_col_stride,
+ const int row_pad_bottom,
+ const int row_pad_right
+ );
+
+ // Limits on the amount of anti-padding to be applied
+ static constexpr int max_pad_bottom = output_tile_rows;
+ static constexpr int max_pad_right = output_tile_cols;
+
+ /** Prepare a single tile of the output tensor. */
+ template <int pad_bottom, int pad_right>
+ static void process_tile(int, const T*, int, const T*, T*, int, int);
+
+ // Array of methods to produce tiles of output tensor.
+ typedef void (*TileFn)(int, const T*, int, const T*, T*, int, int);
+ static const TileFn tile_fns[max_pad_bottom][max_pad_right];
+
+ /** Member constants for instances of the transform. */
+ const T* const _matrix_base;
+ const T* const _biases;
+ const int _matrix_stride, _matrix_row_stride;
+ T* const _outptr;
+ const int _n_batches, _n_rows, _n_cols, _n_channels, _tile_M, _tile_N;
+ };
+
+ /** Perform a convolution.
+ */
+ template <typename TOut, typename TIn>
+ class Convolution
+ {
+ public:
+ // Information about the typed Winograd instance
+ typedef TOut OutputType;
+ typedef TIn InputType;
+
+ /** Create a new Winograd operator. */
+ Convolution(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding,
+ void *kernel_storage=NULL
+ );
+
+ Convolution(const Convolution&) = delete;
+ Convolution operator=(const Convolution&) = delete;
+
+ /** Create a new Winograd operator and initialise the weights. */
+ Convolution(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding,
+ const TIn* const kernel,
+ void *kernel_storage=NULL,
+ void *transform_working_space=NULL
+ );
+
+ /** Clean up a convolution engine. */
+ ~Convolution();
+
+ /** Transform the weights into the Winograd domain. */
+ template <typename WeightsTransform=WeightsTransform<TIn>>
+ void transform_weights(
+ const TIn* const kernel,
+ void *transform_working_space=NULL
+ );
+
+ /* Apply the Winograd operator to some input. */
+ void execute(
+ TOut* const output,
+ const TIn* const input,
+ const TOut* const biases,
+ void* working_space=NULL,
+ const int n_threads=1
+ );
+
+ /* Apply the Winograd operator to some input. */
+ void execute(
+ TOut* const output,
+ const TIn* const input,
+ const TOut* const biases,
+ const int n_threads
+ );
+
+ /** Get the output shape of a convolution. */
+ static Tensor4DShape get_output_shape(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &in_shape,
+ const PaddingType padding
+ );
+
+ /* Get the memory required to transform the kernel.
+ */
+ static size_t get_kernel_transform_working_size(const KernelShape &shape);
+
+ /** Get the memory required to store the kernel transformed into the
+ * Winograd domain.
+ */
+ static size_t get_kernel_storage_size(const KernelShape &shape);
+
+ /** Get the memory required to store the input tensor transformed into
+ * the Winograd domain.
+ */
+ static size_t get_input_storage_size(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ /** Get the memory required to store the output tensor in the Winograd
+ * domain.
+ */
+ static size_t get_output_storage_size(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ /** Get the memory required to apply a Winograd operator to some input.
+ */
+ static size_t get_working_space_size(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ /* Get the memory required by a single "input" matrix.
+ */
+ static size_t get_input_matrix_size(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ static int get_input_matrix_stride(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ /* Get the memory required by a single "output" matrix.
+ */
+ static size_t get_output_matrix_size(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ static int get_output_matrix_stride(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ /* Get the memory required by a single "kernel" matrix.
+ */
+ static size_t get_kernel_matrix_size(const KernelShape &shape);
+ static int get_kernel_matrix_stride(const KernelShape &shape);
+
+ static constexpr int M_BLOCK = 4; /** Size of block used by GEMM. */
+ static constexpr int N_BLOCK = 16; /** Size of block used by GEMM. */
+
+ private:
+ const KernelShape kernel_shape; /** Shape of the kernel to be applied. */
+ TIn *kernel_matrices[N_GEMMS]; /** Pointers into the kernel matrices. */
+ const int kernel_matrix_row_stride; /** Stride within the kernel matrices. */
+
+ const bool manage_kernel_storage; /** Kernel storage is managed by the instance. */
+ void* const _kernel_storage; /** Base pointer for kernel storage. */
+
+ const Tensor4DShape input_shape; /** Shape of the input tensor. */
+ const PaddingType padding; /** Padding applied by the operator. */
+
+ const Tensor4DShape output_shape; /** Output shape produced by the operator. */
+
+ const int tile_rows; /** Number of rows of tiles. */
+ const int tile_cols; /** Number of columns of tiles. */
+ const int M, K, N; /** Sizes of underlying fundamental matrix multiplications. */
+
+ profiler prof;
+ };
+};
+
+} // namespace winograd