From 9ceebbeb8dfe61746fdc7022a147f8e2d24c5493 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Wed, 10 Jan 2018 16:44:13 +0000 Subject: COMPMID-815: Updated NEWinogradLayer with the lastest code from Research. Change-Id: I86d7f53b5f5d1dbc22078aea5c32b08a25d1f49e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/116634 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- .../core/NEON/kernels/NEWinogradLayerKernel.h | 53 +- arm_compute/core/NEON/kernels/winograd/alloc.hpp | 1 + arm_compute/core/NEON/kernels/winograd/arm.hpp | 39 + .../NEON/kernels/winograd/batched_blocked_gemm.hpp | 69 + .../core/NEON/kernels/winograd/convolution.hpp | 29 + .../NEON/kernels/winograd/direct_convolution.hpp | 34 + arm_compute/core/NEON/kernels/winograd/gemm.hpp | 127 ++ .../core/NEON/kernels/winograd/gemm/a64_sgemm.hpp | 355 +++++ .../NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp | 1446 ++++++++++++++++++++ arm_compute/core/NEON/kernels/winograd/perf.h | 9 + .../core/NEON/kernels/winograd/profiler.hpp | 326 +++++ arm_compute/core/NEON/kernels/winograd/shims.hpp | 747 ++++++++++ arm_compute/core/NEON/kernels/winograd/tensor.hpp | 225 ++- .../core/NEON/kernels/winograd/tensor_utils.hpp | 43 + .../NEON/kernels/winograd/transforms/input.hpp | 195 +++ .../NEON/kernels/winograd/transforms/kernel.hpp | 77 ++ .../NEON/kernels/winograd/transforms/output.hpp | 174 +++ arm_compute/core/NEON/kernels/winograd/utils.hpp | 37 + .../core/NEON/kernels/winograd/winograd_gemm.hpp | 441 ++++++ .../core/NEON/kernels/winograd/winograd_layer.hpp | 128 ++ 20 files changed, 4414 insertions(+), 141 deletions(-) create mode 100644 arm_compute/core/NEON/kernels/winograd/arm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/convolution.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/gemm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/perf.h create mode 100644 arm_compute/core/NEON/kernels/winograd/profiler.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/shims.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/input.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/output.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/utils.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp (limited to 'arm_compute/core') diff --git a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h index 73b7e8d2b7..95261929ca 100644 --- a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h +++ b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, 2018 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -25,6 +25,7 @@ #define __ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__ #include "arm_compute/core/NEON/INEKernel.h" +#include "arm_compute/core/NEON/kernels/winograd/convolution.hpp" #include "arm_compute/core/NEON/kernels/winograd/tensor.hpp" namespace arm_compute @@ -36,11 +37,25 @@ class Winograd3x3F32 final { public: friend class NEWinogradLayerKernel; - Winograd3x3F32(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage); + Winograd3x3F32( + const int n_batches, /** Number of batches in the input and output tensors. */ + const int n_input_channels, /** Number of feature maps in a batch of the input tensor. */ + const int n_input_rows, /** Number of rows in a feature map of the input tensor. */ + const int n_input_cols, /** Number of columns in a feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding, /** Use "SAME" padding, otherwise use "VALID". */ + const float *const weights, /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */ + float *const weights_storage, /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */ + const float *const input, /** Pointer to NHWC ordered input tensor, in the spatial domain. */ + float *const winograd_input, /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */ + float *const output, /** Pointer to NHWC ordered output tensor, in the spatial domain. */ + float *const winograd_output /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */ + ); + ~Winograd3x3F32(); - void transform_weights(const void *const kernel, void *transform_working_space); - void reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const void *const input, void *working_space); - void reshape_output(const Tensor4DShape &input_shape, const PaddingType padding_type, void *const output); + void transform_weights(); + void transform_input(); + void transform_output(); private: class Private; @@ -75,15 +90,29 @@ public: /* Get the memory required to instantiate a new Winograd operator. */ - static size_t get_kernel_storage_size(const KernelShape &shape); + static size_t get_weight_storage_size( + const int n_output_channels, /** Number of output feature maps. */ + const int n_input_channels /** Number of input feature maps. */ + ); - /* Get the memory required to apply a Winograd operator to some input. - */ - static size_t get_working_space_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, const PaddingType padding); + static unsigned int get_input_storage_size( + const int n_batches, /** Number of batches in the input tensor. */ + const int n_channels, /** Number of feature maps in the input tensor. */ + const int n_rows, /** Number of rows in each feature map. */ + const int n_cols, /** Number of columns in each feature map. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ + ); - /* Get the memory required to transform the kernel. - */ - static size_t get_kernel_transform_working_size(const KernelShape &shape); + /** Determine how much memory (in units of TOut) to allocate for the + * (Winograd domain) output. + */ + static unsigned int get_output_storage_size( + const int n_batches, /** Number of batches in the output tensor. */ + const int n_rows, /** Number of rows in each feature map of the input tensor. */ + const int n_cols, /** Number of columns in each feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ + ); protected: Winograd3x3F32 *_convolver; diff --git a/arm_compute/core/NEON/kernels/winograd/alloc.hpp b/arm_compute/core/NEON/kernels/winograd/alloc.hpp index ef6f2b5115..799e95d3e6 100644 --- a/arm_compute/core/NEON/kernels/winograd/alloc.hpp +++ b/arm_compute/core/NEON/kernels/winograd/alloc.hpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + #pragma once #ifdef ALLOC_ALIGN diff --git a/arm_compute/core/NEON/kernels/winograd/arm.hpp b/arm_compute/core/NEON/kernels/winograd/arm.hpp new file mode 100644 index 0000000000..90e7828553 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/arm.hpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** Sets the macro __arm_any__ if compiling for Aarch32 or Aarch64. + * Includes `arm_neon.h` if compiling for either architecture. + */ + +#ifdef __arm__ +#define __arm_any__ +#endif // __arm__ + +#ifdef __aarch64__ +#define __arm_any__ +#endif // __aarch64__ + +#ifdef __arm_any__ +#include +#endif // __arm_any__ diff --git a/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp new file mode 100644 index 0000000000..663b3c414f --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +namespace winograd +{ + +template +class BatchedBlockedGemm +{ + public: + /** Create a new batched blocked GEMM operator. */ + BatchedBlockedGemm( + const unsigned int n_gemms, + const int M, const int K, const int N, + const int a_matrix_stride, + const int a_row_stride, + const int b_matrix_stride, + const int b_row_stride, + const int c_matrix_stride, + const int c_row_stride, + const TIn* const a_ptr, + const TIn* const b_ptr, + TOut* const c_ptr + ); + + BatchedBlockedGemm(const BatchedBlockedGemm&) = delete; + BatchedBlockedGemm operator=(const BatchedBlockedGemm&) = delete; + + /** Get a window of work performed by the operator. */ + unsigned int get_window() const; + + /** Perform a portion of the work of the operator. */ + void run(const unsigned int start, const unsigned int stop); + + private: + const unsigned int n_gemms; + const int M, N, K; + const int a_matrix_stride, a_row_stride; + const int b_matrix_stride, b_row_stride; + const int c_matrix_stride, c_row_stride; + const TIn* const a_ptr; + const TIn* const b_ptr; + TOut* const c_ptr; +}; + +} // namespace winograd diff --git a/arm_compute/core/NEON/kernels/winograd/convolution.hpp b/arm_compute/core/NEON/kernels/winograd/convolution.hpp new file mode 100644 index 0000000000..2ab2597785 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/convolution.hpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +enum PaddingType { + PADDING_SAME, PADDING_VALID +}; diff --git a/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp b/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp new file mode 100644 index 0000000000..725f6cab65 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include "convolution.hpp" +#include "tensor.hpp" + +void direct_convolution( + const Tensor4D& input, + const Tensor4D& kernel, + Tensor4D& output, + const PaddingType padding +); diff --git a/arm_compute/core/NEON/kernels/winograd/gemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm.hpp new file mode 100644 index 0000000000..e48d31b4e6 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/gemm.hpp @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include "utils.hpp" + +template +inline void Gemm(const TIn* const a, const TIn* const b, TOut *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride, + const bool a_transposed=false, + const bool b_transposed=false) { + // Array access methods + const auto A = [a, a_transposed, M, K, a_row_stride] (const int i, const int j) -> TIn { + return a[(!a_transposed) ? i*a_row_stride + j : i + j*M]; + }; + + const auto B = [b, b_transposed, K, N, b_row_stride] (const int i, const int j) -> TIn { + return b[(!b_transposed) ? i*b_row_stride + j : i + j*N]; + }; + + const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& { + return c[i*c_row_stride + j]; + }; + + // Perform the matrix multiplication + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < K; k++) { + C(i, j) += A(i, k) * B(k, j); + } + } + } +} + +template +inline void BlockedGemm( + const TIn* const a, const TIn* const b, TOut *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + // Array access methods + const auto A = [a, M, K, a_row_stride] (const int i, const int j) -> TIn { + return a[i*a_row_stride + j]; + }; + + const auto B = [b, K, N, b_row_stride] (const int i, const int j) -> TIn { + return b[i*b_row_stride + j]; + }; + + const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& { + return c[i*c_row_stride + j]; + }; + + const int M_BLOCKS = iceildiv(M, M_BLOCK); + const int N_BLOCKS = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < M_BLOCKS; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < N_BLOCKS; nblock++) { + // Create an appropriately sized block of accumulators + TOut accum[M_BLOCK][N_BLOCK]; + for (int i = 0; i < M_BLOCK; i++) { + for (int j = 0; j < N_BLOCK; j++) { + accum[i][j] = static_cast(0); + } + } + + // Perform this portion of the matrix multiply + for (int k = 0; k < K; k++) { + // Load elements of A + TIn elems_a[M_BLOCK]; + for (int i = 0; i < M_BLOCK; i++) { + elems_a[i] = A(mblock*M_BLOCK + i, k); + } + + // Load elements of B + TIn elems_b[N_BLOCK]; + for (int j = 0; j < N_BLOCK; j++) { + elems_b[j] = B(k, nblock*N_BLOCK + j); + } + + // Perform the partial matrix multiply + for (int i = 0; i < M_BLOCK; i++) { + for (int j = 0; j < N_BLOCK; j++) { + accum[i][j] += elems_a[i] * elems_b[j]; + } + } + } + + // Store the partial product + for (int i = 0; i < M_BLOCK; i++) { + for (int j = 0; j < N_BLOCK; j++) { + C(mblock*M_BLOCK + i, nblock*N_BLOCK + j) = accum[i][j]; + } + } + } + } +} + +#include "gemm/a64_sgemm.hpp" diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp new file mode 100644 index 0000000000..caeb48f65a --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp @@ -0,0 +1,355 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include +#include "../utils.hpp" + +#ifdef __aarch64__ + +template <> +inline void BlockedGemm<8, 12, float, float>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int M_BLOCK = 8; + const int N_BLOCK = 12; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = K; + + asm volatile ( + // Create an 8x12 block of accumulators + " A_1 .req v27\n" + "sA_1 .req s27\n" + " A_2 .req v28\n" + "sA_2 .req s28\n" + " A_3 .req v29\n" + "sA_3 .req s29\n" + " A_4 .req v30\n" + "sA_4 .req s30\n" + + " B_1 .req v24\n" " B_2 .req v25\n" " B_3 .req v26\n" + "qB_1 .req q24\n" "qB_2 .req q25\n" "qB_3 .req q26\n" + + " C_11 .req v0\n" " C_12 .req v1\n" " C_13 .req v2\n" + " C_21 .req v3\n" " C_22 .req v4\n" " C_23 .req v5\n" + " C_31 .req v6\n" " C_32 .req v7\n" " C_33 .req v8\n" + " C_41 .req v9\n" " C_42 .req v10\n" " C_43 .req v11\n" + " C_51 .req v12\n" " C_52 .req v13\n" " C_53 .req v14\n" + " C_61 .req v15\n" " C_62 .req v16\n" " C_63 .req v17\n" + " C_71 .req v18\n" " C_72 .req v19\n" " C_73 .req v20\n" + " C_81 .req v21\n" " C_82 .req v22\n" " C_83 .req v23\n" + + "qC_11 .req q0\n" "qC_12 .req q1\n" "qC_13 .req q2\n" + "qC_21 .req q3\n" "qC_22 .req q4\n" "qC_23 .req q5\n" + "qC_31 .req q6\n" "qC_32 .req q7\n" "qC_33 .req q8\n" + "qC_41 .req q9\n" "qC_42 .req q10\n" "qC_43 .req q11\n" + "qC_51 .req q12\n" "qC_52 .req q13\n" "qC_53 .req q14\n" + "qC_61 .req q15\n" "qC_62 .req q16\n" "qC_63 .req q17\n" + "qC_71 .req q18\n" "qC_72 .req q19\n" "qC_73 .req q20\n" + "qC_81 .req q21\n" "qC_82 .req q22\n" "qC_83 .req q23\n" + + "aptr1 .req x17\n" + "aptr2 .req x18\n" + "aptr3 .req x19\n" + "aptr4 .req x20\n" + "aptr5 .req x21\n" + "aptr6 .req x22\n" + "aptr7 .req x23\n" + + // Initialise accumulators with 0 + // Initialise pointers + "movi C_11.4s, #0\n" + "add aptr1, %x[aptr], %x[a_row_stride]\n" + "movi C_12.4s, #0\n" + "add aptr2, aptr1, %x[a_row_stride]\n" + "movi C_13.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride]\n" + "movi C_21.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride]\n" + "movi C_22.4s, #0\n" + "add aptr5, aptr4, %x[a_row_stride]\n" + "movi C_23.4s, #0\n" + "add aptr6, aptr5, %x[a_row_stride]\n" + "movi C_31.4s, #0\n" + "add aptr7, aptr6, %x[a_row_stride]\n" + "movi C_32.4s, #0\n" + "ldr qB_1, [%x[bptr]]\n" + "movi C_33.4s, #0\n" + "ldr qB_2, [%x[bptr], #0x10]\n" + "movi C_41.4s, #0\n" + "prfm pldl1keep, [%x[bptr], #0x00]\n" + "movi C_42.4s, #0\n" + "prfm pldl1keep, [%x[bptr], #0x10]\n" + "movi C_43.4s, #0\n" + "prfm pldl1keep, [%x[bptr], #0x20]\n" + "movi C_51.4s, #0\n" + "prfm pldl1keep, [%x[aptr], #0x00]\n" + "movi C_52.4s, #0\n" + "prfm pldl1keep, [ aptr1, #0x00]\n" + "movi C_53.4s, #0\n" + "prfm pldl1keep, [ aptr2, #0x00]\n" + "movi C_61.4s, #0\n" + "prfm pldl1keep, [ aptr3, #0x00]\n" + "movi C_62.4s, #0\n" + "prfm pldl1keep, [ aptr4, #0x00]\n" + "movi C_63.4s, #0\n" + "prfm pldl1keep, [ aptr5, #0x00]\n" + "movi C_71.4s, #0\n" + "prfm pldl1keep, [ aptr6, #0x00]\n" + "movi C_72.4s, #0\n" + "prfm pldl1keep, [ aptr7, #0x00]\n" + "movi C_73.4s, #0\n" + "ldr sA_1, [%x[aptr]], #0x4\n" + "movi C_81.4s, #0\n" + "ldr sA_2, [ aptr1], #0x4\n" + "movi C_82.4s, #0\n" + "ldr sA_3, [ aptr2], #0x4\n" + "movi C_83.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 2f\n" + + "1:" + "fmla C_11.4s, B_1.4s, A_1.s[0]\n" + "ldr qB_3, [%x[bptr], #0x20]\n" + "fmla C_12.4s, B_2.4s, A_1.s[0]\n" + "ldr sA_4, [ aptr3], #0x4\n" + "fmla C_13.4s, B_3.4s, A_1.s[0]\n" + "ldr sA_1, [ aptr4], #0x04\n" + + "fmla C_21.4s, B_1.4s, A_2.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride]\n" + "fmla C_22.4s, B_2.4s, A_2.s[0]\n" + "prfm pldl1keep, [ aptr3, #0x10]\n" + "fmla C_23.4s, B_3.4s, A_2.s[0]\n" + "ldr sA_2, [ aptr5], #0x04\n" + + "fmla C_31.4s, B_1.4s, A_3.s[0]\n" + "prfm pldl1keep, [%x[bptr], #0x00]\n" + "fmla C_32.4s, B_2.4s, A_3.s[0]\n" + "prfm pldl1keep, [%x[bptr], #0x10]\n" + "fmla C_33.4s, B_3.4s, A_3.s[0]\n" + "ldr sA_3, [ aptr6], #0x04\n" + + "fmla C_41.4s, B_1.4s, A_4.s[0]\n" + "prfm pldl1keep, [%x[bptr], #0x20]\n" + "fmla C_42.4s, B_2.4s, A_4.s[0]\n" + "prfm pldl1keep, [ aptr4, #0x10]\n" + "fmla C_43.4s, B_3.4s, A_4.s[0]\n" + "ldr sA_4, [ aptr7], #0x04\n" + + "fmla C_51.4s, B_1.4s, A_1.s[0]\n" + "prfm pldl1keep, [ aptr5, #0x10]\n" + "fmla C_52.4s, B_2.4s, A_1.s[0]\n" + "prfm pldl1keep, [ aptr6, #0x10]\n" + "fmla C_53.4s, B_3.4s, A_1.s[0]\n" + "ldr sA_1, [%x[aptr]], #0x04\n" + + "fmla C_61.4s, B_1.4s, A_2.s[0]\n" + "prfm pldl1keep, [ aptr7, #0x10]\n" + "fmla C_62.4s, B_2.4s, A_2.s[0]\n" + "subs %x[k], %x[k], #1\n" + "fmla C_63.4s, B_3.4s, A_2.s[0]\n" + "ldr sA_2, [ aptr1], #0x04\n" + + "fmla C_71.4s, B_1.4s, A_3.s[0]\n" + "prfm pldl1keep, [%x[aptr], #0x10]\n" + "fmla C_72.4s, B_2.4s, A_3.s[0]\n" + "prfm pldl1keep, [ aptr1, #0x10]\n" + "fmla C_73.4s, B_3.4s, A_3.s[0]\n" + "ldr sA_3, [ aptr2], #0x04\n" + + "fmla C_81.4s, B_1.4s, A_4.s[0]\n" + "prfm pldl1keep, [ aptr2, #0x10]\n" + "fmla C_82.4s, B_2.4s, A_4.s[0]\n" + "ldp qB_1, qB_2, [%x[bptr]]\n" + "fmla C_83.4s, B_3.4s, A_4.s[0]\n" + "bne 1b\n" + + "2:" + "fmla C_11.4s, B_1.4s, A_1.s[0]\n" + "ldr qB_3, [%x[bptr], #0x20]\n" + "fmla C_12.4s, B_2.4s, A_1.s[0]\n" + "stp qC_11, qC_12, [%x[cptr]]\n" + "fmla C_13.4s, B_3.4s, A_1.s[0]\n" + "str qC_13, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_1, [ aptr4], #0x04\n" + + "fmla C_21.4s, B_1.4s, A_2.s[0]\n" + "ldr sA_4, [ aptr3], #0x4\n" + "fmla C_22.4s, B_2.4s, A_2.s[0]\n" + "stp qC_21, qC_22, [%x[cptr]]\n" + "fmla C_23.4s, B_3.4s, A_2.s[0]\n" + "str qC_23, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_2, [ aptr5], #0x04\n" + + "fmla C_31.4s, B_1.4s, A_3.s[0]\n" + "fmla C_32.4s, B_2.4s, A_3.s[0]\n" + "stp qC_31, qC_32, [%x[cptr]]\n" + "fmla C_33.4s, B_3.4s, A_3.s[0]\n" + "str qC_33, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_3, [ aptr6], #0x04\n" + + "fmla C_41.4s, B_1.4s, A_4.s[0]\n" + "fmla C_42.4s, B_2.4s, A_4.s[0]\n" + "stp qC_41, qC_42, [%x[cptr]]\n" + "fmla C_43.4s, B_3.4s, A_4.s[0]\n" + "str qC_43, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_4, [ aptr7], #0x04\n" + + "fmla C_51.4s, B_1.4s, A_1.s[0]\n" + "fmla C_52.4s, B_2.4s, A_1.s[0]\n" + "stp qC_51, qC_52, [%x[cptr]]\n" + "fmla C_53.4s, B_3.4s, A_1.s[0]\n" + "str qC_53, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + "fmla C_61.4s, B_1.4s, A_2.s[0]\n" + "fmla C_62.4s, B_2.4s, A_2.s[0]\n" + "stp qC_61, qC_62, [%x[cptr]]\n" + "fmla C_63.4s, B_3.4s, A_2.s[0]\n" + "str qC_63, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + "fmla C_71.4s, B_1.4s, A_3.s[0]\n" + "fmla C_72.4s, B_2.4s, A_3.s[0]\n" + "stp qC_71, qC_72, [%x[cptr]]\n" + "fmla C_73.4s, B_3.4s, A_3.s[0]\n" + "str qC_73, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + "fmla C_81.4s, B_1.4s, A_4.s[0]\n" + "fmla C_82.4s, B_2.4s, A_4.s[0]\n" + "stp qC_81, qC_82, [%x[cptr]]\n" + "fmla C_83.4s, B_3.4s, A_4.s[0]\n" + "str qC_83, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + // Clear aliases + ".unreq aptr1\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + ".unreq aptr5\n" + ".unreq aptr6\n" + ".unreq aptr7\n" + + ".unreq A_1\n" ".unreq A_2\n" ".unreq A_3\n" ".unreq A_4\n" + ".unreq sA_1\n" ".unreq sA_2\n" ".unreq sA_3\n" ".unreq sA_4\n" + + ".unreq B_1\n" ".unreq B_2\n" ".unreq B_3\n" + ".unreq qB_1\n" ".unreq qB_2\n" ".unreq qB_3\n" + + ".unreq C_11\n" ".unreq C_12\n" ".unreq C_13\n" + ".unreq C_21\n" ".unreq C_22\n" ".unreq C_23\n" + ".unreq C_31\n" ".unreq C_32\n" ".unreq C_33\n" + ".unreq C_41\n" ".unreq C_42\n" ".unreq C_43\n" + ".unreq C_51\n" ".unreq C_52\n" ".unreq C_53\n" + ".unreq C_61\n" ".unreq C_62\n" ".unreq C_63\n" + ".unreq C_71\n" ".unreq C_72\n" ".unreq C_73\n" + ".unreq C_81\n" ".unreq C_82\n" ".unreq C_83\n" + + ".unreq qC_11\n" ".unreq qC_12\n" ".unreq qC_13\n" + ".unreq qC_21\n" ".unreq qC_22\n" ".unreq qC_23\n" + ".unreq qC_31\n" ".unreq qC_32\n" ".unreq qC_33\n" + ".unreq qC_41\n" ".unreq qC_42\n" ".unreq qC_43\n" + ".unreq qC_51\n" ".unreq qC_52\n" ".unreq qC_53\n" + ".unreq qC_61\n" ".unreq qC_62\n" ".unreq qC_63\n" + ".unreq qC_71\n" ".unreq qC_72\n" ".unreq qC_73\n" + ".unreq qC_81\n" ".unreq qC_82\n" ".unreq qC_83\n" + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride] "r" (a_row_stride * sizeof(float)), + [b_row_stride] "r" (b_row_stride * sizeof(float)), + [c_row_stride] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "x17", "x18", "x19", "x20", "x21", "x22", "x23" + ); + } + } +} + +/*****************************************************************************/ +/* 4x16 blocked GEMM with specialised tails + */ +#include "a64_sgemm_4x16.hpp" + +template <> +inline void BlockedGemm<4, 16, float, float>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + // Despatch based on tail of K + switch (K % 4) { + case 3: + sgemm_4x16_impl<3>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + case 2: + sgemm_4x16_impl<2>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + case 1: + sgemm_4x16_impl<1>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + case 0: + sgemm_4x16_impl<0>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + default: + assert(false); + } +} + +#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp new file mode 100644 index 0000000000..5cd37de7a0 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp @@ -0,0 +1,1446 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +template +inline void sgemm_4x16_impl( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +); + +template <> +inline void sgemm_4x16_impl<0>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 0; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC12.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC13.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC14.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC21.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC22.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC23.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC24.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC31.4s, #0\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 2f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "2:" // Tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} + +template <> +inline void sgemm_4x16_impl<1>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 1; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC12.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC13.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC14.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC21.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC22.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC23.4s, #0\n" + "cbnz %x[k], 3f\n" + + // Prepare for tail in K + "movi vC24.4s, #0\n" + "ldr sA1, [%x[aptr]], #0x04\n" + "movi vC31.4s, #0\n" + "ldr sA2, [ aptr2], #0x04\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "b 2f\n" // Jump to tail + + "3:" // Prepare for loop over K + "movi vC24.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC31.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 4f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "4:" // Tail iteration + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr sA1, [%x[aptr]], #0x04\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr sA2, [ aptr2], #0x04\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + + "2:" // Common tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "ldr sA3, [ aptr3], #0x04\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "ldr sA4, [ aptr4], #0x04\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} + +template <> +inline void sgemm_4x16_impl<2>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 2; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC12.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC13.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC14.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC21.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC22.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC23.4s, #0\n" + "cbnz %x[k], 3f\n" + + // Prepare for tail in K + "movi vC24.4s, #0\n" + "ldr dA1, [%x[aptr]], #0x08\n" + "movi vC31.4s, #0\n" + "ldr dA2, [ aptr2], #0x08\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "b 2f\n" // Jump to tail + + "3:" // Prepare for loop over K + "movi vC24.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC31.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 4f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "4:" // Tail iteration + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr dA1, [%x[aptr]], #0x08\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr dA2, [ aptr2], #0x08\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + + "2:" // Common tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr dA3, [ aptr3], #0x08\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr dA4, [ aptr4], #0x08\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} + +template <> +inline void sgemm_4x16_impl<3>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 3; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC12.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC13.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC14.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC21.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC22.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC23.4s, #0\n" + "cbnz %x[k], 3f\n" + + // Prepare for tail in K + "movi vC24.4s, #0\n" + "ldr dA1, [%x[aptr]], #0x08\n" + "movi vC31.4s, #0\n" + "ldr dA2, [ aptr2], #0x08\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "b 2f\n" // Jump to tail + + "3:" // Prepare for loop over K + "movi vC24.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC31.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 4f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "4:" // Tail iteration + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr dA1, [%x[aptr]], #0x08\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr dA2, [ aptr2], #0x08\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + + "2:" // Common tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr dA3, [ aptr3], #0x08\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr dA4, [ aptr4], #0x08\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "ldr sA1, [%x[aptr]], #0x04\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "ldr sA2, [ aptr2], #0x04\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "ldr sA3, [ aptr3], #0x04\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "ldr sA4, [ aptr4], #0x04\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} diff --git a/arm_compute/core/NEON/kernels/winograd/perf.h b/arm_compute/core/NEON/kernels/winograd/perf.h new file mode 100644 index 0000000000..0cdf742a25 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/perf.h @@ -0,0 +1,9 @@ +#pragma once + +/* Prototypes from perf.c */ + +void start_counter(int fd); +long long get_counter(int fd); +long long stop_counter(int fd); +int open_instruction_counter(void); +int open_cycle_counter(void); diff --git a/arm_compute/core/NEON/kernels/winograd/profiler.hpp b/arm_compute/core/NEON/kernels/winograd/profiler.hpp new file mode 100644 index 0000000000..01fafa9604 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/profiler.hpp @@ -0,0 +1,326 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "perf.h" +#include + +#ifdef CYCLE_PROFILING +class EventIDContainer +{ + public: + EventIDContainer() : container_lock(), event_ids() + { + } + + int get_event_id(const char *id) + { + std::lock_guard lock(container_lock); + if (!event_ids.count(id)) { + event_ids.emplace(id, event_ids.size()); + } + return event_ids[id]; + } + + unsigned int size() const + { + return event_ids.size(); + } + + auto begin() + { + return event_ids.begin(); + } + + auto end() + { + return event_ids.end(); + } + + private: + std::mutex container_lock; + std::map event_ids; +}; + + +class ThreadEventCounterContainer +{ + public: + ThreadEventCounterContainer() : container_lock(), thread_counter_fds() + { + } + + int get_counter_fd() + { + const auto id = std::this_thread::get_id(); + std::lock_guard lock(container_lock); + if (!thread_counter_fds.count(id)) + { + thread_counter_fds.emplace(id, open_cycle_counter()); + } + return thread_counter_fds[id]; + } + + ~ThreadEventCounterContainer() + { + // Close all counter file descriptors + for (auto& fd : thread_counter_fds) + { + close(fd.second); + } + } + + private: + std::mutex container_lock; + std::map thread_counter_fds; +}; +#endif // CYCLE_PROFILING + + +class profiler { +private: +#ifdef CYCLE_PROFILING + struct ProfileEntry { + int event_id; + long int bytes_read, ops, bytes_written; + long int duration; + }; + + static const int maxevents = 10000; + ProfileEntry events[maxevents]; + int currentevent; + std::mutex event_lock; + + EventIDContainer event_ids; + ThreadEventCounterContainer thread_counter_fds; + + int get_event_id(const char *id) + { + return event_ids.get_event_id(id); + } +#endif // CYCLE_PROFILING + +public: +#ifdef CYCLE_PROFILING + profiler() : + currentevent(0), + event_lock(), + event_ids(), + thread_counter_fds() + { + } + + ~profiler() { + std::lock_guard lock_events(event_lock); + + // Compute performance from recorded events + struct ProfileResult { + ProfileResult() : total_calls(0), + total_duration(0), + total_bytes_read(0), + total_ops(0), + total_bytes_written(0) { + } + + void operator+=(const ProfileEntry &rhs) { + total_calls++; + total_duration += rhs.duration; + total_bytes_read += rhs.bytes_read; + total_ops += rhs.ops; + total_bytes_written = rhs.bytes_written; + } + + float avg_duration(void) const { + return static_cast(total_duration) / + static_cast(total_calls); + } + + float bytes_read_per_cycle(void) const { + return static_cast(total_bytes_read) / + static_cast(total_duration); + } + + float ops_per_cycle(void) const { + return static_cast(total_ops) / + static_cast(total_duration); + } + + float bytes_written_per_cycle(void) const { + return static_cast(total_bytes_written) / + static_cast(total_duration); + } + + long int total_calls, + total_duration, + total_bytes_read, + total_ops, + total_bytes_written; + }; + + std::vector totals; + totals.resize(event_ids.size()); + for (int i = 0; i < currentevent; i++) { + const auto &event = events[i]; + totals[event.event_id] += event; + } + + // Get the longest label + int len_label = 0; + for (const auto &kv : event_ids) { + len_label = std::max(len_label, static_cast(strlen(kv.first))); + } + + // Get the longest values for every other field + const auto get_length_of_field = + [totals] (const char *title, auto f, auto len) -> size_t { + size_t l = strlen(title); + for (const auto &v : totals) { + l = std::max(l, len(f(v))); + } + return l; + }; + + // Get the strlen for an int + const auto intlen = [] (long int x) -> size_t { + size_t len = 0; + do { + x /= 10; + len++; + } while (x); + return len; + }; + + // Get the strlen for a float + const auto floatlen = [] (const int precision) { + return [precision] (float x) { + size_t len = 0; + + if (!std::isfinite(x)) { + return static_cast(3); + } + + do { + x /= 10.0f; + len++; + } while (x > 1.0f); + return len + 1 + precision; + }; + }; + + const int len_calls = get_length_of_field( + "Calls", [] (const auto &v) {return v.total_calls;}, + intlen + ); + const int len_duration = get_length_of_field( + "Duration", [] (const auto &v) {return v.total_duration;}, + intlen + ); + const int len_average_duration = get_length_of_field( + "Average", [] (const auto &v) {return v.avg_duration();}, + floatlen(2) + ); + const int len_reads_per_cycle = get_length_of_field( + "Reads / cycle", + [] (const auto &v) {return v.bytes_read_per_cycle();}, + floatlen(6) + ); + const int len_ops_per_cycle = get_length_of_field( + "Ops / cycle", + [] (const auto &v) {return v.ops_per_cycle();}, + floatlen(6) + ); + const int len_writes_per_cycle = get_length_of_field( + "Writes / cycle", + [] (const auto &v) {return v.bytes_written_per_cycle();}, + floatlen(6) + ); + + // Print header + printf( + "%*s %*s %*s %*s %*s %*s %*s\n", + len_label, "", + len_calls, "Calls", + len_duration, "Duration", + len_average_duration, "Average", + len_reads_per_cycle, "Reads / cycle", + len_ops_per_cycle, "Ops / cycle", + len_writes_per_cycle, "Writes / cycle" + ); + for (const auto &kv : event_ids) { + const auto id = kv.second; + printf( + "%*s %*ld %*ld %*.2f %*.6f %*.6f %*.6f\n", + len_label, kv.first, + len_calls, totals[id].total_calls, + len_duration, totals[id].total_duration, + len_average_duration, totals[id].avg_duration(), + len_reads_per_cycle, totals[id].bytes_read_per_cycle(), + len_ops_per_cycle, totals[id].ops_per_cycle(), + len_writes_per_cycle, totals[id].bytes_written_per_cycle() + ); + } + printf("\n"); + } +#endif // CYCLE_PROFILING + + template + void operator() (const char * event, + T func, + long int bytes_read = 0, + long int ops = 0, + long int bytes_written = 0) { +#ifdef CYCLE_PROFILING + if (currentevent==maxevents) { + func(); + } else { + const auto countfd = thread_counter_fds.get_counter_fd(); + start_counter(countfd); + func(); + long long cycs = stop_counter(countfd); + + // Store the profiling data + std::lock_guard lock_events(event_lock); + events[currentevent++] = { + get_event_id(event), bytes_read, ops, bytes_written, cycs + }; + } +#else + (void) event; + (void) bytes_read; + (void) ops; + (void) bytes_written; + func(); +#endif // CYCLE_PROFILING + } +}; diff --git a/arm_compute/core/NEON/kernels/winograd/shims.hpp b/arm_compute/core/NEON/kernels/winograd/shims.hpp new file mode 100644 index 0000000000..09e14577ff --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/shims.hpp @@ -0,0 +1,747 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include +#include "arm.hpp" + +namespace reorder { +/** Re-order a tensor from NCHW format to NHWC. + * + * @note The stride parameters are optional and are provided to allow padding in either input or output tensors. + * + * @param[in] in Input tensor in NCHW format. + * @param[out] out Output tensor, to be written in NHWC format. + * @param n_batches Number of batches in the tensors. + * @param n_channels Number of channels in the tensors + * @param n_rows Height of the tensor + * @param n_cols Width of the tensor + * @param in_batch_stride Stride over batches in the input tensor. If `0` defaults to `n_channels * in_channel_stride`. + * @param in_channel_stride Stride over channels in the input tensor. If `0` defaults to `n_rows * in_row_stride`. + * @param in_row_stride Stride over rows in the input tensor. If `0` defaults to `n_cols`. + * @param out_batch_stride Stride over batches in the output tensor. If `0` defaults to `n_rows * out_row_stride`. + * @param out_row_stride Stride over rows in the output tensor. If `0` defaults to `n_cols * out_col_stride`. + * @param out_col_stride Stride over columns in the output tensor. If `0` defaults to `n_channels`. + */ +template +inline void nchw_to_nhwc( + const T* const in, + T* const out, + const int n_batches, + const int n_channels, + const int n_rows, + const int n_cols, + int in_batch_stride=0, + int in_channel_stride=0, + int in_row_stride=0, + int out_batch_stride=0, + int out_row_stride=0, + int out_col_stride=0 +); + +/** Re-order a tensor from NHWC format to NCHW. + * + * @note The stride parameters are optional and are provided to allow padding in either input or output tensors. + * + * @param[in] in Input tensor in NHWC format. + * @param[out] out Output tensor, to be written in NCHW format. + * @param n_batches Number of batches in the tensors. + * @param n_rows Height of the tensor + * @param n_cols Width of the tensor + * @param n_channels Number of channels in the tensors + * @param in_batch_stride Stride over batches in the input tensor. If `0` defaults to `n_rows * in_row_stride`. + * @param in_row_stride Stride over rows in the input tensor. If `0` defaults to `n_cols * in_col_stride`. + * @param in_col_stride Stride over columns in the input tensor. If `0` defaults to `n_channels`. + * @param out_batch_stride Stride over batches in the output tensor. If `0` defaults to `n_channels * out_channel_stride`. + * @param out_channel_stride Stride over channels in the output tensor. If `0` defaults to `n_rows * out_row_stride`. + * @param out_row_stride Stride over rows in the output tensor. If `0` defaults to `n_cols`. + */ +template +inline void nhwc_to_nchw( + const T* const in, // Input data in NHWC form + T* const out, // Output data in NCHW form + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + int in_batch_stride=0, + int in_row_stride=0, + int in_col_stride=0, + int out_batch_stride=0, + int out_channel_stride=0, + int out_row_stride=0 +); + +/** Re-order a weight tensor from [Output feature map x Input feature map x + * Height x Width] format to [Height x Width x Input feature map x Output + * feature map] format. + */ +template +inline void ofm_ifm_h_w_to_h_w_ifm_ofm( + const T* const in, // Input in [Output x Input x Height x Width] form + T* const out, // Output in [Height x Width x Input x Output] form + const int n_output_feature_maps, + const int n_input_feature_maps, + const int n_rows, + const int n_cols, + int in_output_feature_map_stride=0, + int in_input_feature_map_stride=0, + int in_row_stride=0, + int out_row_stride=0, + int out_col_stride=0, + int out_input_feature_map_stride=0 +); + +/** Re-order a weight tensor from [Height x Width x Input feature map x Output + * feature map] format to [Output feature map x Input feature map x Height x + * Width] format. + */ +template +inline void h_w_ifm_ofm_to_ofm_ifm_h_w( + const T* const in, // Input in [Height x Width x Input x Output] form + T* const out, // Output in [Output x Input x Height x Width] form + const int n_rows, + const int n_cols, + const int n_input_feature_maps, + const int n_output_feature_maps, + int in_row_stride=0, + int in_col_stride=0, + int in_input_feature_map_stride=0, + int out_output_feature_map_stride=0, + int out_input_feature_map_stride=0, + int out_row_stride=0 +); + +/*****************************************************************************/ +/* 32-bit implementation : NCHW -> NHWC + */ +template <> +inline void nchw_to_nhwc( + const int32_t* const in, + int32_t* const out, + const int n_batches, + const int n_channels, + const int n_rows, + const int n_cols, + int in_batch_stride, + int in_channel_stride, + int in_row_stride, + int out_batch_stride, + int out_row_stride, + int out_col_stride +) +{ + typedef int32_t T; + + // Fill in the stride values + in_row_stride = (in_row_stride) ? in_row_stride : n_cols; + in_channel_stride = (in_channel_stride) ? in_channel_stride + : n_rows * in_row_stride; + in_batch_stride = (in_batch_stride) ? in_batch_stride + : n_channels * in_channel_stride; + + out_col_stride = (out_col_stride) ? out_col_stride : n_channels; + out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride; + out_batch_stride = (out_batch_stride) ? out_batch_stride + : n_rows * out_row_stride; + + // Perform the re-ordering + for (int n = 0; n < n_batches; n++) + { + const T* const in_batch = in + n*in_batch_stride; + T* const out_batch = out + n*out_batch_stride; + + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in_batch + i*in_row_stride; + T* const out_row = out_batch + i*out_row_stride; + + int j = 0, j_remaining = n_cols; +#ifdef __arm_any__ + for (; j_remaining >= 4; j += 4, j_remaining -= 4) + { + int c = 0, c_remaining = n_channels; + for (; c_remaining >= 4; c += 4, c_remaining -= 4) + { + // Read 4 channels worth of 4 columns, then zip to produce 4 columns + // worth of 4 channels. + int32x4_t channel_pixels[4]; + channel_pixels[0] = vld1q_s32(in_row + (c + 0)*in_channel_stride + j); + channel_pixels[1] = vld1q_s32(in_row + (c + 1)*in_channel_stride + j); + channel_pixels[2] = vld1q_s32(in_row + (c + 2)*in_channel_stride + j); + channel_pixels[3] = vld1q_s32(in_row + (c + 3)*in_channel_stride + j); + + const auto zip1 = vzipq_s32(channel_pixels[0], channel_pixels[2]); + const auto zip2 = vzipq_s32(channel_pixels[1], channel_pixels[3]); + const auto out_0 = vzipq_s32(zip1.val[0], zip2.val[0]); + const auto out_1 = vzipq_s32(zip1.val[1], zip2.val[1]); + + vst1q_s32(out_row + (j + 0)*out_col_stride + c, out_0.val[0]); + vst1q_s32(out_row + (j + 1)*out_col_stride + c, out_0.val[1]); + vst1q_s32(out_row + (j + 2)*out_col_stride + c, out_1.val[0]); + vst1q_s32(out_row + (j + 3)*out_col_stride + c, out_1.val[1]); + } + for (; c_remaining; c++, c_remaining--) + { + for (int _j = 0; _j < 4; _j++) + { + const T* const in_col = in_row + j + _j; + T* const out_col = out_row + (j + _j)*out_col_stride; + const T* const in_channel = in_col + c*in_channel_stride; + out_col[c] = *(in_channel); + } + } + } + for (; j_remaining >= 2; j += 2, j_remaining -= 2) + { + int c = 0, c_remaining = n_channels; + for (; c_remaining >= 2; c += 2, c_remaining -= 2) + { + // Read 2 channels worth of 2 columns, then zip to produce 2 columns + // worth of 2 channels. + int32x2_t channel_pixels[2]; + channel_pixels[0] = vld1_s32(in_row + (c + 0)*in_channel_stride + j); + channel_pixels[1] = vld1_s32(in_row + (c + 1)*in_channel_stride + j); + + const auto output = vzip_s32(channel_pixels[0], channel_pixels[1]); + + vst1_s32(out_row + (j + 0)*out_col_stride + c, output.val[0]); + vst1_s32(out_row + (j + 1)*out_col_stride + c, output.val[1]); + } + for (; c_remaining; c++, c_remaining--) + { + for (int _j = 0; _j < 2; _j++) + { + const T* const in_col = in_row + j + _j; + T* const out_col = out_row + (j + _j)*out_col_stride; + const T* const in_channel = in_col + c*in_channel_stride; + out_col[c] = *(in_channel); + } + } + } +#endif // __arm_any__ + for (; j_remaining; j++, j_remaining--) + { + const T* const in_col = in_row + j; + T* const out_col = out_row + j*out_col_stride; + + for (int c = 0; c < n_channels; c++) + { + const T* const in_channel = in_col + c*in_channel_stride; + out_col[c] = *(in_channel); + } + } + } + } +} + +template <> +inline void nchw_to_nhwc( + const uint32_t* const in, + uint32_t* const out, + const int n_batches, + const int n_channels, + const int n_rows, + const int n_cols, + int in_batch_stride, + int in_channel_stride, + int in_row_stride, + int out_batch_stride, + int out_row_stride, + int out_col_stride +) +{ + nchw_to_nhwc( + reinterpret_cast(in), + reinterpret_cast(out), + n_batches, n_channels, n_rows, n_cols, + in_batch_stride, in_channel_stride, in_row_stride, + out_batch_stride, out_row_stride, out_col_stride + ); +} + +template <> +inline void nchw_to_nhwc( + const float* const in, + float* const out, + const int n_batches, + const int n_channels, + const int n_rows, + const int n_cols, + int in_batch_stride, + int in_channel_stride, + int in_row_stride, + int out_batch_stride, + int out_row_stride, + int out_col_stride +) +{ + nchw_to_nhwc( + reinterpret_cast(in), + reinterpret_cast(out), + n_batches, n_channels, n_rows, n_cols, + in_batch_stride, in_channel_stride, in_row_stride, + out_batch_stride, out_row_stride, out_col_stride + ); +} + +/*****************************************************************************/ +/* Generic implementation : NCHW -> NHWC + */ +template +inline void nchw_to_nhwc( + const T* const in, + T* const out, + const int n_batches, + const int n_channels, + const int n_rows, + const int n_cols, + int in_batch_stride, + int in_channel_stride, + int in_row_stride, + int out_batch_stride, + int out_row_stride, + int out_col_stride +) +{ + // Fill in the stride values + in_row_stride = (in_row_stride) ? in_row_stride : n_cols; + in_channel_stride = (in_channel_stride) ? in_channel_stride + : n_rows * in_row_stride; + in_batch_stride = (in_batch_stride) ? in_batch_stride + : n_channels * in_channel_stride; + + out_col_stride = (out_col_stride) ? out_col_stride : n_channels; + out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride; + out_batch_stride = (out_batch_stride) ? out_batch_stride + : n_rows * out_row_stride; + + // Perform the re-ordering + for (int n = 0; n < n_batches; n++) + { + const T* const in_batch = in + n*in_batch_stride; + T* const out_batch = out + n*out_batch_stride; + + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in_batch + i*in_row_stride; + T* const out_row = out_batch + i*out_row_stride; + + for (int j = 0; j < n_cols; j++) + { + const T* const in_col = in_row + j; + T* const out_col = out_row + j*out_col_stride; + + for (int c = 0; c < n_channels; c++) + { + const T* const in_channel = in_col + c*in_channel_stride; + out_col[c] = *(in_channel); + } + } + } + } +} + +/*****************************************************************************/ +/* 32-bit implementation : NHWC -> NCHW + */ +template <> +inline void nhwc_to_nchw( + const int32_t* const in, // Input data in NHWC form + int32_t* const out, // Output data in NCHW form + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + int in_batch_stride, + int in_row_stride, + int in_col_stride, + int out_batch_stride, + int out_channel_stride, + int out_row_stride +) +{ + typedef int32_t T; + + // Fill in stride values + in_col_stride = (in_col_stride) ? in_col_stride : n_channels; + in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride; + in_batch_stride = (in_batch_stride) ? in_batch_stride + : n_rows * in_row_stride; + + out_row_stride = (out_row_stride) ? out_row_stride : n_cols; + out_channel_stride = (out_channel_stride) ? out_channel_stride + : n_rows * out_row_stride; + out_batch_stride = (out_batch_stride) ? out_batch_stride + : n_channels * out_channel_stride; + + // Perform the re-ordering + // For every batch + for (int n = 0; n < n_batches; n++) + { + const T* const in_batch = in + n*in_batch_stride; + T* const out_batch = out + n*out_batch_stride; + + // For every row + for (int i = 0; i < n_rows; i++) + { + const T* const in_i = in_batch + i*in_row_stride; + T* const out_i = out_batch + i*out_row_stride; + + // For every column, beginning with chunks of 4 + int j = 0, j_remaining = n_cols; +#ifdef __arm_any__ + for (; j_remaining >= 4; j += 4, j_remaining -=4) + { + // For every channel, beginning with chunks of 4 + int c = 0, c_remaining = n_channels; + for (; c_remaining >= 4; c += 4, c_remaining -= 4) + { + // Read 4 columns worth of 4 channels then zip to produce 4 channels + // worth of 4 columns. + int32x4_t pixel_channels[4]; + pixel_channels[0] = vld1q_s32(in_i + (j + 0)*in_col_stride + c); + pixel_channels[1] = vld1q_s32(in_i + (j + 1)*in_col_stride + c); + pixel_channels[2] = vld1q_s32(in_i + (j + 2)*in_col_stride + c); + pixel_channels[3] = vld1q_s32(in_i + (j + 3)*in_col_stride + c); + + const auto zip1 = vzipq_s32(pixel_channels[0], pixel_channels[2]); + const auto zip2 = vzipq_s32(pixel_channels[1], pixel_channels[3]); + const auto out_0 = vzipq_s32(zip1.val[0], zip2.val[0]); + const auto out_1 = vzipq_s32(zip1.val[1], zip2.val[1]); + + vst1q_s32(out_i + j + (c + 0)*out_channel_stride, out_0.val[0]); + vst1q_s32(out_i + j + (c + 1)*out_channel_stride, out_0.val[1]); + vst1q_s32(out_i + j + (c + 2)*out_channel_stride, out_1.val[0]); + vst1q_s32(out_i + j + (c + 3)*out_channel_stride, out_1.val[1]); + } + for (; c_remaining; c++, c_remaining--) + { + for (int _j = 0; _j < 4; _j++) + { + const T* const in_j = in_i + (j + _j)*in_col_stride; + T* const out_j = out_i + (j + _j); + + const T* const in_channel = in_j + c; + T* const out_channel = out_j + c*out_channel_stride; + *(out_channel) = *(in_channel); + } + } + } + for (; j_remaining >= 2; j += 2, j_remaining -=2) + { + int c = 0, c_remaining = n_channels; + for (; c_remaining >= 2; c += 2, c_remaining -= 2) + { + // Read 2 columns worth of 2 channels then zip to produce 2 channels + // worth of 2 columns. + int32x2_t pixel_channels[2]; + pixel_channels[0] = vld1_s32(in_i + (j + 0)*in_col_stride + c); + pixel_channels[1] = vld1_s32(in_i + (j + 1)*in_col_stride + c); + + const auto output = vzip_s32(pixel_channels[0], pixel_channels[1]); + + vst1_s32(out_i + j + (c + 0)*out_channel_stride, output.val[0]); + vst1_s32(out_i + j + (c + 1)*out_channel_stride, output.val[1]); + } + for (; c_remaining; c++, c_remaining--) + { + for (int _j = 0; _j < 2; _j++) + { + const T* const in_j = in_i + (j + _j)*in_col_stride; + T* const out_j = out_i + (j + _j); + + const T* const in_channel = in_j + c; + T* const out_channel = out_j + c*out_channel_stride; + *(out_channel) = *(in_channel); + } + } + } +#endif // __arm_any__ + for (; j_remaining; j++, j_remaining--) + { + const T* const in_j = in_i + j*in_col_stride; + T* const out_j = out_i + j; + + // For every channel + for (int c = 0; c < n_channels; c++) + { + const T* const in_channel = in_j + c; + T* const out_channel = out_j + c*out_channel_stride; + *(out_channel) = *(in_channel); + } + } + } + } +} + +template <> +inline void nhwc_to_nchw( + const uint32_t* const in, // Input data in NHWC form + uint32_t* const out, // Output data in NCHW form + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + int in_batch_stride, + int in_row_stride, + int in_col_stride, + int out_batch_stride, + int out_channel_stride, + int out_row_stride +) +{ + // Redirect to generic 32-bit implementation + nhwc_to_nchw( + reinterpret_cast(in), + reinterpret_cast(out), + n_batches, n_rows, n_cols, n_channels, + in_batch_stride, in_row_stride, in_col_stride, + out_batch_stride, out_channel_stride, out_row_stride + ); +} + +template <> +inline void nhwc_to_nchw( + const float* const in, // Input data in NHWC form + float* const out, // Output data in NCHW form + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + int in_batch_stride, + int in_row_stride, + int in_col_stride, + int out_batch_stride, + int out_channel_stride, + int out_row_stride +) +{ + // Redirect to generic 32-bit implementation + nhwc_to_nchw( + reinterpret_cast(in), + reinterpret_cast(out), + n_batches, n_rows, n_cols, n_channels, + in_batch_stride, in_row_stride, in_col_stride, + out_batch_stride, out_channel_stride, out_row_stride + ); +} + +/*****************************************************************************/ +/* Generic implementation : NHWC -> NCHW + */ +template +inline void nhwc_to_nchw( + const T* const in, // Input data in NHWC form + T* const out, // Output data in NCHW form + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + int in_batch_stride, + int in_row_stride, + int in_col_stride, + int out_batch_stride, + int out_channel_stride, + int out_row_stride +) +{ + // Fill in stride values + in_col_stride = (in_col_stride) ? in_col_stride : n_channels; + in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride; + in_batch_stride = (in_batch_stride) ? in_batch_stride + : n_rows * in_row_stride; + + out_row_stride = (out_row_stride) ? out_row_stride : n_cols; + out_channel_stride = (out_channel_stride) ? out_channel_stride + : n_rows * out_row_stride; + out_batch_stride = (out_batch_stride) ? out_batch_stride + : n_channels * out_channel_stride; + + // Perform the re-ordering + // For every batch + for (int n = 0; n < n_batches; n++) + { + const T* const in_batch = in + n*in_batch_stride; + T* const out_batch = out + n*out_batch_stride; + + // For every row + for (int i = 0; i < n_rows; i++) + { + const T* const in_i = in_batch + i*in_row_stride; + T* const out_i = out_batch + i*out_row_stride; + + // For every column + for (int j = 0; j < n_cols; j++) + { + const T* const in_j = in_i + j*in_col_stride; + T* const out_j = out_i + j; + + // For every channel + for (int c = 0; c < n_channels; c++) + { + const T* const in_channel = in_j + c; + T* const out_channel = out_j + c*out_channel_stride; + *(out_channel) = *(in_channel); + } + } + } + } +} + +/*****************************************************************************/ +/* Generic weight re-order implementation. + */ +template +inline void ofm_ifm_h_w_to_h_w_ifm_ofm( + const T* const in, // Input in [Output x Input x Height x Width] form + T* const out, // Output in [Height x Width x Input x Output] form + const int n_output_feature_maps, + const int n_input_feature_maps, + const int n_rows, + const int n_cols, + int in_output_feature_map_stride, + int in_input_feature_map_stride, + int in_row_stride, + int out_row_stride, + int out_col_stride, + int out_input_feature_map_stride +) +{ + // Fill in stride values + in_row_stride = (in_row_stride) + ? in_row_stride + : n_cols; + in_input_feature_map_stride = (in_input_feature_map_stride) + ? in_input_feature_map_stride + : n_rows * in_row_stride; + in_output_feature_map_stride = (in_output_feature_map_stride) + ? in_output_feature_map_stride + : n_input_feature_maps * in_input_feature_map_stride; + + out_input_feature_map_stride = (out_input_feature_map_stride) + ? out_input_feature_map_stride + : n_output_feature_maps; + out_col_stride = (out_col_stride) + ? out_col_stride + : n_input_feature_maps * out_input_feature_map_stride; + out_row_stride = (out_row_stride) + ? out_row_stride + : n_cols * out_col_stride; + + // Perform the re-ordering + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in + i * in_row_stride; + T* out_row = out + i * out_row_stride; + + for (int j = 0; j < n_cols; j++) + { + const T* const in_col = in_row + j; + T* const out_col = out_row + j * out_col_stride; + + for (int ifm = 0; ifm < n_input_feature_maps; ifm++) + { + const T* const in_ifm = in_col + ifm * in_input_feature_map_stride; + T* const out_ifm = out_col + ifm * out_input_feature_map_stride; + + for (int ofm = 0; ofm < n_output_feature_maps; ofm++) + { + const T* const in_ofm = in_ifm + ofm * in_output_feature_map_stride; + T* const out_ofm = out_ifm + ofm; + *(out_ofm) = *(in_ofm); + } + } + } + } +} + +/*****************************************************************************/ +/* Generic weight re-order implementation. + */ +template +inline void h_w_ifm_ofm_to_ofm_ifm_h_w( + const T* const in, // Input in [Height x Width x Input x Output] form + T* const out, // Output in [Output x Input x Height x Width] form + const int n_rows, + const int n_cols, + const int n_input_feature_maps, + const int n_output_feature_maps, + int in_row_stride, + int in_col_stride, + int in_input_feature_map_stride, + int out_output_feature_map_stride, + int out_input_feature_map_stride, + int out_row_stride +) +{ + // Fill in the stride values + in_input_feature_map_stride = (in_input_feature_map_stride) + ? in_input_feature_map_stride + : n_output_feature_maps; + in_col_stride = (in_col_stride) + ? in_col_stride + : n_input_feature_maps * in_input_feature_map_stride; + in_row_stride = (in_row_stride) + ? in_row_stride + : n_cols * in_col_stride; + + out_row_stride = (out_row_stride) + ? out_row_stride + : n_cols; + out_input_feature_map_stride = (out_input_feature_map_stride) + ? out_input_feature_map_stride + : n_rows * out_row_stride; + out_output_feature_map_stride = (out_output_feature_map_stride) + ? out_output_feature_map_stride + : n_input_feature_maps * out_input_feature_map_stride; + + // Perform the re-ordering + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in + i * in_row_stride; + T* const out_row = out + i * out_row_stride; + + for (int j = 0; j < n_cols; j++) + { + const T* const in_col = in_row + j * in_col_stride; + T* const out_col = out_row + j; + + for (int ifm = 0; ifm < n_input_feature_maps; ifm++) + { + const T* const in_ifm = in_col + ifm * in_input_feature_map_stride; + T* const out_ifm = out_col + ifm * out_input_feature_map_stride; + + for (int ofm = 0; ofm < n_output_feature_maps; ofm++) + { + const T* const in_ofm = in_ifm + ofm; + T* const out_ofm = out_ifm + ofm * out_output_feature_map_stride; + *(out_ofm) = *(in_ofm); + } + } + } + } +} + +} // namespace reorder diff --git a/arm_compute/core/NEON/kernels/winograd/tensor.hpp b/arm_compute/core/NEON/kernels/winograd/tensor.hpp index 70ef65d2a5..6567eeb23d 100644 --- a/arm_compute/core/NEON/kernels/winograd/tensor.hpp +++ b/arm_compute/core/NEON/kernels/winograd/tensor.hpp @@ -23,39 +23,44 @@ */ #pragma once -#include #include #include #include "alloc.hpp" -/*****************************************************************************/ -/* Padding definitions */ -enum PaddingType { - PADDING_SAME, PADDING_VALID +enum TensorOrder +{ + NHWC, ///< [Batch x Height x Width x Channels] + NCHW, ///< [Batch x Channels x Height x Width] }; -/*****************************************************************************/ -/* Shape of a kernel */ -struct KernelShape { - int n_output_channels, n_rows, n_cols, n_input_channels; - - int size(void) const { - return n_output_channels * n_rows * n_cols * n_input_channels; +struct Tensor4DShape +{ + int n_batches, n_rows, n_cols, n_channels; + TensorOrder ordering; + + // Create a new tensor with the default (NHWC) ordering + inline Tensor4DShape( + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + const TensorOrder ordering=NHWC + ) : n_batches(n_batches), + n_rows(n_rows), + n_cols(n_cols), + n_channels(n_channels), + ordering(ordering) + { } -}; - -struct Tensor4DShape { - int n_batches, - n_rows, - n_cols, - n_channels; - int size() const { + inline int size() const + { return n_batches * n_rows * n_cols * n_channels; } - bool TestEq(const Tensor4DShape& other) const { + inline bool TestEq(const Tensor4DShape& other) const + { return (n_batches == other.n_batches && n_rows == other.n_rows && n_cols == other.n_cols && @@ -63,148 +68,110 @@ struct Tensor4DShape { } }; + +enum WeightOrder +{ + HWIO, ///< [Height x Width x Input channels x Output channels] + OIHW, ///< [Output channels x Input channels x Height x Width] +}; + +struct KernelShape +{ + int n_output_channels, n_rows, n_cols, n_input_channels; + WeightOrder ordering; + + inline KernelShape( + const int n_output_channels, + const int n_rows, + const int n_cols, + const int n_input_channels, + const WeightOrder ordering=HWIO + ) : n_output_channels(n_output_channels), + n_rows(n_rows), + n_cols(n_cols), + n_input_channels(n_input_channels), + ordering(ordering) + { + } + + inline int size(void) const + { + return n_output_channels * n_rows * n_cols * n_input_channels; + } +}; + + template -class Tensor4D final { +class Tensor4D final +{ public: Tensor4D(ShapeT shape) : - _shape(shape), - _data(reinterpret_cast(ALLOCATE(size_bytes()))) { + shape(shape), + _data(reinterpret_cast(ALLOCATE(size_bytes()))) + { Clear(); } + Tensor4D(const Tensor4D&) = delete; + Tensor4D operator=(const Tensor4D&) = delete; + ~Tensor4D() { free(_data); } - T* ptr() const { + inline T* ptr() const { return _data; } - const ShapeT& shape() const { - return _shape; + inline size_t size_bytes() const { + return shape.size() * sizeof(T); } - size_t size_bytes() const { - return _shape.size() * sizeof(T); - } - - bool TestEq(Tensor4D& other) const; - T& element(int, int, int, int) const; - void Print() const; + inline T& element(int, int, int, int) const; - void Clear() { + inline void Clear() { Fill(static_cast(0)); } - void Fill(T val) { - for (int i = 0; i < _shape.size(); i++) + inline void Fill(T val) { + for (int i = 0; i < shape.size(); i++) _data[i] = val; } - void TestPattern() { - for (int i = 0; i < _shape.size(); i++) - _data[i] = static_cast(i); - } - - void Rand(const int seed=2311) { - std::mt19937 gen(seed); - std::uniform_int_distribution<> dis(-50, +50); - - for (int i = 0; i < _shape.size(); i++) { - _data[i] = static_cast(dis(gen)); - } - } - Tensor4D(const Tensor4D &) = delete; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - Tensor4D &operator=(const Tensor4D &) = delete; - /** Allow instances of this class to be moved */ - Tensor4D(Tensor4D &&) = default; - /** Allow instances of this class to be moved */ - Tensor4D &operator=(Tensor4D &&) = default; - + const ShapeT shape; private: - const ShapeT _shape; T* const _data; }; template <> -inline float& Tensor4D::element(int n, int i, int j, int c) const { - int index = ((n*_shape.n_rows + i)*_shape.n_cols + j)*_shape.n_channels + c; - return _data[index]; -} - - -template <> -inline float& Tensor4D::element(int oc, int i, int j, int ic) const { - int index = ((i*_shape.n_cols + j)*_shape.n_input_channels + ic)*_shape.n_output_channels + oc; - return _data[index]; -} - -template <> -inline bool Tensor4D::TestEq(Tensor4D& other) const { - // Test equivalence, printing errors - // First test the shapes are the same - if (!_shape.TestEq(other.shape())) { - printf("Tensors have different shapes.\n"); - return false; - } else { - int incorrects = 0; - - for (int n = 0; n < _shape.n_batches; n++) { - for (int i = 0; i < _shape.n_rows; i++) { - for (int j = 0; j < _shape.n_cols; j++) { - for (int c = 0; c < _shape.n_channels; c++) { - // Check elements for equivalence - const auto a = this->element(n, i, j, c); - const auto b = other.element(n, i, j, c); - - if (a != b) { - printf("Difference at element {%d, %d, %d, %d}: %.3f != %.3f\n", n, i, j, c, a, b); - - if (++incorrects > 100) { - printf("More than 100 incorrect values, stopping test.\n"); - return false; - } - } - } - } - } - } - - return incorrects == 0; +inline float& Tensor4D::element(int n, int i, int j, int c) const +{ + int index; + if (shape.ordering == NHWC) + { + index = ((n*shape.n_rows + i)*shape.n_cols + j)*shape.n_channels + c; } -} - - -template <> -inline void Tensor4D::Print() const { - for (int n = 0; n < _shape.n_batches; n++) { - for (int c = 0; c < _shape.n_channels; c++) { - for (int i = 0; i < _shape.n_rows; i++) { - for (int j = 0; j < _shape.n_cols; j++) { - printf("%5.2f ", element(n, i, j, c)); - } - printf("\n"); - } - printf("\n"); - } + else // NCHW + { + index = ((n*shape.n_channels + c)*shape.n_rows + i)*shape.n_cols + j; } + return _data[index]; } template <> -inline void Tensor4D::Print() const { - for (int oc = 0; oc < _shape.n_output_channels; oc++) { - for (int ic = 0; ic < _shape.n_input_channels; ic++) { - for (int i = 0; i < _shape.n_rows; i++) { - for (int j = 0; j < _shape.n_cols; j++) { - printf("%5.2f ", element(oc, i, j, ic)); - } - printf("\n"); - } - printf("\n"); - } +inline float& Tensor4D::element(int oc, int i, int j, int ic) const +{ + int index; + if (shape.ordering == HWIO) + { + index = ((i*shape.n_cols + j)*shape.n_input_channels + ic)*shape.n_output_channels + oc; } + else // OIHW + { + index = ((oc*shape.n_input_channels + ic)*shape.n_rows + i)*shape.n_cols + j; + } + return _data[index]; } diff --git a/arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp b/arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp new file mode 100644 index 0000000000..68a5c6a178 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include "tensor.hpp" + +// Methods to print tensors and weights +void PrintTensor(const Tensor4D& tensor); +void PrintWeights(const Tensor4D& weights); + +// Test the equivalence of two tensors +bool CmpTensors(const Tensor4D& a, + const Tensor4D& b, + const float max_delta=0.0f); + +// Fill the tensor with a test pattern +void TestPattern(Tensor4D& tensor); +void TestPattern(Tensor4D& weights); + +// Fill the tensor with random values +void Randomise(Tensor4D& tensor, const int seed=0); +void Randomise(Tensor4D& weights, const int seed=0); diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp new file mode 100644 index 0000000000..39b444184e --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include "../winograd_gemm.hpp" + +namespace winograd +{ + /***************************************************************************/ + /* Instance-less API */ + template + template + void WinogradGEMM::InputTransform::execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride + ) + { + // Compute the padding required on each edge of the image + const bool base_padding = (padding_type == PADDING_SAME) ? 1 : 0; + const int pad_top = base_padding; + const int pad_left = base_padding; + const int tile_overlap = kernel_rows - 1; + + // Compute striding values (assuming NHWC ordered data) + const int input_col_stride = input_shape.n_channels; + const int input_row_stride = input_shape.n_cols * input_col_stride; + const int input_batch_stride = input_shape.n_rows * input_row_stride; + const int output_col_stride = matrix_row_stride; + const int output_row_stride = tile_N * output_col_stride; + + // Loop over batches + for (int batch = 0; batch < input_shape.n_batches; batch++) + { + // Pointer to the batch + const T* const input_base_batch = inptr + batch * input_batch_stride; + T* const outptr_base_batch = outptr_base + batch * matrix_batch_stride; + + // Loop over rows of tiles + for (int tile_i = 0; tile_i < tile_M; tile_i++) + { + // Pointer to the row + const int row_offset = (tile_i == 0) ? + 0 : ((padding_type == PADDING_VALID) ? 0 : 1); + const T* const input_base_row = ( + input_base_batch + ((inner_tile_rows - 2)*tile_i - row_offset)*input_row_stride + ); + T* const outptr_base_row = outptr_base_batch + tile_i*output_row_stride; + + // Padding (top + bottom) for the row + const int row_top = tile_i*(inner_tile_rows - tile_overlap) - pad_top; + const int row_bottom = row_top + inner_tile_rows; + const int row_pad_top = (tile_i == 0) ? pad_top : 0; + const int row_pad_bottom = (row_bottom <= input_shape.n_rows) ? 0 : row_bottom - input_shape.n_rows; + + // Process the row + process_tile_row( + tile_N, input_shape.n_channels, + input_base_row, input_row_stride, input_col_stride, + outptr_base_row, matrix_stride, matrix_row_stride, + row_pad_top, pad_left, row_pad_bottom, input_shape.n_cols + ); + } + } + } + + template + template + void WinogradGEMM::InputTransform::process_tile_row( + const int tile_N, + int n_channels, + const T* const input_base, + const int input_row_stride, + const int input_col_stride, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + const int pad_top, + const int row_pad_left, + const int pad_bottom, + const int n_cols + ) + { + constexpr int tile_overlap = kernel_cols - 1; + + // Loop over columns of tiles + for (int tile_j = 0; tile_j < tile_N; tile_j++) + { + // Padding (left + right) for the tile + const int t_pad_left = (tile_j == 0) ? row_pad_left : 0; + const int t_start = tile_j*(inner_tile_cols - tile_overlap) - row_pad_left; + const int t_end = t_start + inner_tile_cols; + const int t_pad_right = (t_end <= n_cols) ? 0 : t_end - n_cols; + + // Get pointers into the inputs and outputs + const int col_offset = (tile_j == 0) ? 0 : row_pad_left; + const T* const input_base_col = ( + input_base + ((inner_tile_cols - tile_overlap)*tile_j - col_offset)*input_col_stride + ); + T* const outptr = matrix_base + tile_j*matrix_row_stride; + + // Apply the specific tile processing function + tile_fns[pad_top][t_pad_left][pad_bottom][t_pad_right]( + n_channels, + input_base_col, + input_row_stride, + input_col_stride, + outptr, + matrix_stride + ); + } + } + + /***************************************************************************/ + template + template + WinogradGEMM::InputTransform::InputTransform( + const T* const input, /** Input tensor data */ + const int n_batches, /** Number of batches in input tensor. */ + const int n_rows, /** Number of rows in input tensor. */ + const int n_cols, /** Number of columns in input tensor. */ + const int n_channels, /** Number of channels in input tensor. */ + const PaddingType padding, /** Padding type. */ + T* const output, /** Base of output matrices. */ + const int matrix_stride, /** Stride between output matrices. */ + const int matrix_row_stride /** Stride within matrices. */ + ) : _inptr(input), _outptr(output), + _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels), + _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride), + _tiles_M(iceildiv((padding == PADDING_SAME) ? n_rows : n_rows - 2, output_tile_rows)), + _tiles_N(iceildiv((padding == PADDING_SAME) ? n_cols : n_cols - 2, output_tile_cols)), + _padding_type(padding) + { + } + + template + template + unsigned int WinogradGEMM::InputTransform::get_window() const + { + // TODO When the input transform supports multithreading, return the total + // number of tile rows (allowing for multiple batches). For now we return 1 + // to indicate that the activations must be transformed as a single block. + return 1; // TODO _tiles_M * _n_batches; + } + + template + template + void WinogradGEMM::InputTransform::run( + const unsigned int start, const unsigned int stop + ) + { + // TODO When the input transform supports multithreading call execute for a + // portion of the tile rows. + (void) start; + (void) stop; + + // For now, just do all of the work. + const Tensor4DShape input_shape = { + _n_batches, _n_rows, _n_cols, _n_channels, NHWC + }; + execute( + _inptr, input_shape, _padding_type, _tiles_M, _tiles_N, _outptr, + _matrix_stride, _matrix_row_stride * _tiles_M * _tiles_N, _matrix_row_stride + ); + } +} diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp new file mode 100644 index 0000000000..4b54dfdf08 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "winograd_gemm.hpp" +using namespace winograd; + + +template +template +WinogradGEMM::WeightsTransform::WeightsTransform( + const T* const input, + T* const output, + const int matrix_stride, /** Stride across matrices in the output. */ + const int matrix_row_stride, /** Stride across rows of the matrix. */ + const int n_output_channels, + const int n_input_channels +) : inptr(input), outptr(output), + matrix_stride(matrix_stride), matrix_row_stride(matrix_row_stride), + n_output_channels(n_output_channels), n_input_channels(n_input_channels) +{ +} + + +template +template +unsigned int WinogradGEMM::WeightsTransform::get_window() const +{ + // TODO When the weights transform supports multithreading, return the number + // of output channels. For now we return 1 to indicate that the weights must + // be transformed as a single block. + // return n_output_channels; + return 1; +} + + +template +template +void WinogradGEMM::WeightsTransform::run( + const unsigned int start, const unsigned int stop +) +{ + // TODO When the weights transform supports multithreading call execute for a + // portion of the output channels. + (void) start; + (void) stop; + + // For now, just do all of the work. + execute( + n_output_channels, + n_input_channels, + inptr, + outptr, + matrix_stride, + matrix_row_stride + ); +} diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp new file mode 100644 index 0000000000..7fa5ee9617 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include "../winograd_gemm.hpp" + +namespace winograd +{ + template + template + void WinogradGEMM::OutputTransform::execute( + const Tensor4DShape &output_shape, + const T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output + ) + { + // Compute the number of tiles and hence the padding required on the bottom + // and right of the image. + const int tile_M = iceildiv(output_shape.n_rows, output_tile_rows); + const int tile_N = iceildiv(output_shape.n_cols, output_tile_cols); + const int pad_bottom = output_tile_rows*tile_M - output_shape.n_rows; + const int pad_right = output_tile_cols*tile_N - output_shape.n_cols; + + const int matrix_tile_row_stride = tile_N * matrix_row_stride; + const int matrix_batch_stride = tile_M * matrix_tile_row_stride; + const int output_col_stride = output_shape.n_channels; + const int output_row_stride = output_shape.n_cols * output_col_stride; + const int output_batch_stride = output_shape.n_rows * output_row_stride; + + // Perform the output transformation for each batch + for (int batch = 0; batch < output_shape.n_batches; batch++) + { + // Get batch offset for input and outputs. + const T* const matrix_batch = matrix_base + batch*matrix_batch_stride; + T* const outptr_batch = output + batch*output_batch_stride; + + // Perform the output transformation for each row of the output tensor. + for (int tile_i = 0; tile_i < tile_M; tile_i++) + { + // Compute properties of this row of output tiles + const int row_pad_bottom = (tile_i < tile_M - 1) ? 0: pad_bottom; + const T* const matrix_tile_row = matrix_batch + tile_i * matrix_tile_row_stride; + T* const outptr_row = outptr_batch + output_tile_rows*tile_i*output_row_stride; + + // Process the row + process_tile_row( + tile_N, output_shape.n_channels, matrix_tile_row, matrix_stride, + matrix_row_stride, outptr_row, output_row_stride, + output_col_stride, row_pad_bottom, pad_right + ); + } + } + } + + template + template + void WinogradGEMM::OutputTransform::process_tile_row( + const int tile_N, + const int n_channels, + const T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output, + const int output_row_stride, + const int output_col_stride, + const int row_pad_bottom, + const int row_pad_right + ) + { + // Loop over columns of tiles + for (int tile_j = 0; tile_j < tile_N; tile_j++) + { + // Properties of this tile + const int tile_pad_right = (tile_j < tile_N - 1) ? 0 : row_pad_right; + const T* const matrix_row = matrix_base + tile_j * matrix_row_stride; + T* const outptr = output + output_tile_cols*tile_j*output_col_stride; + + // Perform the output transformation + tile_fns[row_pad_bottom][tile_pad_right]( + n_channels, matrix_row, matrix_stride, + outptr, output_row_stride, output_col_stride + ); + } + } + + template + template + size_t WinogradGEMM::OutputTransform::bytes_read(const Tensor4DShape &shape) + { + const int M = iceildiv(shape.n_rows, output_tile_rows) * + iceildiv(shape.n_cols, output_tile_cols); + const int N = shape.n_channels; + return inner_tile_rows * inner_tile_cols * M * N * sizeof(T); + } + + template + template + size_t WinogradGEMM::OutputTransform::bytes_written(const Tensor4DShape &shape) + { + return shape.size() * sizeof(T); + } + + template + template + WinogradGEMM::OutputTransform::OutputTransform( + const T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output, + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels + ) : _matrix_base(matrix_base), _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride), + _outptr(output), _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels), + _tile_M(iceildiv(n_rows, output_tile_rows)), _tile_N(iceildiv(n_cols, output_tile_cols)) + { + } + + template + template + unsigned int WinogradGEMM::OutputTransform::get_window() const + { + // TODO When the output transform supports multithreading, return the total + // number of tile rows (allowing for multiple batches). For now we return 1 + // to indicate that the activations must be transformed as a single block. + return 1; // TODO _tile_M * _n_batches; + } + + template + template + void WinogradGEMM::OutputTransform::run( + const unsigned int start, const unsigned int stop + ) + { + // TODO When the output transform supports multithreading call execute for a + // portion of the tile rows. + (void) start; + (void) stop; + + // For now, just do all of the work. + const Tensor4DShape output_shape = { + _n_batches, _n_rows, _n_cols, _n_channels, NHWC + }; + execute( + output_shape, _matrix_base, _matrix_stride, _matrix_row_stride, _outptr + ); + } +} // namespace winograd diff --git a/arm_compute/core/NEON/kernels/winograd/utils.hpp b/arm_compute/core/NEON/kernels/winograd/utils.hpp new file mode 100644 index 0000000000..d8b9c3b7d3 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/utils.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +double TimeInUs(void); +void PrintMatrix(const float* const m, const int M, const int N, const int row_stride); + +inline int iceildiv(const int a, const int b) { + return (a + b - 1) / b; +} + +template +inline T roundup(const T a, const T b) { + return a + b - (a % b); +} diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp new file mode 100644 index 0000000000..adca48a6d6 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp @@ -0,0 +1,441 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include "alloc.hpp" +#include "convolution.hpp" +#include "gemm.hpp" +#include "profiler.hpp" +#include "shims.hpp" +#include "tensor.hpp" +#include "utils.hpp" + +#include +#include +#include + +// Generic Winograd implementation using GEMM +namespace winograd +{ + +template +class WinogradGEMM +{ + public: + // Information about the specific Winograd instance + static constexpr int output_tile_rows = OutputTileRows; + static constexpr int output_tile_cols = OutputTileCols; + static constexpr int kernel_rows = KernelRows; + static constexpr int kernel_cols = KernelCols; + static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1; // TODO Check + static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1; // TODO Check + static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols; + + /** Transform weights from the spatial to the Winograd domain. */ + template + struct WeightsTransform + { + /** Get the bytes read during the transform. */ + static inline size_t bytes_read(const KernelShape &shape) + { + return shape.size() * sizeof(T); + } + + /** Get the bytes written during the transform. */ + static inline size_t bytes_written(const KernelShape &shape) + { + const int inner_tile_size = inner_tile_rows * inner_tile_cols; + return (inner_tile_size * shape.n_input_channels * + shape.n_output_channels * sizeof(T)); + } + + /** Get the count of operations performed by the transform. */ + static int ops_performed(const KernelShape &shape); + + /** Apply the transform to a tensor. */ + static void execute( + const int n_output_channels, + const int n_input_channels, + const T* const input, + T* const output, + const int matrix_stride, + const int matrix_row_stride + ); + + /** Create a WeightsTransform operator fixed on a given problem and set + * of pointers. + */ + WeightsTransform( + const T* const input, + T* const output, + const int matrix_stride, /** Stride across matrices in the output. */ + const int matrix_row_stride, /** Stride across rows of the matrix. */ + const int n_output_channels, /** Number of filters. */ + const int n_input_channels /** Number of channels in each filter. */ + ); + + /** Get the window of work a given operator can perform. */ + unsigned int get_window() const; + + /** Perform work upon a window of the input. */ + void run(const unsigned int start, const unsigned int stop); + + private: + const T* const inptr; /** Fixed pointer to input data. */ + T* const outptr; /** Fixed pointer to output memory. */ + const int matrix_stride; /** Stride between output matrices. */ + const int matrix_row_stride; /** Stride within output matrices. */ + const int n_output_channels; /** Number of filters. */ + const int n_input_channels; /** Number of channels in each filter. */ + }; + + /** Transform input feature maps from the spatial to the Winograd domain. + */ + template + struct InputTransform + { + /** Get the bytes read during the transform. */ + static size_t bytes_read(const Tensor4DShape &shape) + { + return shape.size() * sizeof(T); + } + + /** Get the bytes written during the transform. */ + static size_t bytes_written(const Tensor4DShape &shape) + { + const int M = iceildiv(shape.n_rows, inner_tile_rows) * + iceildiv(shape.n_cols, inner_tile_cols); + const int K = shape.n_channels; + return inner_tile_rows * inner_tile_cols * M * K * sizeof(T); + } + + /** Get the count of operations performed by the transform. */ + static int ops_performed(const Tensor4DShape &shape); + + /** Apply the transform to a tensor. */ + static void execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride + ); + + /***********************************************************************/ + /** Create an InputTransform operator fixed on a given problem and set of + * pointers. + */ + InputTransform( + const T* const input, /** Input tensor data */ + const int n_batches, /** Number of batches in input tensor. */ + const int n_rows, /** Number of rows in input tensor. */ + const int n_cols, /** Number of columns in input tensor. */ + const int n_channels, /** Number of channels in input tensor. */ + const PaddingType padding, /** Padding type. */ + T* const output, /** Base of output matrices. */ + const int matrix_stride, /** Stride between output matrices. */ + const int matrix_row_stride /** Stride within matrices. */ + ); + + /** Get the winodw of work a given operator can perform. */ + unsigned int get_window() const; + + /** Perform work upon a window of the input. */ + void run(const unsigned int start, const unsigned int stop); + /***********************************************************************/ + + private: + static void process_tile_row( + const int tile_N, + int n_channels, + const T* const input_base, + const int input_row_stride, + const int input_col_stride, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + const int row_pad_top, + const int row_pad_left, + const int row_pad_bottom, + const int row_pad_right + ); + + static constexpr int max_pad_bottom = inner_tile_rows - 1; + static constexpr int max_pad_right = inner_tile_cols - 1; + + /** Process a single tile of the input tensor. */ + template + static void process_tile(int, const T*, int, int, T*, int); + + // Array of methods to transform tiles of the input tensor. + typedef void (*TileFn)(int, const T*, int, int, T*, int); + static const TileFn tile_fns[2][2][max_pad_bottom][max_pad_right]; + + /* Member values for instance-based API. */ + const T* const _inptr; + T* const _outptr; + const int _n_batches, _n_rows, _n_cols, _n_channels, _matrix_stride, + _matrix_row_stride, _tiles_M, _tiles_N; + const PaddingType _padding_type; + }; + + /** Transform output feature maps from the Winograd to the spatial domain. + */ + template + struct OutputTransform + { + /** Get the bytes read during the transform. */ + static size_t bytes_read(const Tensor4DShape &shape); + + /** Get the bytes written during the transform. */ + static size_t bytes_written(const Tensor4DShape &shape); + + /** Get the count of operations performed by the transform. */ + static int ops_performed(const Tensor4DShape &shape); + + /** Apply the transform to create a tensor. */ + static void execute( + const Tensor4DShape &output_shape, + const T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output + ); + + /***********************************************************************/ + /** Create an OutputTransform operator fixed on a given problem and set + * of pointers. + */ + OutputTransform( + const T* const matrix_base, /** Pointer to base of matrices. */ + const int matrix_stride, /** Stride between matrices. */ + const int matrix_row_stride, /** Stride within a matrix. */ + T* const output, /** Pointer to output tensor. */ + const int n_batches, /** Number of batches in output tensor. */ + const int n_rows, /** Number of rows in output tensor. */ + const int n_cols, /** Number of columns in output tensor. */ + const int n_channels /** Number of channels in output tensor. */ + ); + + /** Get the window of work a given operator can perform. */ + unsigned int get_window() const; + + /** Perform work upon a window of the input. */ + void run(const unsigned int start, const unsigned int stop); + /***********************************************************************/ + + private: + static void process_tile_row( + const int tile_N, + const int n_channels, + const T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output, + const int output_row_stride, + const int output_col_stride, + const int row_pad_bottom, + const int row_pad_right + ); + + // Limits on the amount of anti-padding to be applied + static constexpr int max_pad_bottom = output_tile_rows; + static constexpr int max_pad_right = output_tile_cols; + + /** Prepare a single tile of the output tensor. */ + template + static void process_tile(int, const T*, int, T*, int, int); + + // Array of methods to produce tiles of output tensor. + typedef void (*TileFn)(int, const T*, int, T*, int, int); + static const TileFn tile_fns[max_pad_bottom][max_pad_right]; + + /** Member constants for instances of the transform. */ + const T* const _matrix_base; + const int _matrix_stride, _matrix_row_stride; + T* const _outptr; + const int _n_batches, _n_rows, _n_cols, _n_channels, _tile_M, _tile_N; + }; + + /** Perform a convolution. + */ + template + class Convolution + { + public: + // Information about the typed Winograd instance + typedef TOut OutputType; + typedef TIn InputType; + + /** Create a new Winograd operator. */ + Convolution( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding, + void *kernel_storage=NULL + ); + + Convolution(const Convolution&) = delete; + Convolution operator=(const Convolution&) = delete; + + /** Create a new Winograd operator and initialise the weights. */ + Convolution( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding, + const TIn* const kernel, + void *kernel_storage=NULL, + void *transform_working_space=NULL + ); + + /** Clean up a convolution engine. */ + ~Convolution(); + + /** Transform the weights into the Winograd domain. */ + template > + void transform_weights( + const TIn* const kernel, + void *transform_working_space=NULL + ); + + /* Apply the Winograd operator to some input. */ + void execute( + TOut* const output, + const TIn* const input, + void* working_space=NULL, + const int n_threads=1 + ); + + /* Apply the Winograd operator to some input. */ + void execute( + TOut* const output, + const TIn* const input, + const int n_threads + ); + + /** Get the output shape of a convolution. */ + static Tensor4DShape get_output_shape( + const KernelShape &kernel_shape, + const Tensor4DShape &in_shape, + const PaddingType padding + ); + + /* Get the memory required to transform the kernel. + */ + static size_t get_kernel_transform_working_size(const KernelShape &shape); + + /** Get the memory required to store the kernel transformed into the + * Winograd domain. + */ + static size_t get_kernel_storage_size(const KernelShape &shape); + + /** Get the memory required to store the input tensor transformed into + * the Winograd domain. + */ + static size_t get_input_storage_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + /** Get the memory required to store the output tensor in the Winograd + * domain. + */ + static size_t get_output_storage_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + /** Get the memory required to apply a Winograd operator to some input. + */ + static size_t get_working_space_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + /* Get the memory required by a single "input" matrix. + */ + static size_t get_input_matrix_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + static int get_input_matrix_stride( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + /* Get the memory required by a single "output" matrix. + */ + static size_t get_output_matrix_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + static int get_output_matrix_stride( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + /* Get the memory required by a single "kernel" matrix. + */ + static size_t get_kernel_matrix_size(const KernelShape &shape); + static int get_kernel_matrix_stride(const KernelShape &shape); + + static constexpr int M_BLOCK = 4; /** Size of block used by GEMM. */ + static constexpr int N_BLOCK = 16; /** Size of block used by GEMM. */ + + private: + const KernelShape kernel_shape; /** Shape of the kernel to be applied. */ + TIn *kernel_matrices[N_GEMMS]; /** Pointers into the kernel matrices. */ + const int kernel_matrix_row_stride; /** Stride within the kernel matrices. */ + + const bool manage_kernel_storage; /** Kernel storage is managed by the instance. */ + void* const _kernel_storage; /** Base pointer for kernel storage. */ + + const Tensor4DShape input_shape; /** Shape of the input tensor. */ + const PaddingType padding; /** Padding applied by the operator. */ + + const Tensor4DShape output_shape; /** Output shape produced by the operator. */ + + const int tile_rows; /** Number of rows of tiles. */ + const int tile_cols; /** Number of columns of tiles. */ + const int M, K, N; /** Sizes of underlying fundamental matrix multiplications. */ + + profiler prof; + }; +}; + +} // namespace winograd diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp new file mode 100644 index 0000000000..a3b3db42dd --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include + +#include "batched_blocked_gemm.hpp" +#include "winograd_gemm.hpp" + +/** Example of how to construct an ACL-like interface. + * + * Use `get_weight_storage_size`, `get_input_storage_size` and + * `get_output_storage_size` to allocate memory for the convolution engine. + * Then create a `WinogradConvolutionLayer`. + * + * Initialise the weights using `weights_transform.run(...)`. + * + * For each inference: + * 1. Transform the inputs to the Winograd domain using `input_transform.run(...)` + * 2. Perform a number of GEMMs using `gemms.run(...)` + * 3. Transform the output to the spatial domain using `output_transform.run(...)` + */ +template +class WinogradConvolutionLayer +{ + private: + const KernelShape _kernel_shape; + const Tensor4DShape _input_shape; + const PaddingType _padding; + const Tensor4DShape _output_shape; + const int _n_output_rows, _n_output_cols; + const int _kernel_matrix_stride, _kernel_matrix_row_stride; + const int _input_matrix_stride, _input_matrix_row_stride; + const int _output_matrix_stride, _output_matrix_row_stride; + const int _tile_rows, _tile_cols; + const int _m, _k, _n; + + public: + using WinogradBase = winograd::WinogradGEMM; + using WeightsTransform = typename WinogradBase::template WeightsTransform; + using InputTransform = typename WinogradBase::template InputTransform; + using WinogradConv = typename WinogradBase::template Convolution; + using MultiGEMM = winograd::BatchedBlockedGemm; + using OutputTransform = typename WinogradBase::template OutputTransform; + + /* Public member variables. */ + WeightsTransform weights_transform; /** Operator to transform weights to Winograd domain. */ + InputTransform input_transform; /** Operator to transform input to Winograd domain. */ + MultiGEMM gemms; /** Operator to perform multiple GEMMs. */ + OutputTransform output_transform; /** Operator to transform output from Winograd domain. */ + + /** Determine how much memory (in units of TIn) to allocate for the + * transformed weights. + */ + static unsigned int get_weight_storage_size( + const int n_output_channels, /** Number of output feature maps. */ + const int n_input_channels /** Number of input feature maps. */ + ); + + /** Determine how much memory (in units of TIn) to allocate for the + * transformed input. + */ + static unsigned int get_input_storage_size( + const int n_batches, /** Number of batches in the input tensor. */ + const int n_channels, /** Number of feature maps in the input tensor. */ + const int n_rows, /** Number of rows in each feature map. */ + const int n_cols, /** Number of columns in each feature map. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ + ); + + /** Determine how much memory (in units of TOut) to allocate for the + * (Winograd domain) output. + */ + static unsigned int get_output_storage_size( + const int n_batches, /** Number of batches in the output tensor. */ + const int n_rows, /** Number of rows in each feature map of the input tensor. */ + const int n_cols, /** Number of columns in each feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ + ); + + /** Get the shape (rows, cols) of a feature map of the output tensor. */ + static std::pair get_output_feature_map_shape( + const int n_input_rows, /** Number of rows in the input feature map. */ + const int n_input_cols, /** Number of columns in the input feature map. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ + ); + + /** Create a new Winograd convolution layer. + */ + WinogradConvolutionLayer( + const int n_batches, /** Number of batches in the input and output tensors. */ + const int n_input_channels, /** Number of feature maps in a batch of the input tensor. */ + const int n_input_rows, /** Number of rows in a feature map of the input tensor. */ + const int n_input_cols, /** Number of columns in a feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding, /** Use "SAME" padding, otherwise use "VALID". */ + const TIn* const weights, /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */ + TIn* const weights_storage, /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */ + const TIn* const input, /** Pointer to NHWC ordered input tensor, in the spatial domain. */ + TIn* const winograd_input, /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */ + TOut* const output, /** Pointer to NHWC ordered output tensor, in the spatial domain. */ + TOut* const winograd_output /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */ + ); +}; -- cgit v1.2.1