From 9ceebbeb8dfe61746fdc7022a147f8e2d24c5493 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Wed, 10 Jan 2018 16:44:13 +0000 Subject: COMPMID-815: Updated NEWinogradLayer with the lastest code from Research. Change-Id: I86d7f53b5f5d1dbc22078aea5c32b08a25d1f49e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/116634 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- SConscript | 5 + .../core/NEON/kernels/NEWinogradLayerKernel.h | 53 +- arm_compute/core/NEON/kernels/winograd/alloc.hpp | 1 + arm_compute/core/NEON/kernels/winograd/arm.hpp | 39 + .../NEON/kernels/winograd/batched_blocked_gemm.hpp | 69 + .../core/NEON/kernels/winograd/convolution.hpp | 29 + .../NEON/kernels/winograd/direct_convolution.hpp | 34 + arm_compute/core/NEON/kernels/winograd/gemm.hpp | 127 ++ .../core/NEON/kernels/winograd/gemm/a64_sgemm.hpp | 355 +++++ .../NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp | 1446 +++++++++++++++++++ arm_compute/core/NEON/kernels/winograd/perf.h | 9 + .../core/NEON/kernels/winograd/profiler.hpp | 326 +++++ arm_compute/core/NEON/kernels/winograd/shims.hpp | 747 ++++++++++ arm_compute/core/NEON/kernels/winograd/tensor.hpp | 225 ++- .../core/NEON/kernels/winograd/tensor_utils.hpp | 43 + .../NEON/kernels/winograd/transforms/input.hpp | 195 +++ .../NEON/kernels/winograd/transforms/kernel.hpp | 77 + .../NEON/kernels/winograd/transforms/output.hpp | 174 +++ arm_compute/core/NEON/kernels/winograd/utils.hpp | 37 + .../core/NEON/kernels/winograd/winograd_gemm.hpp | 441 ++++++ .../core/NEON/kernels/winograd/winograd_layer.hpp | 128 ++ .../runtime/NEON/functions/NEWinogradLayer.h | 5 +- src/core/NEON/kernels/NEWinogradLayerKernel.cpp | 87 +- .../NEON/kernels/winograd/batched_blocked_gemm.cpp | 81 ++ src/core/NEON/kernels/winograd/gemm.hpp | 127 -- src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp | 355 ----- .../NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp | 1445 ------------------- src/core/NEON/kernels/winograd/perf.h | 32 - src/core/NEON/kernels/winograd/profiler.hpp | 244 ---- src/core/NEON/kernels/winograd/shims.hpp | 319 ----- src/core/NEON/kernels/winograd/transforms.hpp | 29 - .../kernels/winograd/transforms/input_2x2_3x3.hpp | 639 --------- .../transforms/input_2x2_3x3/a64_float.hpp | 1498 -------------------- .../input_2x2_3x3/a64_float_channelwise.hpp | 961 ------------- .../winograd/transforms/input_2x2_3x3_fp32.cpp | 409 ++++++ .../winograd/transforms/input_4x4_3x3_fp32.cpp | 486 +++++++ .../kernels/winograd/transforms/kernel_2x2_3x3.hpp | 195 --- .../transforms/kernel_2x2_3x3/a64_float.hpp | 822 ----------- .../kernels/winograd/transforms/output_2x2_3x3.hpp | 356 ----- .../transforms/output_2x2_3x3/a64_float.hpp | 650 --------- .../output_2x2_3x3/a64_float_two_stage.hpp | 655 --------- .../winograd/transforms/output_2x2_3x3_fp32.cpp | 238 ++++ .../winograd/transforms/output_4x4_3x3_fp32.cpp | 299 ++++ .../winograd/transforms/weights_2x2_3x3_fp32.cpp | 228 +++ .../winograd/transforms/weights_4x4_3x3_fp32.cpp | 266 ++++ src/core/NEON/kernels/winograd/utils.cpp | 50 + src/core/NEON/kernels/winograd/utils.hpp | 55 - src/core/NEON/kernels/winograd/winograd_gemm.cpp | 560 ++++++++ src/core/NEON/kernels/winograd/winograd_gemm.hpp | 345 ----- src/core/NEON/kernels/winograd/winograd_layer.cpp | 204 +++ src/runtime/NEON/functions/NEWinogradLayer.cpp | 113 +- 51 files changed, 7356 insertions(+), 8957 deletions(-) create mode 100644 arm_compute/core/NEON/kernels/winograd/arm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/convolution.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/gemm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/perf.h create mode 100644 arm_compute/core/NEON/kernels/winograd/profiler.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/shims.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/input.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/transforms/output.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/utils.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp create mode 100644 arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp create mode 100644 src/core/NEON/kernels/winograd/batched_blocked_gemm.cpp delete mode 100644 src/core/NEON/kernels/winograd/gemm.hpp delete mode 100644 src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp delete mode 100644 src/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp delete mode 100644 src/core/NEON/kernels/winograd/perf.h delete mode 100644 src/core/NEON/kernels/winograd/profiler.hpp delete mode 100644 src/core/NEON/kernels/winograd/shims.hpp delete mode 100644 src/core/NEON/kernels/winograd/transforms.hpp delete mode 100644 src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp delete mode 100644 src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp delete mode 100644 src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp create mode 100644 src/core/NEON/kernels/winograd/transforms/input_2x2_3x3_fp32.cpp create mode 100644 src/core/NEON/kernels/winograd/transforms/input_4x4_3x3_fp32.cpp delete mode 100644 src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp delete mode 100644 src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp delete mode 100644 src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp delete mode 100644 src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp delete mode 100644 src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp create mode 100644 src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp create mode 100644 src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp create mode 100644 src/core/NEON/kernels/winograd/transforms/weights_2x2_3x3_fp32.cpp create mode 100644 src/core/NEON/kernels/winograd/transforms/weights_4x4_3x3_fp32.cpp create mode 100644 src/core/NEON/kernels/winograd/utils.cpp delete mode 100644 src/core/NEON/kernels/winograd/utils.hpp create mode 100644 src/core/NEON/kernels/winograd/winograd_gemm.cpp delete mode 100644 src/core/NEON/kernels/winograd/winograd_gemm.hpp create mode 100644 src/core/NEON/kernels/winograd/winograd_layer.cpp diff --git a/SConscript b/SConscript index c7779ca8f7..c9f6d0821e 100644 --- a/SConscript +++ b/SConscript @@ -175,6 +175,11 @@ if env['neon']: core_files += Glob('src/core/NEON/*.cpp') core_files += Glob('src/core/NEON/kernels/*.cpp') + # build winograd sources for either v7a / v8a + core_files += Glob('src/core/NEON/kernels/winograd/*.cpp') + core_files += Glob('src/core/NEON/kernels/winograd/transforms/*.cpp') + arm_compute_env.Append(CPPPATH = ["arm_compute/core/NEON/kernels/winograd/"]) + if env['arch'] == "armv7a": core_files += Glob('src/core/NEON/kernels/arm32/*.cpp') diff --git a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h index 73b7e8d2b7..95261929ca 100644 --- a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h +++ b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, 2018 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -25,6 +25,7 @@ #define __ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__ #include "arm_compute/core/NEON/INEKernel.h" +#include "arm_compute/core/NEON/kernels/winograd/convolution.hpp" #include "arm_compute/core/NEON/kernels/winograd/tensor.hpp" namespace arm_compute @@ -36,11 +37,25 @@ class Winograd3x3F32 final { public: friend class NEWinogradLayerKernel; - Winograd3x3F32(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage); + Winograd3x3F32( + const int n_batches, /** Number of batches in the input and output tensors. */ + const int n_input_channels, /** Number of feature maps in a batch of the input tensor. */ + const int n_input_rows, /** Number of rows in a feature map of the input tensor. */ + const int n_input_cols, /** Number of columns in a feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding, /** Use "SAME" padding, otherwise use "VALID". */ + const float *const weights, /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */ + float *const weights_storage, /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */ + const float *const input, /** Pointer to NHWC ordered input tensor, in the spatial domain. */ + float *const winograd_input, /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */ + float *const output, /** Pointer to NHWC ordered output tensor, in the spatial domain. */ + float *const winograd_output /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */ + ); + ~Winograd3x3F32(); - void transform_weights(const void *const kernel, void *transform_working_space); - void reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const void *const input, void *working_space); - void reshape_output(const Tensor4DShape &input_shape, const PaddingType padding_type, void *const output); + void transform_weights(); + void transform_input(); + void transform_output(); private: class Private; @@ -75,15 +90,29 @@ public: /* Get the memory required to instantiate a new Winograd operator. */ - static size_t get_kernel_storage_size(const KernelShape &shape); + static size_t get_weight_storage_size( + const int n_output_channels, /** Number of output feature maps. */ + const int n_input_channels /** Number of input feature maps. */ + ); - /* Get the memory required to apply a Winograd operator to some input. - */ - static size_t get_working_space_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, const PaddingType padding); + static unsigned int get_input_storage_size( + const int n_batches, /** Number of batches in the input tensor. */ + const int n_channels, /** Number of feature maps in the input tensor. */ + const int n_rows, /** Number of rows in each feature map. */ + const int n_cols, /** Number of columns in each feature map. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ + ); - /* Get the memory required to transform the kernel. - */ - static size_t get_kernel_transform_working_size(const KernelShape &shape); + /** Determine how much memory (in units of TOut) to allocate for the + * (Winograd domain) output. + */ + static unsigned int get_output_storage_size( + const int n_batches, /** Number of batches in the output tensor. */ + const int n_rows, /** Number of rows in each feature map of the input tensor. */ + const int n_cols, /** Number of columns in each feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ + ); protected: Winograd3x3F32 *_convolver; diff --git a/arm_compute/core/NEON/kernels/winograd/alloc.hpp b/arm_compute/core/NEON/kernels/winograd/alloc.hpp index ef6f2b5115..799e95d3e6 100644 --- a/arm_compute/core/NEON/kernels/winograd/alloc.hpp +++ b/arm_compute/core/NEON/kernels/winograd/alloc.hpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + #pragma once #ifdef ALLOC_ALIGN diff --git a/arm_compute/core/NEON/kernels/winograd/arm.hpp b/arm_compute/core/NEON/kernels/winograd/arm.hpp new file mode 100644 index 0000000000..90e7828553 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/arm.hpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** Sets the macro __arm_any__ if compiling for Aarch32 or Aarch64. + * Includes `arm_neon.h` if compiling for either architecture. + */ + +#ifdef __arm__ +#define __arm_any__ +#endif // __arm__ + +#ifdef __aarch64__ +#define __arm_any__ +#endif // __aarch64__ + +#ifdef __arm_any__ +#include +#endif // __arm_any__ diff --git a/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp new file mode 100644 index 0000000000..663b3c414f --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +namespace winograd +{ + +template +class BatchedBlockedGemm +{ + public: + /** Create a new batched blocked GEMM operator. */ + BatchedBlockedGemm( + const unsigned int n_gemms, + const int M, const int K, const int N, + const int a_matrix_stride, + const int a_row_stride, + const int b_matrix_stride, + const int b_row_stride, + const int c_matrix_stride, + const int c_row_stride, + const TIn* const a_ptr, + const TIn* const b_ptr, + TOut* const c_ptr + ); + + BatchedBlockedGemm(const BatchedBlockedGemm&) = delete; + BatchedBlockedGemm operator=(const BatchedBlockedGemm&) = delete; + + /** Get a window of work performed by the operator. */ + unsigned int get_window() const; + + /** Perform a portion of the work of the operator. */ + void run(const unsigned int start, const unsigned int stop); + + private: + const unsigned int n_gemms; + const int M, N, K; + const int a_matrix_stride, a_row_stride; + const int b_matrix_stride, b_row_stride; + const int c_matrix_stride, c_row_stride; + const TIn* const a_ptr; + const TIn* const b_ptr; + TOut* const c_ptr; +}; + +} // namespace winograd diff --git a/arm_compute/core/NEON/kernels/winograd/convolution.hpp b/arm_compute/core/NEON/kernels/winograd/convolution.hpp new file mode 100644 index 0000000000..2ab2597785 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/convolution.hpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +enum PaddingType { + PADDING_SAME, PADDING_VALID +}; diff --git a/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp b/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp new file mode 100644 index 0000000000..725f6cab65 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include "convolution.hpp" +#include "tensor.hpp" + +void direct_convolution( + const Tensor4D& input, + const Tensor4D& kernel, + Tensor4D& output, + const PaddingType padding +); diff --git a/arm_compute/core/NEON/kernels/winograd/gemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm.hpp new file mode 100644 index 0000000000..e48d31b4e6 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/gemm.hpp @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include "utils.hpp" + +template +inline void Gemm(const TIn* const a, const TIn* const b, TOut *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride, + const bool a_transposed=false, + const bool b_transposed=false) { + // Array access methods + const auto A = [a, a_transposed, M, K, a_row_stride] (const int i, const int j) -> TIn { + return a[(!a_transposed) ? i*a_row_stride + j : i + j*M]; + }; + + const auto B = [b, b_transposed, K, N, b_row_stride] (const int i, const int j) -> TIn { + return b[(!b_transposed) ? i*b_row_stride + j : i + j*N]; + }; + + const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& { + return c[i*c_row_stride + j]; + }; + + // Perform the matrix multiplication + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < K; k++) { + C(i, j) += A(i, k) * B(k, j); + } + } + } +} + +template +inline void BlockedGemm( + const TIn* const a, const TIn* const b, TOut *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + // Array access methods + const auto A = [a, M, K, a_row_stride] (const int i, const int j) -> TIn { + return a[i*a_row_stride + j]; + }; + + const auto B = [b, K, N, b_row_stride] (const int i, const int j) -> TIn { + return b[i*b_row_stride + j]; + }; + + const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& { + return c[i*c_row_stride + j]; + }; + + const int M_BLOCKS = iceildiv(M, M_BLOCK); + const int N_BLOCKS = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < M_BLOCKS; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < N_BLOCKS; nblock++) { + // Create an appropriately sized block of accumulators + TOut accum[M_BLOCK][N_BLOCK]; + for (int i = 0; i < M_BLOCK; i++) { + for (int j = 0; j < N_BLOCK; j++) { + accum[i][j] = static_cast(0); + } + } + + // Perform this portion of the matrix multiply + for (int k = 0; k < K; k++) { + // Load elements of A + TIn elems_a[M_BLOCK]; + for (int i = 0; i < M_BLOCK; i++) { + elems_a[i] = A(mblock*M_BLOCK + i, k); + } + + // Load elements of B + TIn elems_b[N_BLOCK]; + for (int j = 0; j < N_BLOCK; j++) { + elems_b[j] = B(k, nblock*N_BLOCK + j); + } + + // Perform the partial matrix multiply + for (int i = 0; i < M_BLOCK; i++) { + for (int j = 0; j < N_BLOCK; j++) { + accum[i][j] += elems_a[i] * elems_b[j]; + } + } + } + + // Store the partial product + for (int i = 0; i < M_BLOCK; i++) { + for (int j = 0; j < N_BLOCK; j++) { + C(mblock*M_BLOCK + i, nblock*N_BLOCK + j) = accum[i][j]; + } + } + } + } +} + +#include "gemm/a64_sgemm.hpp" diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp new file mode 100644 index 0000000000..caeb48f65a --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp @@ -0,0 +1,355 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include +#include "../utils.hpp" + +#ifdef __aarch64__ + +template <> +inline void BlockedGemm<8, 12, float, float>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int M_BLOCK = 8; + const int N_BLOCK = 12; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = K; + + asm volatile ( + // Create an 8x12 block of accumulators + " A_1 .req v27\n" + "sA_1 .req s27\n" + " A_2 .req v28\n" + "sA_2 .req s28\n" + " A_3 .req v29\n" + "sA_3 .req s29\n" + " A_4 .req v30\n" + "sA_4 .req s30\n" + + " B_1 .req v24\n" " B_2 .req v25\n" " B_3 .req v26\n" + "qB_1 .req q24\n" "qB_2 .req q25\n" "qB_3 .req q26\n" + + " C_11 .req v0\n" " C_12 .req v1\n" " C_13 .req v2\n" + " C_21 .req v3\n" " C_22 .req v4\n" " C_23 .req v5\n" + " C_31 .req v6\n" " C_32 .req v7\n" " C_33 .req v8\n" + " C_41 .req v9\n" " C_42 .req v10\n" " C_43 .req v11\n" + " C_51 .req v12\n" " C_52 .req v13\n" " C_53 .req v14\n" + " C_61 .req v15\n" " C_62 .req v16\n" " C_63 .req v17\n" + " C_71 .req v18\n" " C_72 .req v19\n" " C_73 .req v20\n" + " C_81 .req v21\n" " C_82 .req v22\n" " C_83 .req v23\n" + + "qC_11 .req q0\n" "qC_12 .req q1\n" "qC_13 .req q2\n" + "qC_21 .req q3\n" "qC_22 .req q4\n" "qC_23 .req q5\n" + "qC_31 .req q6\n" "qC_32 .req q7\n" "qC_33 .req q8\n" + "qC_41 .req q9\n" "qC_42 .req q10\n" "qC_43 .req q11\n" + "qC_51 .req q12\n" "qC_52 .req q13\n" "qC_53 .req q14\n" + "qC_61 .req q15\n" "qC_62 .req q16\n" "qC_63 .req q17\n" + "qC_71 .req q18\n" "qC_72 .req q19\n" "qC_73 .req q20\n" + "qC_81 .req q21\n" "qC_82 .req q22\n" "qC_83 .req q23\n" + + "aptr1 .req x17\n" + "aptr2 .req x18\n" + "aptr3 .req x19\n" + "aptr4 .req x20\n" + "aptr5 .req x21\n" + "aptr6 .req x22\n" + "aptr7 .req x23\n" + + // Initialise accumulators with 0 + // Initialise pointers + "movi C_11.4s, #0\n" + "add aptr1, %x[aptr], %x[a_row_stride]\n" + "movi C_12.4s, #0\n" + "add aptr2, aptr1, %x[a_row_stride]\n" + "movi C_13.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride]\n" + "movi C_21.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride]\n" + "movi C_22.4s, #0\n" + "add aptr5, aptr4, %x[a_row_stride]\n" + "movi C_23.4s, #0\n" + "add aptr6, aptr5, %x[a_row_stride]\n" + "movi C_31.4s, #0\n" + "add aptr7, aptr6, %x[a_row_stride]\n" + "movi C_32.4s, #0\n" + "ldr qB_1, [%x[bptr]]\n" + "movi C_33.4s, #0\n" + "ldr qB_2, [%x[bptr], #0x10]\n" + "movi C_41.4s, #0\n" + "prfm pldl1keep, [%x[bptr], #0x00]\n" + "movi C_42.4s, #0\n" + "prfm pldl1keep, [%x[bptr], #0x10]\n" + "movi C_43.4s, #0\n" + "prfm pldl1keep, [%x[bptr], #0x20]\n" + "movi C_51.4s, #0\n" + "prfm pldl1keep, [%x[aptr], #0x00]\n" + "movi C_52.4s, #0\n" + "prfm pldl1keep, [ aptr1, #0x00]\n" + "movi C_53.4s, #0\n" + "prfm pldl1keep, [ aptr2, #0x00]\n" + "movi C_61.4s, #0\n" + "prfm pldl1keep, [ aptr3, #0x00]\n" + "movi C_62.4s, #0\n" + "prfm pldl1keep, [ aptr4, #0x00]\n" + "movi C_63.4s, #0\n" + "prfm pldl1keep, [ aptr5, #0x00]\n" + "movi C_71.4s, #0\n" + "prfm pldl1keep, [ aptr6, #0x00]\n" + "movi C_72.4s, #0\n" + "prfm pldl1keep, [ aptr7, #0x00]\n" + "movi C_73.4s, #0\n" + "ldr sA_1, [%x[aptr]], #0x4\n" + "movi C_81.4s, #0\n" + "ldr sA_2, [ aptr1], #0x4\n" + "movi C_82.4s, #0\n" + "ldr sA_3, [ aptr2], #0x4\n" + "movi C_83.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 2f\n" + + "1:" + "fmla C_11.4s, B_1.4s, A_1.s[0]\n" + "ldr qB_3, [%x[bptr], #0x20]\n" + "fmla C_12.4s, B_2.4s, A_1.s[0]\n" + "ldr sA_4, [ aptr3], #0x4\n" + "fmla C_13.4s, B_3.4s, A_1.s[0]\n" + "ldr sA_1, [ aptr4], #0x04\n" + + "fmla C_21.4s, B_1.4s, A_2.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride]\n" + "fmla C_22.4s, B_2.4s, A_2.s[0]\n" + "prfm pldl1keep, [ aptr3, #0x10]\n" + "fmla C_23.4s, B_3.4s, A_2.s[0]\n" + "ldr sA_2, [ aptr5], #0x04\n" + + "fmla C_31.4s, B_1.4s, A_3.s[0]\n" + "prfm pldl1keep, [%x[bptr], #0x00]\n" + "fmla C_32.4s, B_2.4s, A_3.s[0]\n" + "prfm pldl1keep, [%x[bptr], #0x10]\n" + "fmla C_33.4s, B_3.4s, A_3.s[0]\n" + "ldr sA_3, [ aptr6], #0x04\n" + + "fmla C_41.4s, B_1.4s, A_4.s[0]\n" + "prfm pldl1keep, [%x[bptr], #0x20]\n" + "fmla C_42.4s, B_2.4s, A_4.s[0]\n" + "prfm pldl1keep, [ aptr4, #0x10]\n" + "fmla C_43.4s, B_3.4s, A_4.s[0]\n" + "ldr sA_4, [ aptr7], #0x04\n" + + "fmla C_51.4s, B_1.4s, A_1.s[0]\n" + "prfm pldl1keep, [ aptr5, #0x10]\n" + "fmla C_52.4s, B_2.4s, A_1.s[0]\n" + "prfm pldl1keep, [ aptr6, #0x10]\n" + "fmla C_53.4s, B_3.4s, A_1.s[0]\n" + "ldr sA_1, [%x[aptr]], #0x04\n" + + "fmla C_61.4s, B_1.4s, A_2.s[0]\n" + "prfm pldl1keep, [ aptr7, #0x10]\n" + "fmla C_62.4s, B_2.4s, A_2.s[0]\n" + "subs %x[k], %x[k], #1\n" + "fmla C_63.4s, B_3.4s, A_2.s[0]\n" + "ldr sA_2, [ aptr1], #0x04\n" + + "fmla C_71.4s, B_1.4s, A_3.s[0]\n" + "prfm pldl1keep, [%x[aptr], #0x10]\n" + "fmla C_72.4s, B_2.4s, A_3.s[0]\n" + "prfm pldl1keep, [ aptr1, #0x10]\n" + "fmla C_73.4s, B_3.4s, A_3.s[0]\n" + "ldr sA_3, [ aptr2], #0x04\n" + + "fmla C_81.4s, B_1.4s, A_4.s[0]\n" + "prfm pldl1keep, [ aptr2, #0x10]\n" + "fmla C_82.4s, B_2.4s, A_4.s[0]\n" + "ldp qB_1, qB_2, [%x[bptr]]\n" + "fmla C_83.4s, B_3.4s, A_4.s[0]\n" + "bne 1b\n" + + "2:" + "fmla C_11.4s, B_1.4s, A_1.s[0]\n" + "ldr qB_3, [%x[bptr], #0x20]\n" + "fmla C_12.4s, B_2.4s, A_1.s[0]\n" + "stp qC_11, qC_12, [%x[cptr]]\n" + "fmla C_13.4s, B_3.4s, A_1.s[0]\n" + "str qC_13, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_1, [ aptr4], #0x04\n" + + "fmla C_21.4s, B_1.4s, A_2.s[0]\n" + "ldr sA_4, [ aptr3], #0x4\n" + "fmla C_22.4s, B_2.4s, A_2.s[0]\n" + "stp qC_21, qC_22, [%x[cptr]]\n" + "fmla C_23.4s, B_3.4s, A_2.s[0]\n" + "str qC_23, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_2, [ aptr5], #0x04\n" + + "fmla C_31.4s, B_1.4s, A_3.s[0]\n" + "fmla C_32.4s, B_2.4s, A_3.s[0]\n" + "stp qC_31, qC_32, [%x[cptr]]\n" + "fmla C_33.4s, B_3.4s, A_3.s[0]\n" + "str qC_33, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_3, [ aptr6], #0x04\n" + + "fmla C_41.4s, B_1.4s, A_4.s[0]\n" + "fmla C_42.4s, B_2.4s, A_4.s[0]\n" + "stp qC_41, qC_42, [%x[cptr]]\n" + "fmla C_43.4s, B_3.4s, A_4.s[0]\n" + "str qC_43, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + "ldr sA_4, [ aptr7], #0x04\n" + + "fmla C_51.4s, B_1.4s, A_1.s[0]\n" + "fmla C_52.4s, B_2.4s, A_1.s[0]\n" + "stp qC_51, qC_52, [%x[cptr]]\n" + "fmla C_53.4s, B_3.4s, A_1.s[0]\n" + "str qC_53, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + "fmla C_61.4s, B_1.4s, A_2.s[0]\n" + "fmla C_62.4s, B_2.4s, A_2.s[0]\n" + "stp qC_61, qC_62, [%x[cptr]]\n" + "fmla C_63.4s, B_3.4s, A_2.s[0]\n" + "str qC_63, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + "fmla C_71.4s, B_1.4s, A_3.s[0]\n" + "fmla C_72.4s, B_2.4s, A_3.s[0]\n" + "stp qC_71, qC_72, [%x[cptr]]\n" + "fmla C_73.4s, B_3.4s, A_3.s[0]\n" + "str qC_73, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + "fmla C_81.4s, B_1.4s, A_4.s[0]\n" + "fmla C_82.4s, B_2.4s, A_4.s[0]\n" + "stp qC_81, qC_82, [%x[cptr]]\n" + "fmla C_83.4s, B_3.4s, A_4.s[0]\n" + "str qC_83, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride]\n" + + // Clear aliases + ".unreq aptr1\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + ".unreq aptr5\n" + ".unreq aptr6\n" + ".unreq aptr7\n" + + ".unreq A_1\n" ".unreq A_2\n" ".unreq A_3\n" ".unreq A_4\n" + ".unreq sA_1\n" ".unreq sA_2\n" ".unreq sA_3\n" ".unreq sA_4\n" + + ".unreq B_1\n" ".unreq B_2\n" ".unreq B_3\n" + ".unreq qB_1\n" ".unreq qB_2\n" ".unreq qB_3\n" + + ".unreq C_11\n" ".unreq C_12\n" ".unreq C_13\n" + ".unreq C_21\n" ".unreq C_22\n" ".unreq C_23\n" + ".unreq C_31\n" ".unreq C_32\n" ".unreq C_33\n" + ".unreq C_41\n" ".unreq C_42\n" ".unreq C_43\n" + ".unreq C_51\n" ".unreq C_52\n" ".unreq C_53\n" + ".unreq C_61\n" ".unreq C_62\n" ".unreq C_63\n" + ".unreq C_71\n" ".unreq C_72\n" ".unreq C_73\n" + ".unreq C_81\n" ".unreq C_82\n" ".unreq C_83\n" + + ".unreq qC_11\n" ".unreq qC_12\n" ".unreq qC_13\n" + ".unreq qC_21\n" ".unreq qC_22\n" ".unreq qC_23\n" + ".unreq qC_31\n" ".unreq qC_32\n" ".unreq qC_33\n" + ".unreq qC_41\n" ".unreq qC_42\n" ".unreq qC_43\n" + ".unreq qC_51\n" ".unreq qC_52\n" ".unreq qC_53\n" + ".unreq qC_61\n" ".unreq qC_62\n" ".unreq qC_63\n" + ".unreq qC_71\n" ".unreq qC_72\n" ".unreq qC_73\n" + ".unreq qC_81\n" ".unreq qC_82\n" ".unreq qC_83\n" + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride] "r" (a_row_stride * sizeof(float)), + [b_row_stride] "r" (b_row_stride * sizeof(float)), + [c_row_stride] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "x17", "x18", "x19", "x20", "x21", "x22", "x23" + ); + } + } +} + +/*****************************************************************************/ +/* 4x16 blocked GEMM with specialised tails + */ +#include "a64_sgemm_4x16.hpp" + +template <> +inline void BlockedGemm<4, 16, float, float>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + // Despatch based on tail of K + switch (K % 4) { + case 3: + sgemm_4x16_impl<3>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + case 2: + sgemm_4x16_impl<2>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + case 1: + sgemm_4x16_impl<1>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + case 0: + sgemm_4x16_impl<0>( + a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride + ); + break; + default: + assert(false); + } +} + +#endif // __aarch64__ diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp new file mode 100644 index 0000000000..5cd37de7a0 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp @@ -0,0 +1,1446 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +template +inline void sgemm_4x16_impl( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +); + +template <> +inline void sgemm_4x16_impl<0>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 0; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC12.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC13.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC14.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC21.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC22.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC23.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC24.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC31.4s, #0\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 2f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "2:" // Tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} + +template <> +inline void sgemm_4x16_impl<1>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 1; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC12.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC13.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC14.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC21.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC22.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC23.4s, #0\n" + "cbnz %x[k], 3f\n" + + // Prepare for tail in K + "movi vC24.4s, #0\n" + "ldr sA1, [%x[aptr]], #0x04\n" + "movi vC31.4s, #0\n" + "ldr sA2, [ aptr2], #0x04\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "b 2f\n" // Jump to tail + + "3:" // Prepare for loop over K + "movi vC24.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC31.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 4f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "4:" // Tail iteration + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr sA1, [%x[aptr]], #0x04\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr sA2, [ aptr2], #0x04\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + + "2:" // Common tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "ldr sA3, [ aptr3], #0x04\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "ldr sA4, [ aptr4], #0x04\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} + +template <> +inline void sgemm_4x16_impl<2>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 2; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC12.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC13.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC14.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC21.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC22.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC23.4s, #0\n" + "cbnz %x[k], 3f\n" + + // Prepare for tail in K + "movi vC24.4s, #0\n" + "ldr dA1, [%x[aptr]], #0x08\n" + "movi vC31.4s, #0\n" + "ldr dA2, [ aptr2], #0x08\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "b 2f\n" // Jump to tail + + "3:" // Prepare for loop over K + "movi vC24.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC31.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 4f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "4:" // Tail iteration + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr dA1, [%x[aptr]], #0x08\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr dA2, [ aptr2], #0x08\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + + "2:" // Common tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr dA3, [ aptr3], #0x08\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr dA4, [ aptr4], #0x08\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} + +template <> +inline void sgemm_4x16_impl<3>( + const float* const a, const float* const b, float *c, + const int M, const int K, const int N, + const int a_row_stride, + const int b_row_stride, + const int c_row_stride +) { + const int TAIL_SIZE = 3; + const int M_BLOCK = 4; + const int N_BLOCK = 16; + + const int m_blocks = iceildiv(M, M_BLOCK); + const int n_blocks = iceildiv(N, N_BLOCK); + + // For each block of output rows + for (int mblock = 0; mblock < m_blocks; mblock++) { + // For each block of output columns + for (int nblock = 0; nblock < n_blocks; nblock++) { + const float *aptr = a + mblock*M_BLOCK*a_row_stride; + const float *bptr = b + nblock*N_BLOCK; + float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; + int k = (K - TAIL_SIZE) / 4; + + asm volatile( + "aptr2 .req X20\n" + "aptr3 .req X21\n" + "aptr4 .req X22\n" + "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" + "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" + "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" + "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" + "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" + "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" + "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" + "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" + "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" + "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" + "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" + "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" + "vB1 .req v20\n" "qB1 .req q20\n" + "vB2 .req v21\n" "qB2 .req q21\n" + "vB3 .req v22\n" "qB3 .req q22\n" + "vB4 .req v23\n" "qB4 .req q23\n" + + // Clear accumulators, initialise pointers + "movi vC11.4s, #0\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "movi vC12.4s, #0\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "movi vC13.4s, #0\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "movi vC14.4s, #0\n" + "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" + "movi vC21.4s, #0\n" + "add aptr3, aptr2, %x[a_row_stride_bytes]\n" + "movi vC22.4s, #0\n" + "add aptr4, aptr3, %x[a_row_stride_bytes]\n" + "movi vC23.4s, #0\n" + "cbnz %x[k], 3f\n" + + // Prepare for tail in K + "movi vC24.4s, #0\n" + "ldr dA1, [%x[aptr]], #0x08\n" + "movi vC31.4s, #0\n" + "ldr dA2, [ aptr2], #0x08\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "b 2f\n" // Jump to tail + + "3:" // Prepare for loop over K + "movi vC24.4s, #0\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "movi vC31.4s, #0\n" + "ldr qA2, [ aptr2], #0x10\n" + "movi vC32.4s, #0\n" + "movi vC33.4s, #0\n" + "movi vC34.4s, #0\n" + "movi vC41.4s, #0\n" + "movi vC42.4s, #0\n" + "movi vC43.4s, #0\n" + "movi vC44.4s, #0\n" + "subs %x[k], %x[k], #1\n" + "beq 4f\n" + + "1:" // Loop proper + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "subs %x[k], %x[k], #1\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr qA1, [%x[aptr]], #0x10\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr qA2, [ aptr2], #0x10\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + "bne 1b\n" + + "4:" // Tail iteration + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qA3, [ aptr3], #0x10\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr qA4, [ aptr4], #0x10\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[2]\n" + "fmla vC21.4s, vB1.4s, vA2.s[2]\n" + "fmla vC31.4s, vB1.4s, vA3.s[2]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[2]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[2]\n" + "fmla vC22.4s, vB2.4s, vA2.s[2]\n" + "fmla vC32.4s, vB2.4s, vA3.s[2]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[2]\n" + "fmla vC13.4s, vB3.4s, vA1.s[2]\n" + "fmla vC23.4s, vB3.4s, vA2.s[2]\n" + "fmla vC33.4s, vB3.4s, vA3.s[2]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[2]\n" + "fmla vC14.4s, vB4.4s, vA1.s[2]\n" + "fmla vC24.4s, vB4.4s, vA2.s[2]\n" + "fmla vC34.4s, vB4.4s, vA3.s[2]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[2]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[3]\n" + "fmla vC21.4s, vB1.4s, vA2.s[3]\n" + "fmla vC31.4s, vB1.4s, vA3.s[3]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[3]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[3]\n" + "fmla vC22.4s, vB2.4s, vA2.s[3]\n" + "fmla vC32.4s, vB2.4s, vA3.s[3]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[3]\n" + "fmla vC13.4s, vB3.4s, vA1.s[3]\n" + "fmla vC23.4s, vB3.4s, vA2.s[3]\n" + "fmla vC33.4s, vB3.4s, vA3.s[3]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[3]\n" + "fmla vC14.4s, vB4.4s, vA1.s[3]\n" + "ldr dA1, [%x[aptr]], #0x08\n" + "fmla vC24.4s, vB4.4s, vA2.s[3]\n" + "ldr dA2, [ aptr2], #0x08\n" + "fmla vC34.4s, vB4.4s, vA3.s[3]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[3]\n" + + "2:" // Common tail + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr dA3, [ aptr3], #0x08\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "ldr dA4, [ aptr4], #0x08\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[1]\n" + "fmla vC21.4s, vB1.4s, vA2.s[1]\n" + "fmla vC31.4s, vB1.4s, vA3.s[1]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC41.4s, vB1.4s, vA4.s[1]\n" + "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" + "fmla vC12.4s, vB2.4s, vA1.s[1]\n" + "fmla vC22.4s, vB2.4s, vA2.s[1]\n" + "fmla vC32.4s, vB2.4s, vA3.s[1]\n" + "ldr qB1, [%x[bptr], #0x00]\n" + "fmla vC42.4s, vB2.4s, vA4.s[1]\n" + "fmla vC13.4s, vB3.4s, vA1.s[1]\n" + "fmla vC23.4s, vB3.4s, vA2.s[1]\n" + "fmla vC33.4s, vB3.4s, vA3.s[1]\n" + "ldr qB2, [%x[bptr], #0x10]\n" + "fmla vC43.4s, vB3.4s, vA4.s[1]\n" + "fmla vC14.4s, vB4.4s, vA1.s[1]\n" + "ldr sA1, [%x[aptr]], #0x04\n" + "fmla vC24.4s, vB4.4s, vA2.s[1]\n" + "ldr sA2, [ aptr2], #0x04\n" + "fmla vC34.4s, vB4.4s, vA3.s[1]\n" + "ldr qB3, [%x[bptr], #0x20]\n" + "fmla vC44.4s, vB4.4s, vA4.s[1]\n" + + "fmla vC11.4s, vB1.4s, vA1.s[0]\n" + "ldr qB4, [%x[bptr], #0x30]\n" + "fmla vC12.4s, vB2.4s, vA1.s[0]\n" + "stp qC11, qC12, [%x[cptr], #0x00]\n" + "fmla vC13.4s, vB3.4s, vA1.s[0]\n" + "ldr sA3, [ aptr3], #0x04\n" + "fmla vC14.4s, vB4.4s, vA1.s[0]\n" + "stp qC13, qC14, [%x[cptr], #0x20]\n" + "fmla vC21.4s, vB1.4s, vA2.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC22.4s, vB2.4s, vA2.s[0]\n" + "stp qC21, qC22, [%x[cptr], #0x00]\n" + "fmla vC23.4s, vB3.4s, vA2.s[0]\n" + "ldr sA4, [ aptr4], #0x04\n" + "fmla vC24.4s, vB4.4s, vA2.s[0]\n" + "stp qC23, qC24, [%x[cptr], #0x20]\n" + "fmla vC31.4s, vB1.4s, vA3.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC32.4s, vB2.4s, vA3.s[0]\n" + "stp qC31, qC32, [%x[cptr], #0x00]\n" + "fmla vC33.4s, vB3.4s, vA3.s[0]\n" + "fmla vC34.4s, vB4.4s, vA3.s[0]\n" + "stp qC33, qC34, [%x[cptr], #0x20]\n" + "fmla vC41.4s, vB1.4s, vA4.s[0]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + "fmla vC42.4s, vB2.4s, vA4.s[0]\n" + "stp qC41, qC42, [%x[cptr], #0x00]\n" + "fmla vC43.4s, vB3.4s, vA4.s[0]\n" + "fmla vC44.4s, vB4.4s, vA4.s[0]\n" + "stp qC43, qC44, [%x[cptr], #0x20]\n" + "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" + + ".unreq vB4\n" ".unreq qB4\n" + ".unreq vB3\n" ".unreq qB3\n" + ".unreq vB2\n" ".unreq qB2\n" + ".unreq vB1\n" ".unreq qB1\n" + ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" + ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" + ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" + ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" + ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" + ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" + ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" + ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" + ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" + ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" + ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" + ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" + ".unreq aptr2\n" + ".unreq aptr3\n" + ".unreq aptr4\n" + + : [aptr] "+r" (aptr), + [bptr] "+r" (bptr), + [cptr] "+r" (cptr), + [k] "+r" (k) + : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), + [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), + [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) + : "cc", "memory", "x20", "x21", "x22", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23" + ); + } + } +} diff --git a/arm_compute/core/NEON/kernels/winograd/perf.h b/arm_compute/core/NEON/kernels/winograd/perf.h new file mode 100644 index 0000000000..0cdf742a25 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/perf.h @@ -0,0 +1,9 @@ +#pragma once + +/* Prototypes from perf.c */ + +void start_counter(int fd); +long long get_counter(int fd); +long long stop_counter(int fd); +int open_instruction_counter(void); +int open_cycle_counter(void); diff --git a/arm_compute/core/NEON/kernels/winograd/profiler.hpp b/arm_compute/core/NEON/kernels/winograd/profiler.hpp new file mode 100644 index 0000000000..01fafa9604 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/profiler.hpp @@ -0,0 +1,326 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "perf.h" +#include + +#ifdef CYCLE_PROFILING +class EventIDContainer +{ + public: + EventIDContainer() : container_lock(), event_ids() + { + } + + int get_event_id(const char *id) + { + std::lock_guard lock(container_lock); + if (!event_ids.count(id)) { + event_ids.emplace(id, event_ids.size()); + } + return event_ids[id]; + } + + unsigned int size() const + { + return event_ids.size(); + } + + auto begin() + { + return event_ids.begin(); + } + + auto end() + { + return event_ids.end(); + } + + private: + std::mutex container_lock; + std::map event_ids; +}; + + +class ThreadEventCounterContainer +{ + public: + ThreadEventCounterContainer() : container_lock(), thread_counter_fds() + { + } + + int get_counter_fd() + { + const auto id = std::this_thread::get_id(); + std::lock_guard lock(container_lock); + if (!thread_counter_fds.count(id)) + { + thread_counter_fds.emplace(id, open_cycle_counter()); + } + return thread_counter_fds[id]; + } + + ~ThreadEventCounterContainer() + { + // Close all counter file descriptors + for (auto& fd : thread_counter_fds) + { + close(fd.second); + } + } + + private: + std::mutex container_lock; + std::map thread_counter_fds; +}; +#endif // CYCLE_PROFILING + + +class profiler { +private: +#ifdef CYCLE_PROFILING + struct ProfileEntry { + int event_id; + long int bytes_read, ops, bytes_written; + long int duration; + }; + + static const int maxevents = 10000; + ProfileEntry events[maxevents]; + int currentevent; + std::mutex event_lock; + + EventIDContainer event_ids; + ThreadEventCounterContainer thread_counter_fds; + + int get_event_id(const char *id) + { + return event_ids.get_event_id(id); + } +#endif // CYCLE_PROFILING + +public: +#ifdef CYCLE_PROFILING + profiler() : + currentevent(0), + event_lock(), + event_ids(), + thread_counter_fds() + { + } + + ~profiler() { + std::lock_guard lock_events(event_lock); + + // Compute performance from recorded events + struct ProfileResult { + ProfileResult() : total_calls(0), + total_duration(0), + total_bytes_read(0), + total_ops(0), + total_bytes_written(0) { + } + + void operator+=(const ProfileEntry &rhs) { + total_calls++; + total_duration += rhs.duration; + total_bytes_read += rhs.bytes_read; + total_ops += rhs.ops; + total_bytes_written = rhs.bytes_written; + } + + float avg_duration(void) const { + return static_cast(total_duration) / + static_cast(total_calls); + } + + float bytes_read_per_cycle(void) const { + return static_cast(total_bytes_read) / + static_cast(total_duration); + } + + float ops_per_cycle(void) const { + return static_cast(total_ops) / + static_cast(total_duration); + } + + float bytes_written_per_cycle(void) const { + return static_cast(total_bytes_written) / + static_cast(total_duration); + } + + long int total_calls, + total_duration, + total_bytes_read, + total_ops, + total_bytes_written; + }; + + std::vector totals; + totals.resize(event_ids.size()); + for (int i = 0; i < currentevent; i++) { + const auto &event = events[i]; + totals[event.event_id] += event; + } + + // Get the longest label + int len_label = 0; + for (const auto &kv : event_ids) { + len_label = std::max(len_label, static_cast(strlen(kv.first))); + } + + // Get the longest values for every other field + const auto get_length_of_field = + [totals] (const char *title, auto f, auto len) -> size_t { + size_t l = strlen(title); + for (const auto &v : totals) { + l = std::max(l, len(f(v))); + } + return l; + }; + + // Get the strlen for an int + const auto intlen = [] (long int x) -> size_t { + size_t len = 0; + do { + x /= 10; + len++; + } while (x); + return len; + }; + + // Get the strlen for a float + const auto floatlen = [] (const int precision) { + return [precision] (float x) { + size_t len = 0; + + if (!std::isfinite(x)) { + return static_cast(3); + } + + do { + x /= 10.0f; + len++; + } while (x > 1.0f); + return len + 1 + precision; + }; + }; + + const int len_calls = get_length_of_field( + "Calls", [] (const auto &v) {return v.total_calls;}, + intlen + ); + const int len_duration = get_length_of_field( + "Duration", [] (const auto &v) {return v.total_duration;}, + intlen + ); + const int len_average_duration = get_length_of_field( + "Average", [] (const auto &v) {return v.avg_duration();}, + floatlen(2) + ); + const int len_reads_per_cycle = get_length_of_field( + "Reads / cycle", + [] (const auto &v) {return v.bytes_read_per_cycle();}, + floatlen(6) + ); + const int len_ops_per_cycle = get_length_of_field( + "Ops / cycle", + [] (const auto &v) {return v.ops_per_cycle();}, + floatlen(6) + ); + const int len_writes_per_cycle = get_length_of_field( + "Writes / cycle", + [] (const auto &v) {return v.bytes_written_per_cycle();}, + floatlen(6) + ); + + // Print header + printf( + "%*s %*s %*s %*s %*s %*s %*s\n", + len_label, "", + len_calls, "Calls", + len_duration, "Duration", + len_average_duration, "Average", + len_reads_per_cycle, "Reads / cycle", + len_ops_per_cycle, "Ops / cycle", + len_writes_per_cycle, "Writes / cycle" + ); + for (const auto &kv : event_ids) { + const auto id = kv.second; + printf( + "%*s %*ld %*ld %*.2f %*.6f %*.6f %*.6f\n", + len_label, kv.first, + len_calls, totals[id].total_calls, + len_duration, totals[id].total_duration, + len_average_duration, totals[id].avg_duration(), + len_reads_per_cycle, totals[id].bytes_read_per_cycle(), + len_ops_per_cycle, totals[id].ops_per_cycle(), + len_writes_per_cycle, totals[id].bytes_written_per_cycle() + ); + } + printf("\n"); + } +#endif // CYCLE_PROFILING + + template + void operator() (const char * event, + T func, + long int bytes_read = 0, + long int ops = 0, + long int bytes_written = 0) { +#ifdef CYCLE_PROFILING + if (currentevent==maxevents) { + func(); + } else { + const auto countfd = thread_counter_fds.get_counter_fd(); + start_counter(countfd); + func(); + long long cycs = stop_counter(countfd); + + // Store the profiling data + std::lock_guard lock_events(event_lock); + events[currentevent++] = { + get_event_id(event), bytes_read, ops, bytes_written, cycs + }; + } +#else + (void) event; + (void) bytes_read; + (void) ops; + (void) bytes_written; + func(); +#endif // CYCLE_PROFILING + } +}; diff --git a/arm_compute/core/NEON/kernels/winograd/shims.hpp b/arm_compute/core/NEON/kernels/winograd/shims.hpp new file mode 100644 index 0000000000..09e14577ff --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/shims.hpp @@ -0,0 +1,747 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include +#include "arm.hpp" + +namespace reorder { +/** Re-order a tensor from NCHW format to NHWC. + * + * @note The stride parameters are optional and are provided to allow padding in either input or output tensors. + * + * @param[in] in Input tensor in NCHW format. + * @param[out] out Output tensor, to be written in NHWC format. + * @param n_batches Number of batches in the tensors. + * @param n_channels Number of channels in the tensors + * @param n_rows Height of the tensor + * @param n_cols Width of the tensor + * @param in_batch_stride Stride over batches in the input tensor. If `0` defaults to `n_channels * in_channel_stride`. + * @param in_channel_stride Stride over channels in the input tensor. If `0` defaults to `n_rows * in_row_stride`. + * @param in_row_stride Stride over rows in the input tensor. If `0` defaults to `n_cols`. + * @param out_batch_stride Stride over batches in the output tensor. If `0` defaults to `n_rows * out_row_stride`. + * @param out_row_stride Stride over rows in the output tensor. If `0` defaults to `n_cols * out_col_stride`. + * @param out_col_stride Stride over columns in the output tensor. If `0` defaults to `n_channels`. + */ +template +inline void nchw_to_nhwc( + const T* const in, + T* const out, + const int n_batches, + const int n_channels, + const int n_rows, + const int n_cols, + int in_batch_stride=0, + int in_channel_stride=0, + int in_row_stride=0, + int out_batch_stride=0, + int out_row_stride=0, + int out_col_stride=0 +); + +/** Re-order a tensor from NHWC format to NCHW. + * + * @note The stride parameters are optional and are provided to allow padding in either input or output tensors. + * + * @param[in] in Input tensor in NHWC format. + * @param[out] out Output tensor, to be written in NCHW format. + * @param n_batches Number of batches in the tensors. + * @param n_rows Height of the tensor + * @param n_cols Width of the tensor + * @param n_channels Number of channels in the tensors + * @param in_batch_stride Stride over batches in the input tensor. If `0` defaults to `n_rows * in_row_stride`. + * @param in_row_stride Stride over rows in the input tensor. If `0` defaults to `n_cols * in_col_stride`. + * @param in_col_stride Stride over columns in the input tensor. If `0` defaults to `n_channels`. + * @param out_batch_stride Stride over batches in the output tensor. If `0` defaults to `n_channels * out_channel_stride`. + * @param out_channel_stride Stride over channels in the output tensor. If `0` defaults to `n_rows * out_row_stride`. + * @param out_row_stride Stride over rows in the output tensor. If `0` defaults to `n_cols`. + */ +template +inline void nhwc_to_nchw( + const T* const in, // Input data in NHWC form + T* const out, // Output data in NCHW form + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + int in_batch_stride=0, + int in_row_stride=0, + int in_col_stride=0, + int out_batch_stride=0, + int out_channel_stride=0, + int out_row_stride=0 +); + +/** Re-order a weight tensor from [Output feature map x Input feature map x + * Height x Width] format to [Height x Width x Input feature map x Output + * feature map] format. + */ +template +inline void ofm_ifm_h_w_to_h_w_ifm_ofm( + const T* const in, // Input in [Output x Input x Height x Width] form + T* const out, // Output in [Height x Width x Input x Output] form + const int n_output_feature_maps, + const int n_input_feature_maps, + const int n_rows, + const int n_cols, + int in_output_feature_map_stride=0, + int in_input_feature_map_stride=0, + int in_row_stride=0, + int out_row_stride=0, + int out_col_stride=0, + int out_input_feature_map_stride=0 +); + +/** Re-order a weight tensor from [Height x Width x Input feature map x Output + * feature map] format to [Output feature map x Input feature map x Height x + * Width] format. + */ +template +inline void h_w_ifm_ofm_to_ofm_ifm_h_w( + const T* const in, // Input in [Height x Width x Input x Output] form + T* const out, // Output in [Output x Input x Height x Width] form + const int n_rows, + const int n_cols, + const int n_input_feature_maps, + const int n_output_feature_maps, + int in_row_stride=0, + int in_col_stride=0, + int in_input_feature_map_stride=0, + int out_output_feature_map_stride=0, + int out_input_feature_map_stride=0, + int out_row_stride=0 +); + +/*****************************************************************************/ +/* 32-bit implementation : NCHW -> NHWC + */ +template <> +inline void nchw_to_nhwc( + const int32_t* const in, + int32_t* const out, + const int n_batches, + const int n_channels, + const int n_rows, + const int n_cols, + int in_batch_stride, + int in_channel_stride, + int in_row_stride, + int out_batch_stride, + int out_row_stride, + int out_col_stride +) +{ + typedef int32_t T; + + // Fill in the stride values + in_row_stride = (in_row_stride) ? in_row_stride : n_cols; + in_channel_stride = (in_channel_stride) ? in_channel_stride + : n_rows * in_row_stride; + in_batch_stride = (in_batch_stride) ? in_batch_stride + : n_channels * in_channel_stride; + + out_col_stride = (out_col_stride) ? out_col_stride : n_channels; + out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride; + out_batch_stride = (out_batch_stride) ? out_batch_stride + : n_rows * out_row_stride; + + // Perform the re-ordering + for (int n = 0; n < n_batches; n++) + { + const T* const in_batch = in + n*in_batch_stride; + T* const out_batch = out + n*out_batch_stride; + + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in_batch + i*in_row_stride; + T* const out_row = out_batch + i*out_row_stride; + + int j = 0, j_remaining = n_cols; +#ifdef __arm_any__ + for (; j_remaining >= 4; j += 4, j_remaining -= 4) + { + int c = 0, c_remaining = n_channels; + for (; c_remaining >= 4; c += 4, c_remaining -= 4) + { + // Read 4 channels worth of 4 columns, then zip to produce 4 columns + // worth of 4 channels. + int32x4_t channel_pixels[4]; + channel_pixels[0] = vld1q_s32(in_row + (c + 0)*in_channel_stride + j); + channel_pixels[1] = vld1q_s32(in_row + (c + 1)*in_channel_stride + j); + channel_pixels[2] = vld1q_s32(in_row + (c + 2)*in_channel_stride + j); + channel_pixels[3] = vld1q_s32(in_row + (c + 3)*in_channel_stride + j); + + const auto zip1 = vzipq_s32(channel_pixels[0], channel_pixels[2]); + const auto zip2 = vzipq_s32(channel_pixels[1], channel_pixels[3]); + const auto out_0 = vzipq_s32(zip1.val[0], zip2.val[0]); + const auto out_1 = vzipq_s32(zip1.val[1], zip2.val[1]); + + vst1q_s32(out_row + (j + 0)*out_col_stride + c, out_0.val[0]); + vst1q_s32(out_row + (j + 1)*out_col_stride + c, out_0.val[1]); + vst1q_s32(out_row + (j + 2)*out_col_stride + c, out_1.val[0]); + vst1q_s32(out_row + (j + 3)*out_col_stride + c, out_1.val[1]); + } + for (; c_remaining; c++, c_remaining--) + { + for (int _j = 0; _j < 4; _j++) + { + const T* const in_col = in_row + j + _j; + T* const out_col = out_row + (j + _j)*out_col_stride; + const T* const in_channel = in_col + c*in_channel_stride; + out_col[c] = *(in_channel); + } + } + } + for (; j_remaining >= 2; j += 2, j_remaining -= 2) + { + int c = 0, c_remaining = n_channels; + for (; c_remaining >= 2; c += 2, c_remaining -= 2) + { + // Read 2 channels worth of 2 columns, then zip to produce 2 columns + // worth of 2 channels. + int32x2_t channel_pixels[2]; + channel_pixels[0] = vld1_s32(in_row + (c + 0)*in_channel_stride + j); + channel_pixels[1] = vld1_s32(in_row + (c + 1)*in_channel_stride + j); + + const auto output = vzip_s32(channel_pixels[0], channel_pixels[1]); + + vst1_s32(out_row + (j + 0)*out_col_stride + c, output.val[0]); + vst1_s32(out_row + (j + 1)*out_col_stride + c, output.val[1]); + } + for (; c_remaining; c++, c_remaining--) + { + for (int _j = 0; _j < 2; _j++) + { + const T* const in_col = in_row + j + _j; + T* const out_col = out_row + (j + _j)*out_col_stride; + const T* const in_channel = in_col + c*in_channel_stride; + out_col[c] = *(in_channel); + } + } + } +#endif // __arm_any__ + for (; j_remaining; j++, j_remaining--) + { + const T* const in_col = in_row + j; + T* const out_col = out_row + j*out_col_stride; + + for (int c = 0; c < n_channels; c++) + { + const T* const in_channel = in_col + c*in_channel_stride; + out_col[c] = *(in_channel); + } + } + } + } +} + +template <> +inline void nchw_to_nhwc( + const uint32_t* const in, + uint32_t* const out, + const int n_batches, + const int n_channels, + const int n_rows, + const int n_cols, + int in_batch_stride, + int in_channel_stride, + int in_row_stride, + int out_batch_stride, + int out_row_stride, + int out_col_stride +) +{ + nchw_to_nhwc( + reinterpret_cast(in), + reinterpret_cast(out), + n_batches, n_channels, n_rows, n_cols, + in_batch_stride, in_channel_stride, in_row_stride, + out_batch_stride, out_row_stride, out_col_stride + ); +} + +template <> +inline void nchw_to_nhwc( + const float* const in, + float* const out, + const int n_batches, + const int n_channels, + const int n_rows, + const int n_cols, + int in_batch_stride, + int in_channel_stride, + int in_row_stride, + int out_batch_stride, + int out_row_stride, + int out_col_stride +) +{ + nchw_to_nhwc( + reinterpret_cast(in), + reinterpret_cast(out), + n_batches, n_channels, n_rows, n_cols, + in_batch_stride, in_channel_stride, in_row_stride, + out_batch_stride, out_row_stride, out_col_stride + ); +} + +/*****************************************************************************/ +/* Generic implementation : NCHW -> NHWC + */ +template +inline void nchw_to_nhwc( + const T* const in, + T* const out, + const int n_batches, + const int n_channels, + const int n_rows, + const int n_cols, + int in_batch_stride, + int in_channel_stride, + int in_row_stride, + int out_batch_stride, + int out_row_stride, + int out_col_stride +) +{ + // Fill in the stride values + in_row_stride = (in_row_stride) ? in_row_stride : n_cols; + in_channel_stride = (in_channel_stride) ? in_channel_stride + : n_rows * in_row_stride; + in_batch_stride = (in_batch_stride) ? in_batch_stride + : n_channels * in_channel_stride; + + out_col_stride = (out_col_stride) ? out_col_stride : n_channels; + out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride; + out_batch_stride = (out_batch_stride) ? out_batch_stride + : n_rows * out_row_stride; + + // Perform the re-ordering + for (int n = 0; n < n_batches; n++) + { + const T* const in_batch = in + n*in_batch_stride; + T* const out_batch = out + n*out_batch_stride; + + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in_batch + i*in_row_stride; + T* const out_row = out_batch + i*out_row_stride; + + for (int j = 0; j < n_cols; j++) + { + const T* const in_col = in_row + j; + T* const out_col = out_row + j*out_col_stride; + + for (int c = 0; c < n_channels; c++) + { + const T* const in_channel = in_col + c*in_channel_stride; + out_col[c] = *(in_channel); + } + } + } + } +} + +/*****************************************************************************/ +/* 32-bit implementation : NHWC -> NCHW + */ +template <> +inline void nhwc_to_nchw( + const int32_t* const in, // Input data in NHWC form + int32_t* const out, // Output data in NCHW form + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + int in_batch_stride, + int in_row_stride, + int in_col_stride, + int out_batch_stride, + int out_channel_stride, + int out_row_stride +) +{ + typedef int32_t T; + + // Fill in stride values + in_col_stride = (in_col_stride) ? in_col_stride : n_channels; + in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride; + in_batch_stride = (in_batch_stride) ? in_batch_stride + : n_rows * in_row_stride; + + out_row_stride = (out_row_stride) ? out_row_stride : n_cols; + out_channel_stride = (out_channel_stride) ? out_channel_stride + : n_rows * out_row_stride; + out_batch_stride = (out_batch_stride) ? out_batch_stride + : n_channels * out_channel_stride; + + // Perform the re-ordering + // For every batch + for (int n = 0; n < n_batches; n++) + { + const T* const in_batch = in + n*in_batch_stride; + T* const out_batch = out + n*out_batch_stride; + + // For every row + for (int i = 0; i < n_rows; i++) + { + const T* const in_i = in_batch + i*in_row_stride; + T* const out_i = out_batch + i*out_row_stride; + + // For every column, beginning with chunks of 4 + int j = 0, j_remaining = n_cols; +#ifdef __arm_any__ + for (; j_remaining >= 4; j += 4, j_remaining -=4) + { + // For every channel, beginning with chunks of 4 + int c = 0, c_remaining = n_channels; + for (; c_remaining >= 4; c += 4, c_remaining -= 4) + { + // Read 4 columns worth of 4 channels then zip to produce 4 channels + // worth of 4 columns. + int32x4_t pixel_channels[4]; + pixel_channels[0] = vld1q_s32(in_i + (j + 0)*in_col_stride + c); + pixel_channels[1] = vld1q_s32(in_i + (j + 1)*in_col_stride + c); + pixel_channels[2] = vld1q_s32(in_i + (j + 2)*in_col_stride + c); + pixel_channels[3] = vld1q_s32(in_i + (j + 3)*in_col_stride + c); + + const auto zip1 = vzipq_s32(pixel_channels[0], pixel_channels[2]); + const auto zip2 = vzipq_s32(pixel_channels[1], pixel_channels[3]); + const auto out_0 = vzipq_s32(zip1.val[0], zip2.val[0]); + const auto out_1 = vzipq_s32(zip1.val[1], zip2.val[1]); + + vst1q_s32(out_i + j + (c + 0)*out_channel_stride, out_0.val[0]); + vst1q_s32(out_i + j + (c + 1)*out_channel_stride, out_0.val[1]); + vst1q_s32(out_i + j + (c + 2)*out_channel_stride, out_1.val[0]); + vst1q_s32(out_i + j + (c + 3)*out_channel_stride, out_1.val[1]); + } + for (; c_remaining; c++, c_remaining--) + { + for (int _j = 0; _j < 4; _j++) + { + const T* const in_j = in_i + (j + _j)*in_col_stride; + T* const out_j = out_i + (j + _j); + + const T* const in_channel = in_j + c; + T* const out_channel = out_j + c*out_channel_stride; + *(out_channel) = *(in_channel); + } + } + } + for (; j_remaining >= 2; j += 2, j_remaining -=2) + { + int c = 0, c_remaining = n_channels; + for (; c_remaining >= 2; c += 2, c_remaining -= 2) + { + // Read 2 columns worth of 2 channels then zip to produce 2 channels + // worth of 2 columns. + int32x2_t pixel_channels[2]; + pixel_channels[0] = vld1_s32(in_i + (j + 0)*in_col_stride + c); + pixel_channels[1] = vld1_s32(in_i + (j + 1)*in_col_stride + c); + + const auto output = vzip_s32(pixel_channels[0], pixel_channels[1]); + + vst1_s32(out_i + j + (c + 0)*out_channel_stride, output.val[0]); + vst1_s32(out_i + j + (c + 1)*out_channel_stride, output.val[1]); + } + for (; c_remaining; c++, c_remaining--) + { + for (int _j = 0; _j < 2; _j++) + { + const T* const in_j = in_i + (j + _j)*in_col_stride; + T* const out_j = out_i + (j + _j); + + const T* const in_channel = in_j + c; + T* const out_channel = out_j + c*out_channel_stride; + *(out_channel) = *(in_channel); + } + } + } +#endif // __arm_any__ + for (; j_remaining; j++, j_remaining--) + { + const T* const in_j = in_i + j*in_col_stride; + T* const out_j = out_i + j; + + // For every channel + for (int c = 0; c < n_channels; c++) + { + const T* const in_channel = in_j + c; + T* const out_channel = out_j + c*out_channel_stride; + *(out_channel) = *(in_channel); + } + } + } + } +} + +template <> +inline void nhwc_to_nchw( + const uint32_t* const in, // Input data in NHWC form + uint32_t* const out, // Output data in NCHW form + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + int in_batch_stride, + int in_row_stride, + int in_col_stride, + int out_batch_stride, + int out_channel_stride, + int out_row_stride +) +{ + // Redirect to generic 32-bit implementation + nhwc_to_nchw( + reinterpret_cast(in), + reinterpret_cast(out), + n_batches, n_rows, n_cols, n_channels, + in_batch_stride, in_row_stride, in_col_stride, + out_batch_stride, out_channel_stride, out_row_stride + ); +} + +template <> +inline void nhwc_to_nchw( + const float* const in, // Input data in NHWC form + float* const out, // Output data in NCHW form + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + int in_batch_stride, + int in_row_stride, + int in_col_stride, + int out_batch_stride, + int out_channel_stride, + int out_row_stride +) +{ + // Redirect to generic 32-bit implementation + nhwc_to_nchw( + reinterpret_cast(in), + reinterpret_cast(out), + n_batches, n_rows, n_cols, n_channels, + in_batch_stride, in_row_stride, in_col_stride, + out_batch_stride, out_channel_stride, out_row_stride + ); +} + +/*****************************************************************************/ +/* Generic implementation : NHWC -> NCHW + */ +template +inline void nhwc_to_nchw( + const T* const in, // Input data in NHWC form + T* const out, // Output data in NCHW form + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + int in_batch_stride, + int in_row_stride, + int in_col_stride, + int out_batch_stride, + int out_channel_stride, + int out_row_stride +) +{ + // Fill in stride values + in_col_stride = (in_col_stride) ? in_col_stride : n_channels; + in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride; + in_batch_stride = (in_batch_stride) ? in_batch_stride + : n_rows * in_row_stride; + + out_row_stride = (out_row_stride) ? out_row_stride : n_cols; + out_channel_stride = (out_channel_stride) ? out_channel_stride + : n_rows * out_row_stride; + out_batch_stride = (out_batch_stride) ? out_batch_stride + : n_channels * out_channel_stride; + + // Perform the re-ordering + // For every batch + for (int n = 0; n < n_batches; n++) + { + const T* const in_batch = in + n*in_batch_stride; + T* const out_batch = out + n*out_batch_stride; + + // For every row + for (int i = 0; i < n_rows; i++) + { + const T* const in_i = in_batch + i*in_row_stride; + T* const out_i = out_batch + i*out_row_stride; + + // For every column + for (int j = 0; j < n_cols; j++) + { + const T* const in_j = in_i + j*in_col_stride; + T* const out_j = out_i + j; + + // For every channel + for (int c = 0; c < n_channels; c++) + { + const T* const in_channel = in_j + c; + T* const out_channel = out_j + c*out_channel_stride; + *(out_channel) = *(in_channel); + } + } + } + } +} + +/*****************************************************************************/ +/* Generic weight re-order implementation. + */ +template +inline void ofm_ifm_h_w_to_h_w_ifm_ofm( + const T* const in, // Input in [Output x Input x Height x Width] form + T* const out, // Output in [Height x Width x Input x Output] form + const int n_output_feature_maps, + const int n_input_feature_maps, + const int n_rows, + const int n_cols, + int in_output_feature_map_stride, + int in_input_feature_map_stride, + int in_row_stride, + int out_row_stride, + int out_col_stride, + int out_input_feature_map_stride +) +{ + // Fill in stride values + in_row_stride = (in_row_stride) + ? in_row_stride + : n_cols; + in_input_feature_map_stride = (in_input_feature_map_stride) + ? in_input_feature_map_stride + : n_rows * in_row_stride; + in_output_feature_map_stride = (in_output_feature_map_stride) + ? in_output_feature_map_stride + : n_input_feature_maps * in_input_feature_map_stride; + + out_input_feature_map_stride = (out_input_feature_map_stride) + ? out_input_feature_map_stride + : n_output_feature_maps; + out_col_stride = (out_col_stride) + ? out_col_stride + : n_input_feature_maps * out_input_feature_map_stride; + out_row_stride = (out_row_stride) + ? out_row_stride + : n_cols * out_col_stride; + + // Perform the re-ordering + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in + i * in_row_stride; + T* out_row = out + i * out_row_stride; + + for (int j = 0; j < n_cols; j++) + { + const T* const in_col = in_row + j; + T* const out_col = out_row + j * out_col_stride; + + for (int ifm = 0; ifm < n_input_feature_maps; ifm++) + { + const T* const in_ifm = in_col + ifm * in_input_feature_map_stride; + T* const out_ifm = out_col + ifm * out_input_feature_map_stride; + + for (int ofm = 0; ofm < n_output_feature_maps; ofm++) + { + const T* const in_ofm = in_ifm + ofm * in_output_feature_map_stride; + T* const out_ofm = out_ifm + ofm; + *(out_ofm) = *(in_ofm); + } + } + } + } +} + +/*****************************************************************************/ +/* Generic weight re-order implementation. + */ +template +inline void h_w_ifm_ofm_to_ofm_ifm_h_w( + const T* const in, // Input in [Height x Width x Input x Output] form + T* const out, // Output in [Output x Input x Height x Width] form + const int n_rows, + const int n_cols, + const int n_input_feature_maps, + const int n_output_feature_maps, + int in_row_stride, + int in_col_stride, + int in_input_feature_map_stride, + int out_output_feature_map_stride, + int out_input_feature_map_stride, + int out_row_stride +) +{ + // Fill in the stride values + in_input_feature_map_stride = (in_input_feature_map_stride) + ? in_input_feature_map_stride + : n_output_feature_maps; + in_col_stride = (in_col_stride) + ? in_col_stride + : n_input_feature_maps * in_input_feature_map_stride; + in_row_stride = (in_row_stride) + ? in_row_stride + : n_cols * in_col_stride; + + out_row_stride = (out_row_stride) + ? out_row_stride + : n_cols; + out_input_feature_map_stride = (out_input_feature_map_stride) + ? out_input_feature_map_stride + : n_rows * out_row_stride; + out_output_feature_map_stride = (out_output_feature_map_stride) + ? out_output_feature_map_stride + : n_input_feature_maps * out_input_feature_map_stride; + + // Perform the re-ordering + for (int i = 0; i < n_rows; i++) + { + const T* const in_row = in + i * in_row_stride; + T* const out_row = out + i * out_row_stride; + + for (int j = 0; j < n_cols; j++) + { + const T* const in_col = in_row + j * in_col_stride; + T* const out_col = out_row + j; + + for (int ifm = 0; ifm < n_input_feature_maps; ifm++) + { + const T* const in_ifm = in_col + ifm * in_input_feature_map_stride; + T* const out_ifm = out_col + ifm * out_input_feature_map_stride; + + for (int ofm = 0; ofm < n_output_feature_maps; ofm++) + { + const T* const in_ofm = in_ifm + ofm; + T* const out_ofm = out_ifm + ofm * out_output_feature_map_stride; + *(out_ofm) = *(in_ofm); + } + } + } + } +} + +} // namespace reorder diff --git a/arm_compute/core/NEON/kernels/winograd/tensor.hpp b/arm_compute/core/NEON/kernels/winograd/tensor.hpp index 70ef65d2a5..6567eeb23d 100644 --- a/arm_compute/core/NEON/kernels/winograd/tensor.hpp +++ b/arm_compute/core/NEON/kernels/winograd/tensor.hpp @@ -23,39 +23,44 @@ */ #pragma once -#include #include #include #include "alloc.hpp" -/*****************************************************************************/ -/* Padding definitions */ -enum PaddingType { - PADDING_SAME, PADDING_VALID +enum TensorOrder +{ + NHWC, ///< [Batch x Height x Width x Channels] + NCHW, ///< [Batch x Channels x Height x Width] }; -/*****************************************************************************/ -/* Shape of a kernel */ -struct KernelShape { - int n_output_channels, n_rows, n_cols, n_input_channels; - - int size(void) const { - return n_output_channels * n_rows * n_cols * n_input_channels; +struct Tensor4DShape +{ + int n_batches, n_rows, n_cols, n_channels; + TensorOrder ordering; + + // Create a new tensor with the default (NHWC) ordering + inline Tensor4DShape( + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels, + const TensorOrder ordering=NHWC + ) : n_batches(n_batches), + n_rows(n_rows), + n_cols(n_cols), + n_channels(n_channels), + ordering(ordering) + { } -}; - -struct Tensor4DShape { - int n_batches, - n_rows, - n_cols, - n_channels; - int size() const { + inline int size() const + { return n_batches * n_rows * n_cols * n_channels; } - bool TestEq(const Tensor4DShape& other) const { + inline bool TestEq(const Tensor4DShape& other) const + { return (n_batches == other.n_batches && n_rows == other.n_rows && n_cols == other.n_cols && @@ -63,148 +68,110 @@ struct Tensor4DShape { } }; + +enum WeightOrder +{ + HWIO, ///< [Height x Width x Input channels x Output channels] + OIHW, ///< [Output channels x Input channels x Height x Width] +}; + +struct KernelShape +{ + int n_output_channels, n_rows, n_cols, n_input_channels; + WeightOrder ordering; + + inline KernelShape( + const int n_output_channels, + const int n_rows, + const int n_cols, + const int n_input_channels, + const WeightOrder ordering=HWIO + ) : n_output_channels(n_output_channels), + n_rows(n_rows), + n_cols(n_cols), + n_input_channels(n_input_channels), + ordering(ordering) + { + } + + inline int size(void) const + { + return n_output_channels * n_rows * n_cols * n_input_channels; + } +}; + + template -class Tensor4D final { +class Tensor4D final +{ public: Tensor4D(ShapeT shape) : - _shape(shape), - _data(reinterpret_cast(ALLOCATE(size_bytes()))) { + shape(shape), + _data(reinterpret_cast(ALLOCATE(size_bytes()))) + { Clear(); } + Tensor4D(const Tensor4D&) = delete; + Tensor4D operator=(const Tensor4D&) = delete; + ~Tensor4D() { free(_data); } - T* ptr() const { + inline T* ptr() const { return _data; } - const ShapeT& shape() const { - return _shape; + inline size_t size_bytes() const { + return shape.size() * sizeof(T); } - size_t size_bytes() const { - return _shape.size() * sizeof(T); - } - - bool TestEq(Tensor4D& other) const; - T& element(int, int, int, int) const; - void Print() const; + inline T& element(int, int, int, int) const; - void Clear() { + inline void Clear() { Fill(static_cast(0)); } - void Fill(T val) { - for (int i = 0; i < _shape.size(); i++) + inline void Fill(T val) { + for (int i = 0; i < shape.size(); i++) _data[i] = val; } - void TestPattern() { - for (int i = 0; i < _shape.size(); i++) - _data[i] = static_cast(i); - } - - void Rand(const int seed=2311) { - std::mt19937 gen(seed); - std::uniform_int_distribution<> dis(-50, +50); - - for (int i = 0; i < _shape.size(); i++) { - _data[i] = static_cast(dis(gen)); - } - } - Tensor4D(const Tensor4D &) = delete; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - Tensor4D &operator=(const Tensor4D &) = delete; - /** Allow instances of this class to be moved */ - Tensor4D(Tensor4D &&) = default; - /** Allow instances of this class to be moved */ - Tensor4D &operator=(Tensor4D &&) = default; - + const ShapeT shape; private: - const ShapeT _shape; T* const _data; }; template <> -inline float& Tensor4D::element(int n, int i, int j, int c) const { - int index = ((n*_shape.n_rows + i)*_shape.n_cols + j)*_shape.n_channels + c; - return _data[index]; -} - - -template <> -inline float& Tensor4D::element(int oc, int i, int j, int ic) const { - int index = ((i*_shape.n_cols + j)*_shape.n_input_channels + ic)*_shape.n_output_channels + oc; - return _data[index]; -} - -template <> -inline bool Tensor4D::TestEq(Tensor4D& other) const { - // Test equivalence, printing errors - // First test the shapes are the same - if (!_shape.TestEq(other.shape())) { - printf("Tensors have different shapes.\n"); - return false; - } else { - int incorrects = 0; - - for (int n = 0; n < _shape.n_batches; n++) { - for (int i = 0; i < _shape.n_rows; i++) { - for (int j = 0; j < _shape.n_cols; j++) { - for (int c = 0; c < _shape.n_channels; c++) { - // Check elements for equivalence - const auto a = this->element(n, i, j, c); - const auto b = other.element(n, i, j, c); - - if (a != b) { - printf("Difference at element {%d, %d, %d, %d}: %.3f != %.3f\n", n, i, j, c, a, b); - - if (++incorrects > 100) { - printf("More than 100 incorrect values, stopping test.\n"); - return false; - } - } - } - } - } - } - - return incorrects == 0; +inline float& Tensor4D::element(int n, int i, int j, int c) const +{ + int index; + if (shape.ordering == NHWC) + { + index = ((n*shape.n_rows + i)*shape.n_cols + j)*shape.n_channels + c; } -} - - -template <> -inline void Tensor4D::Print() const { - for (int n = 0; n < _shape.n_batches; n++) { - for (int c = 0; c < _shape.n_channels; c++) { - for (int i = 0; i < _shape.n_rows; i++) { - for (int j = 0; j < _shape.n_cols; j++) { - printf("%5.2f ", element(n, i, j, c)); - } - printf("\n"); - } - printf("\n"); - } + else // NCHW + { + index = ((n*shape.n_channels + c)*shape.n_rows + i)*shape.n_cols + j; } + return _data[index]; } template <> -inline void Tensor4D::Print() const { - for (int oc = 0; oc < _shape.n_output_channels; oc++) { - for (int ic = 0; ic < _shape.n_input_channels; ic++) { - for (int i = 0; i < _shape.n_rows; i++) { - for (int j = 0; j < _shape.n_cols; j++) { - printf("%5.2f ", element(oc, i, j, ic)); - } - printf("\n"); - } - printf("\n"); - } +inline float& Tensor4D::element(int oc, int i, int j, int ic) const +{ + int index; + if (shape.ordering == HWIO) + { + index = ((i*shape.n_cols + j)*shape.n_input_channels + ic)*shape.n_output_channels + oc; } + else // OIHW + { + index = ((oc*shape.n_input_channels + ic)*shape.n_rows + i)*shape.n_cols + j; + } + return _data[index]; } diff --git a/arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp b/arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp new file mode 100644 index 0000000000..68a5c6a178 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include "tensor.hpp" + +// Methods to print tensors and weights +void PrintTensor(const Tensor4D& tensor); +void PrintWeights(const Tensor4D& weights); + +// Test the equivalence of two tensors +bool CmpTensors(const Tensor4D& a, + const Tensor4D& b, + const float max_delta=0.0f); + +// Fill the tensor with a test pattern +void TestPattern(Tensor4D& tensor); +void TestPattern(Tensor4D& weights); + +// Fill the tensor with random values +void Randomise(Tensor4D& tensor, const int seed=0); +void Randomise(Tensor4D& weights, const int seed=0); diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp new file mode 100644 index 0000000000..39b444184e --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include "../winograd_gemm.hpp" + +namespace winograd +{ + /***************************************************************************/ + /* Instance-less API */ + template + template + void WinogradGEMM::InputTransform::execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride + ) + { + // Compute the padding required on each edge of the image + const bool base_padding = (padding_type == PADDING_SAME) ? 1 : 0; + const int pad_top = base_padding; + const int pad_left = base_padding; + const int tile_overlap = kernel_rows - 1; + + // Compute striding values (assuming NHWC ordered data) + const int input_col_stride = input_shape.n_channels; + const int input_row_stride = input_shape.n_cols * input_col_stride; + const int input_batch_stride = input_shape.n_rows * input_row_stride; + const int output_col_stride = matrix_row_stride; + const int output_row_stride = tile_N * output_col_stride; + + // Loop over batches + for (int batch = 0; batch < input_shape.n_batches; batch++) + { + // Pointer to the batch + const T* const input_base_batch = inptr + batch * input_batch_stride; + T* const outptr_base_batch = outptr_base + batch * matrix_batch_stride; + + // Loop over rows of tiles + for (int tile_i = 0; tile_i < tile_M; tile_i++) + { + // Pointer to the row + const int row_offset = (tile_i == 0) ? + 0 : ((padding_type == PADDING_VALID) ? 0 : 1); + const T* const input_base_row = ( + input_base_batch + ((inner_tile_rows - 2)*tile_i - row_offset)*input_row_stride + ); + T* const outptr_base_row = outptr_base_batch + tile_i*output_row_stride; + + // Padding (top + bottom) for the row + const int row_top = tile_i*(inner_tile_rows - tile_overlap) - pad_top; + const int row_bottom = row_top + inner_tile_rows; + const int row_pad_top = (tile_i == 0) ? pad_top : 0; + const int row_pad_bottom = (row_bottom <= input_shape.n_rows) ? 0 : row_bottom - input_shape.n_rows; + + // Process the row + process_tile_row( + tile_N, input_shape.n_channels, + input_base_row, input_row_stride, input_col_stride, + outptr_base_row, matrix_stride, matrix_row_stride, + row_pad_top, pad_left, row_pad_bottom, input_shape.n_cols + ); + } + } + } + + template + template + void WinogradGEMM::InputTransform::process_tile_row( + const int tile_N, + int n_channels, + const T* const input_base, + const int input_row_stride, + const int input_col_stride, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + const int pad_top, + const int row_pad_left, + const int pad_bottom, + const int n_cols + ) + { + constexpr int tile_overlap = kernel_cols - 1; + + // Loop over columns of tiles + for (int tile_j = 0; tile_j < tile_N; tile_j++) + { + // Padding (left + right) for the tile + const int t_pad_left = (tile_j == 0) ? row_pad_left : 0; + const int t_start = tile_j*(inner_tile_cols - tile_overlap) - row_pad_left; + const int t_end = t_start + inner_tile_cols; + const int t_pad_right = (t_end <= n_cols) ? 0 : t_end - n_cols; + + // Get pointers into the inputs and outputs + const int col_offset = (tile_j == 0) ? 0 : row_pad_left; + const T* const input_base_col = ( + input_base + ((inner_tile_cols - tile_overlap)*tile_j - col_offset)*input_col_stride + ); + T* const outptr = matrix_base + tile_j*matrix_row_stride; + + // Apply the specific tile processing function + tile_fns[pad_top][t_pad_left][pad_bottom][t_pad_right]( + n_channels, + input_base_col, + input_row_stride, + input_col_stride, + outptr, + matrix_stride + ); + } + } + + /***************************************************************************/ + template + template + WinogradGEMM::InputTransform::InputTransform( + const T* const input, /** Input tensor data */ + const int n_batches, /** Number of batches in input tensor. */ + const int n_rows, /** Number of rows in input tensor. */ + const int n_cols, /** Number of columns in input tensor. */ + const int n_channels, /** Number of channels in input tensor. */ + const PaddingType padding, /** Padding type. */ + T* const output, /** Base of output matrices. */ + const int matrix_stride, /** Stride between output matrices. */ + const int matrix_row_stride /** Stride within matrices. */ + ) : _inptr(input), _outptr(output), + _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels), + _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride), + _tiles_M(iceildiv((padding == PADDING_SAME) ? n_rows : n_rows - 2, output_tile_rows)), + _tiles_N(iceildiv((padding == PADDING_SAME) ? n_cols : n_cols - 2, output_tile_cols)), + _padding_type(padding) + { + } + + template + template + unsigned int WinogradGEMM::InputTransform::get_window() const + { + // TODO When the input transform supports multithreading, return the total + // number of tile rows (allowing for multiple batches). For now we return 1 + // to indicate that the activations must be transformed as a single block. + return 1; // TODO _tiles_M * _n_batches; + } + + template + template + void WinogradGEMM::InputTransform::run( + const unsigned int start, const unsigned int stop + ) + { + // TODO When the input transform supports multithreading call execute for a + // portion of the tile rows. + (void) start; + (void) stop; + + // For now, just do all of the work. + const Tensor4DShape input_shape = { + _n_batches, _n_rows, _n_cols, _n_channels, NHWC + }; + execute( + _inptr, input_shape, _padding_type, _tiles_M, _tiles_N, _outptr, + _matrix_stride, _matrix_row_stride * _tiles_M * _tiles_N, _matrix_row_stride + ); + } +} diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp new file mode 100644 index 0000000000..4b54dfdf08 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "winograd_gemm.hpp" +using namespace winograd; + + +template +template +WinogradGEMM::WeightsTransform::WeightsTransform( + const T* const input, + T* const output, + const int matrix_stride, /** Stride across matrices in the output. */ + const int matrix_row_stride, /** Stride across rows of the matrix. */ + const int n_output_channels, + const int n_input_channels +) : inptr(input), outptr(output), + matrix_stride(matrix_stride), matrix_row_stride(matrix_row_stride), + n_output_channels(n_output_channels), n_input_channels(n_input_channels) +{ +} + + +template +template +unsigned int WinogradGEMM::WeightsTransform::get_window() const +{ + // TODO When the weights transform supports multithreading, return the number + // of output channels. For now we return 1 to indicate that the weights must + // be transformed as a single block. + // return n_output_channels; + return 1; +} + + +template +template +void WinogradGEMM::WeightsTransform::run( + const unsigned int start, const unsigned int stop +) +{ + // TODO When the weights transform supports multithreading call execute for a + // portion of the output channels. + (void) start; + (void) stop; + + // For now, just do all of the work. + execute( + n_output_channels, + n_input_channels, + inptr, + outptr, + matrix_stride, + matrix_row_stride + ); +} diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp new file mode 100644 index 0000000000..7fa5ee9617 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once +#include "../winograd_gemm.hpp" + +namespace winograd +{ + template + template + void WinogradGEMM::OutputTransform::execute( + const Tensor4DShape &output_shape, + const T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output + ) + { + // Compute the number of tiles and hence the padding required on the bottom + // and right of the image. + const int tile_M = iceildiv(output_shape.n_rows, output_tile_rows); + const int tile_N = iceildiv(output_shape.n_cols, output_tile_cols); + const int pad_bottom = output_tile_rows*tile_M - output_shape.n_rows; + const int pad_right = output_tile_cols*tile_N - output_shape.n_cols; + + const int matrix_tile_row_stride = tile_N * matrix_row_stride; + const int matrix_batch_stride = tile_M * matrix_tile_row_stride; + const int output_col_stride = output_shape.n_channels; + const int output_row_stride = output_shape.n_cols * output_col_stride; + const int output_batch_stride = output_shape.n_rows * output_row_stride; + + // Perform the output transformation for each batch + for (int batch = 0; batch < output_shape.n_batches; batch++) + { + // Get batch offset for input and outputs. + const T* const matrix_batch = matrix_base + batch*matrix_batch_stride; + T* const outptr_batch = output + batch*output_batch_stride; + + // Perform the output transformation for each row of the output tensor. + for (int tile_i = 0; tile_i < tile_M; tile_i++) + { + // Compute properties of this row of output tiles + const int row_pad_bottom = (tile_i < tile_M - 1) ? 0: pad_bottom; + const T* const matrix_tile_row = matrix_batch + tile_i * matrix_tile_row_stride; + T* const outptr_row = outptr_batch + output_tile_rows*tile_i*output_row_stride; + + // Process the row + process_tile_row( + tile_N, output_shape.n_channels, matrix_tile_row, matrix_stride, + matrix_row_stride, outptr_row, output_row_stride, + output_col_stride, row_pad_bottom, pad_right + ); + } + } + } + + template + template + void WinogradGEMM::OutputTransform::process_tile_row( + const int tile_N, + const int n_channels, + const T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output, + const int output_row_stride, + const int output_col_stride, + const int row_pad_bottom, + const int row_pad_right + ) + { + // Loop over columns of tiles + for (int tile_j = 0; tile_j < tile_N; tile_j++) + { + // Properties of this tile + const int tile_pad_right = (tile_j < tile_N - 1) ? 0 : row_pad_right; + const T* const matrix_row = matrix_base + tile_j * matrix_row_stride; + T* const outptr = output + output_tile_cols*tile_j*output_col_stride; + + // Perform the output transformation + tile_fns[row_pad_bottom][tile_pad_right]( + n_channels, matrix_row, matrix_stride, + outptr, output_row_stride, output_col_stride + ); + } + } + + template + template + size_t WinogradGEMM::OutputTransform::bytes_read(const Tensor4DShape &shape) + { + const int M = iceildiv(shape.n_rows, output_tile_rows) * + iceildiv(shape.n_cols, output_tile_cols); + const int N = shape.n_channels; + return inner_tile_rows * inner_tile_cols * M * N * sizeof(T); + } + + template + template + size_t WinogradGEMM::OutputTransform::bytes_written(const Tensor4DShape &shape) + { + return shape.size() * sizeof(T); + } + + template + template + WinogradGEMM::OutputTransform::OutputTransform( + const T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output, + const int n_batches, + const int n_rows, + const int n_cols, + const int n_channels + ) : _matrix_base(matrix_base), _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride), + _outptr(output), _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels), + _tile_M(iceildiv(n_rows, output_tile_rows)), _tile_N(iceildiv(n_cols, output_tile_cols)) + { + } + + template + template + unsigned int WinogradGEMM::OutputTransform::get_window() const + { + // TODO When the output transform supports multithreading, return the total + // number of tile rows (allowing for multiple batches). For now we return 1 + // to indicate that the activations must be transformed as a single block. + return 1; // TODO _tile_M * _n_batches; + } + + template + template + void WinogradGEMM::OutputTransform::run( + const unsigned int start, const unsigned int stop + ) + { + // TODO When the output transform supports multithreading call execute for a + // portion of the tile rows. + (void) start; + (void) stop; + + // For now, just do all of the work. + const Tensor4DShape output_shape = { + _n_batches, _n_rows, _n_cols, _n_channels, NHWC + }; + execute( + output_shape, _matrix_base, _matrix_stride, _matrix_row_stride, _outptr + ); + } +} // namespace winograd diff --git a/arm_compute/core/NEON/kernels/winograd/utils.hpp b/arm_compute/core/NEON/kernels/winograd/utils.hpp new file mode 100644 index 0000000000..d8b9c3b7d3 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/utils.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +double TimeInUs(void); +void PrintMatrix(const float* const m, const int M, const int N, const int row_stride); + +inline int iceildiv(const int a, const int b) { + return (a + b - 1) / b; +} + +template +inline T roundup(const T a, const T b) { + return a + b - (a % b); +} diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp new file mode 100644 index 0000000000..adca48a6d6 --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp @@ -0,0 +1,441 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include "alloc.hpp" +#include "convolution.hpp" +#include "gemm.hpp" +#include "profiler.hpp" +#include "shims.hpp" +#include "tensor.hpp" +#include "utils.hpp" + +#include +#include +#include + +// Generic Winograd implementation using GEMM +namespace winograd +{ + +template +class WinogradGEMM +{ + public: + // Information about the specific Winograd instance + static constexpr int output_tile_rows = OutputTileRows; + static constexpr int output_tile_cols = OutputTileCols; + static constexpr int kernel_rows = KernelRows; + static constexpr int kernel_cols = KernelCols; + static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1; // TODO Check + static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1; // TODO Check + static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols; + + /** Transform weights from the spatial to the Winograd domain. */ + template + struct WeightsTransform + { + /** Get the bytes read during the transform. */ + static inline size_t bytes_read(const KernelShape &shape) + { + return shape.size() * sizeof(T); + } + + /** Get the bytes written during the transform. */ + static inline size_t bytes_written(const KernelShape &shape) + { + const int inner_tile_size = inner_tile_rows * inner_tile_cols; + return (inner_tile_size * shape.n_input_channels * + shape.n_output_channels * sizeof(T)); + } + + /** Get the count of operations performed by the transform. */ + static int ops_performed(const KernelShape &shape); + + /** Apply the transform to a tensor. */ + static void execute( + const int n_output_channels, + const int n_input_channels, + const T* const input, + T* const output, + const int matrix_stride, + const int matrix_row_stride + ); + + /** Create a WeightsTransform operator fixed on a given problem and set + * of pointers. + */ + WeightsTransform( + const T* const input, + T* const output, + const int matrix_stride, /** Stride across matrices in the output. */ + const int matrix_row_stride, /** Stride across rows of the matrix. */ + const int n_output_channels, /** Number of filters. */ + const int n_input_channels /** Number of channels in each filter. */ + ); + + /** Get the window of work a given operator can perform. */ + unsigned int get_window() const; + + /** Perform work upon a window of the input. */ + void run(const unsigned int start, const unsigned int stop); + + private: + const T* const inptr; /** Fixed pointer to input data. */ + T* const outptr; /** Fixed pointer to output memory. */ + const int matrix_stride; /** Stride between output matrices. */ + const int matrix_row_stride; /** Stride within output matrices. */ + const int n_output_channels; /** Number of filters. */ + const int n_input_channels; /** Number of channels in each filter. */ + }; + + /** Transform input feature maps from the spatial to the Winograd domain. + */ + template + struct InputTransform + { + /** Get the bytes read during the transform. */ + static size_t bytes_read(const Tensor4DShape &shape) + { + return shape.size() * sizeof(T); + } + + /** Get the bytes written during the transform. */ + static size_t bytes_written(const Tensor4DShape &shape) + { + const int M = iceildiv(shape.n_rows, inner_tile_rows) * + iceildiv(shape.n_cols, inner_tile_cols); + const int K = shape.n_channels; + return inner_tile_rows * inner_tile_cols * M * K * sizeof(T); + } + + /** Get the count of operations performed by the transform. */ + static int ops_performed(const Tensor4DShape &shape); + + /** Apply the transform to a tensor. */ + static void execute( + const T *inptr, + const Tensor4DShape& input_shape, + const PaddingType padding_type, + const int tile_M, + const int tile_N, + T *outptr_base, + const int matrix_stride, + const int matrix_batch_stride, + const int matrix_row_stride + ); + + /***********************************************************************/ + /** Create an InputTransform operator fixed on a given problem and set of + * pointers. + */ + InputTransform( + const T* const input, /** Input tensor data */ + const int n_batches, /** Number of batches in input tensor. */ + const int n_rows, /** Number of rows in input tensor. */ + const int n_cols, /** Number of columns in input tensor. */ + const int n_channels, /** Number of channels in input tensor. */ + const PaddingType padding, /** Padding type. */ + T* const output, /** Base of output matrices. */ + const int matrix_stride, /** Stride between output matrices. */ + const int matrix_row_stride /** Stride within matrices. */ + ); + + /** Get the winodw of work a given operator can perform. */ + unsigned int get_window() const; + + /** Perform work upon a window of the input. */ + void run(const unsigned int start, const unsigned int stop); + /***********************************************************************/ + + private: + static void process_tile_row( + const int tile_N, + int n_channels, + const T* const input_base, + const int input_row_stride, + const int input_col_stride, + T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + const int row_pad_top, + const int row_pad_left, + const int row_pad_bottom, + const int row_pad_right + ); + + static constexpr int max_pad_bottom = inner_tile_rows - 1; + static constexpr int max_pad_right = inner_tile_cols - 1; + + /** Process a single tile of the input tensor. */ + template + static void process_tile(int, const T*, int, int, T*, int); + + // Array of methods to transform tiles of the input tensor. + typedef void (*TileFn)(int, const T*, int, int, T*, int); + static const TileFn tile_fns[2][2][max_pad_bottom][max_pad_right]; + + /* Member values for instance-based API. */ + const T* const _inptr; + T* const _outptr; + const int _n_batches, _n_rows, _n_cols, _n_channels, _matrix_stride, + _matrix_row_stride, _tiles_M, _tiles_N; + const PaddingType _padding_type; + }; + + /** Transform output feature maps from the Winograd to the spatial domain. + */ + template + struct OutputTransform + { + /** Get the bytes read during the transform. */ + static size_t bytes_read(const Tensor4DShape &shape); + + /** Get the bytes written during the transform. */ + static size_t bytes_written(const Tensor4DShape &shape); + + /** Get the count of operations performed by the transform. */ + static int ops_performed(const Tensor4DShape &shape); + + /** Apply the transform to create a tensor. */ + static void execute( + const Tensor4DShape &output_shape, + const T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output + ); + + /***********************************************************************/ + /** Create an OutputTransform operator fixed on a given problem and set + * of pointers. + */ + OutputTransform( + const T* const matrix_base, /** Pointer to base of matrices. */ + const int matrix_stride, /** Stride between matrices. */ + const int matrix_row_stride, /** Stride within a matrix. */ + T* const output, /** Pointer to output tensor. */ + const int n_batches, /** Number of batches in output tensor. */ + const int n_rows, /** Number of rows in output tensor. */ + const int n_cols, /** Number of columns in output tensor. */ + const int n_channels /** Number of channels in output tensor. */ + ); + + /** Get the window of work a given operator can perform. */ + unsigned int get_window() const; + + /** Perform work upon a window of the input. */ + void run(const unsigned int start, const unsigned int stop); + /***********************************************************************/ + + private: + static void process_tile_row( + const int tile_N, + const int n_channels, + const T* const matrix_base, + const int matrix_stride, + const int matrix_row_stride, + T* const output, + const int output_row_stride, + const int output_col_stride, + const int row_pad_bottom, + const int row_pad_right + ); + + // Limits on the amount of anti-padding to be applied + static constexpr int max_pad_bottom = output_tile_rows; + static constexpr int max_pad_right = output_tile_cols; + + /** Prepare a single tile of the output tensor. */ + template + static void process_tile(int, const T*, int, T*, int, int); + + // Array of methods to produce tiles of output tensor. + typedef void (*TileFn)(int, const T*, int, T*, int, int); + static const TileFn tile_fns[max_pad_bottom][max_pad_right]; + + /** Member constants for instances of the transform. */ + const T* const _matrix_base; + const int _matrix_stride, _matrix_row_stride; + T* const _outptr; + const int _n_batches, _n_rows, _n_cols, _n_channels, _tile_M, _tile_N; + }; + + /** Perform a convolution. + */ + template + class Convolution + { + public: + // Information about the typed Winograd instance + typedef TOut OutputType; + typedef TIn InputType; + + /** Create a new Winograd operator. */ + Convolution( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding, + void *kernel_storage=NULL + ); + + Convolution(const Convolution&) = delete; + Convolution operator=(const Convolution&) = delete; + + /** Create a new Winograd operator and initialise the weights. */ + Convolution( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding, + const TIn* const kernel, + void *kernel_storage=NULL, + void *transform_working_space=NULL + ); + + /** Clean up a convolution engine. */ + ~Convolution(); + + /** Transform the weights into the Winograd domain. */ + template > + void transform_weights( + const TIn* const kernel, + void *transform_working_space=NULL + ); + + /* Apply the Winograd operator to some input. */ + void execute( + TOut* const output, + const TIn* const input, + void* working_space=NULL, + const int n_threads=1 + ); + + /* Apply the Winograd operator to some input. */ + void execute( + TOut* const output, + const TIn* const input, + const int n_threads + ); + + /** Get the output shape of a convolution. */ + static Tensor4DShape get_output_shape( + const KernelShape &kernel_shape, + const Tensor4DShape &in_shape, + const PaddingType padding + ); + + /* Get the memory required to transform the kernel. + */ + static size_t get_kernel_transform_working_size(const KernelShape &shape); + + /** Get the memory required to store the kernel transformed into the + * Winograd domain. + */ + static size_t get_kernel_storage_size(const KernelShape &shape); + + /** Get the memory required to store the input tensor transformed into + * the Winograd domain. + */ + static size_t get_input_storage_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + /** Get the memory required to store the output tensor in the Winograd + * domain. + */ + static size_t get_output_storage_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + /** Get the memory required to apply a Winograd operator to some input. + */ + static size_t get_working_space_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + /* Get the memory required by a single "input" matrix. + */ + static size_t get_input_matrix_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + static int get_input_matrix_stride( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + /* Get the memory required by a single "output" matrix. + */ + static size_t get_output_matrix_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + static int get_output_matrix_stride( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type + ); + + /* Get the memory required by a single "kernel" matrix. + */ + static size_t get_kernel_matrix_size(const KernelShape &shape); + static int get_kernel_matrix_stride(const KernelShape &shape); + + static constexpr int M_BLOCK = 4; /** Size of block used by GEMM. */ + static constexpr int N_BLOCK = 16; /** Size of block used by GEMM. */ + + private: + const KernelShape kernel_shape; /** Shape of the kernel to be applied. */ + TIn *kernel_matrices[N_GEMMS]; /** Pointers into the kernel matrices. */ + const int kernel_matrix_row_stride; /** Stride within the kernel matrices. */ + + const bool manage_kernel_storage; /** Kernel storage is managed by the instance. */ + void* const _kernel_storage; /** Base pointer for kernel storage. */ + + const Tensor4DShape input_shape; /** Shape of the input tensor. */ + const PaddingType padding; /** Padding applied by the operator. */ + + const Tensor4DShape output_shape; /** Output shape produced by the operator. */ + + const int tile_rows; /** Number of rows of tiles. */ + const int tile_cols; /** Number of columns of tiles. */ + const int M, K, N; /** Sizes of underlying fundamental matrix multiplications. */ + + profiler prof; + }; +}; + +} // namespace winograd diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp new file mode 100644 index 0000000000..a3b3db42dd --- /dev/null +++ b/arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include + +#include "batched_blocked_gemm.hpp" +#include "winograd_gemm.hpp" + +/** Example of how to construct an ACL-like interface. + * + * Use `get_weight_storage_size`, `get_input_storage_size` and + * `get_output_storage_size` to allocate memory for the convolution engine. + * Then create a `WinogradConvolutionLayer`. + * + * Initialise the weights using `weights_transform.run(...)`. + * + * For each inference: + * 1. Transform the inputs to the Winograd domain using `input_transform.run(...)` + * 2. Perform a number of GEMMs using `gemms.run(...)` + * 3. Transform the output to the spatial domain using `output_transform.run(...)` + */ +template +class WinogradConvolutionLayer +{ + private: + const KernelShape _kernel_shape; + const Tensor4DShape _input_shape; + const PaddingType _padding; + const Tensor4DShape _output_shape; + const int _n_output_rows, _n_output_cols; + const int _kernel_matrix_stride, _kernel_matrix_row_stride; + const int _input_matrix_stride, _input_matrix_row_stride; + const int _output_matrix_stride, _output_matrix_row_stride; + const int _tile_rows, _tile_cols; + const int _m, _k, _n; + + public: + using WinogradBase = winograd::WinogradGEMM; + using WeightsTransform = typename WinogradBase::template WeightsTransform; + using InputTransform = typename WinogradBase::template InputTransform; + using WinogradConv = typename WinogradBase::template Convolution; + using MultiGEMM = winograd::BatchedBlockedGemm; + using OutputTransform = typename WinogradBase::template OutputTransform; + + /* Public member variables. */ + WeightsTransform weights_transform; /** Operator to transform weights to Winograd domain. */ + InputTransform input_transform; /** Operator to transform input to Winograd domain. */ + MultiGEMM gemms; /** Operator to perform multiple GEMMs. */ + OutputTransform output_transform; /** Operator to transform output from Winograd domain. */ + + /** Determine how much memory (in units of TIn) to allocate for the + * transformed weights. + */ + static unsigned int get_weight_storage_size( + const int n_output_channels, /** Number of output feature maps. */ + const int n_input_channels /** Number of input feature maps. */ + ); + + /** Determine how much memory (in units of TIn) to allocate for the + * transformed input. + */ + static unsigned int get_input_storage_size( + const int n_batches, /** Number of batches in the input tensor. */ + const int n_channels, /** Number of feature maps in the input tensor. */ + const int n_rows, /** Number of rows in each feature map. */ + const int n_cols, /** Number of columns in each feature map. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ + ); + + /** Determine how much memory (in units of TOut) to allocate for the + * (Winograd domain) output. + */ + static unsigned int get_output_storage_size( + const int n_batches, /** Number of batches in the output tensor. */ + const int n_rows, /** Number of rows in each feature map of the input tensor. */ + const int n_cols, /** Number of columns in each feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ + ); + + /** Get the shape (rows, cols) of a feature map of the output tensor. */ + static std::pair get_output_feature_map_shape( + const int n_input_rows, /** Number of rows in the input feature map. */ + const int n_input_cols, /** Number of columns in the input feature map. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ + ); + + /** Create a new Winograd convolution layer. + */ + WinogradConvolutionLayer( + const int n_batches, /** Number of batches in the input and output tensors. */ + const int n_input_channels, /** Number of feature maps in a batch of the input tensor. */ + const int n_input_rows, /** Number of rows in a feature map of the input tensor. */ + const int n_input_cols, /** Number of columns in a feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding, /** Use "SAME" padding, otherwise use "VALID". */ + const TIn* const weights, /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */ + TIn* const weights_storage, /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */ + const TIn* const input, /** Pointer to NHWC ordered input tensor, in the spatial domain. */ + TIn* const winograd_input, /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */ + TOut* const output, /** Pointer to NHWC ordered output tensor, in the spatial domain. */ + TOut* const winograd_output /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */ + ); +}; diff --git a/arm_compute/runtime/NEON/functions/NEWinogradLayer.h b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h index 6fecf082a2..60cdc97469 100644 --- a/arm_compute/runtime/NEON/functions/NEWinogradLayer.h +++ b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -73,7 +73,8 @@ private: CPPPermute _permute_input; CPPPermute _permute_weights; CPPPermute _permute_output; - Tensor _workspace; + Tensor _input_workspace; + Tensor _output_workspace; Tensor _kernel_storage; Tensor _input_nhwc; Tensor _output_nhwc; diff --git a/src/core/NEON/kernels/NEWinogradLayerKernel.cpp b/src/core/NEON/kernels/NEWinogradLayerKernel.cpp index d17630a92e..24d72eddd8 100644 --- a/src/core/NEON/kernels/NEWinogradLayerKernel.cpp +++ b/src/core/NEON/kernels/NEWinogradLayerKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, 2018 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -29,11 +29,11 @@ #include "arm_compute/core/TensorInfo.h" #include "support/ToolchainSupport.h" -#include "src/core/NEON/kernels/winograd/winograd_gemm.hpp" +#include "arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp" namespace { -using T = winograd::Winograd2x2_3x3GEMM; +using T = WinogradConvolutionLayer<2, 2, 3, 3, float, float>; } // namespace namespace arm_compute @@ -41,11 +41,23 @@ namespace arm_compute class Winograd3x3F32::Private { public: - Private(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage) - : convolver(kernel_shape, input_shape, padding_type, kernel_storage) + Private( + const int n_batches, /** Number of batches in the input and output tensors. */ + const int n_input_channels, /** Number of feature maps in a batch of the input tensor. */ + const int n_input_rows, /** Number of rows in a feature map of the input tensor. */ + const int n_input_cols, /** Number of columns in a feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding, /** Use "SAME" padding, otherwise use "VALID". */ + const float *const weights, /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */ + float *const weights_storage, /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */ + const float *const input, /** Pointer to NHWC ordered input tensor, in the spatial domain. */ + float *const winograd_input, /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */ + float *const output, /** Pointer to NHWC ordered output tensor, in the spatial domain. */ + float *const winograd_output /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */ + ) + : convolver(n_batches, n_input_channels, n_input_rows, n_input_cols, n_output_channels, same_padding, weights, weights_storage, input, winograd_input, output, winograd_output) { } - T convolver; }; @@ -53,46 +65,62 @@ Winograd3x3F32::~Winograd3x3F32() { } -void Winograd3x3F32::transform_weights(const void *const kernel, void *transform_working_space) +void Winograd3x3F32::transform_output() { - _pimpl->convolver.transform_weights(reinterpret_cast(kernel), transform_working_space); + auto win = _pimpl->convolver.output_transform.get_window(); + _pimpl->convolver.output_transform.run(0, win); } -void Winograd3x3F32::reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const void *const input, void *working_space) +void Winograd3x3F32::transform_input() { - _pimpl->convolver.reshape_input(input_shape, padding_type, reinterpret_cast(input), working_space); + auto win = _pimpl->convolver.input_transform.get_window(); + _pimpl->convolver.input_transform.run(0, win); } -void Winograd3x3F32::reshape_output(const Tensor4DShape &input_shape, const PaddingType padding_type, void *const output) +void Winograd3x3F32::transform_weights() { -#if defined(__aarch64__) - _pimpl->convolver.reshape_output(input_shape, padding_type, reinterpret_cast(output)); -#else /* __aarch64__ */ - ARM_COMPUTE_UNUSED(input_shape); - ARM_COMPUTE_UNUSED(padding_type); - ARM_COMPUTE_UNUSED(output); - ARM_COMPUTE_ERROR("Not implemented"); -#endif /* __aarch64__ */ + auto win = _pimpl->convolver.weights_transform.get_window(); + _pimpl->convolver.weights_transform.run(0, win); } -Winograd3x3F32::Winograd3x3F32(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage) - : _pimpl(support::cpp14::make_unique(kernel_shape, input_shape, padding_type, kernel_storage)) +Winograd3x3F32::Winograd3x3F32( + const int n_batches, /** Number of batches in the input and output tensors. */ + const int n_input_channels, /** Number of feature maps in a batch of the input tensor. */ + const int n_input_rows, /** Number of rows in a feature map of the input tensor. */ + const int n_input_cols, /** Number of columns in a feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding, /** Use "SAME" padding, otherwise use "VALID". */ + const float *const weights, /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */ + float *const weights_storage, /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */ + const float *const input, /** Pointer to NHWC ordered input tensor, in the spatial domain. */ + float *const winograd_input, /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */ + float *const output, /** Pointer to NHWC ordered output tensor, in the spatial domain. */ + float *const winograd_output /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */ +) + : _pimpl(support::cpp14::make_unique(n_batches, n_input_channels, n_input_rows, n_input_cols, n_output_channels, same_padding, weights, weights_storage, input, winograd_input, output, + winograd_output)) { } -size_t NEWinogradLayerKernel::get_kernel_storage_size(const KernelShape &shape) +unsigned int NEWinogradLayerKernel::get_input_storage_size(const int n_batches, const int n_channels, const int n_rows, const int n_cols, const bool same_padding) { - return T::get_kernel_storage_size(shape); + return T::get_input_storage_size(n_batches, n_channels, n_rows, n_cols, same_padding); } -size_t NEWinogradLayerKernel::get_working_space_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, const PaddingType padding) +unsigned int NEWinogradLayerKernel::get_output_storage_size( + const int n_batches, /** Number of batches in the output tensor. */ + const int n_rows, /** Number of rows in each feature map of the input tensor. */ + const int n_cols, /** Number of columns in each feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ +) { - return T::get_working_space_size(input_shape, k_shape, padding); + return T::get_output_storage_size(n_batches, n_rows, n_cols, n_output_channels, same_padding); } -size_t NEWinogradLayerKernel::get_kernel_transform_working_size(const KernelShape &shape) +size_t NEWinogradLayerKernel::get_weight_storage_size(const int n_output_channels, const int n_input_channels) { - return T::get_kernel_transform_working_size(shape); + return T::get_weight_storage_size(n_output_channels, n_input_channels); } NEWinogradLayerKernel::NEWinogradLayerKernel() @@ -105,7 +133,8 @@ void NEWinogradLayerKernel::configure(Winograd3x3F32 *convolver) ARM_COMPUTE_ERROR_ON_NULLPTR(convolver); _convolver = convolver; Window win; - win.set(Window::DimX, Window::Dimension(0, 15, 1)); + auto win_last = _convolver->_pimpl->convolver.gemms.get_window(); + win.set(Window::DimX, Window::Dimension(0, win_last, 1)); INEKernel::configure(win); } @@ -115,6 +144,6 @@ void NEWinogradLayerKernel::run(const Window &window, const ThreadInfo &info) ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); const size_t first_gemm = window.x().start(); const size_t last_gemm = window.x().end(); - _convolver->_pimpl->convolver.execute(first_gemm, last_gemm); + _convolver->_pimpl->convolver.gemms.run(first_gemm, last_gemm); } } // namespace arm_compute diff --git a/src/core/NEON/kernels/winograd/batched_blocked_gemm.cpp b/src/core/NEON/kernels/winograd/batched_blocked_gemm.cpp new file mode 100644 index 0000000000..52c2db866a --- /dev/null +++ b/src/core/NEON/kernels/winograd/batched_blocked_gemm.cpp @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "batched_blocked_gemm.hpp" +#include "gemm.hpp" +using namespace winograd; + +template +BatchedBlockedGemm::BatchedBlockedGemm( + const unsigned int n_gemms, + const int M, const int K, const int N, + const int a_matrix_stride, + const int a_row_stride, + const int b_matrix_stride, + const int b_row_stride, + const int c_matrix_stride, + const int c_row_stride, + const TIn* const a_ptr, + const TIn* const b_ptr, + TOut* const c_ptr +) : n_gemms(n_gemms), M(M), N(N), K(K), + a_matrix_stride(a_matrix_stride), + a_row_stride(a_row_stride), + b_matrix_stride(b_matrix_stride), + b_row_stride(b_row_stride), + c_matrix_stride(c_matrix_stride), + c_row_stride(c_row_stride), + a_ptr(a_ptr), b_ptr(b_ptr), c_ptr(c_ptr) +{ +} + +template +unsigned int BatchedBlockedGemm::get_window() const +{ + return n_gemms; +} + +template +void BatchedBlockedGemm::run( + const unsigned int start, const unsigned int stop +) +{ + // Perform the specified GEMMs + for (unsigned int i = start; i < stop; i++) + { + // Get pointers to the relevant matrices + const TIn* const mtr_a = a_ptr + i*a_matrix_stride; + const TIn* const mtr_b = b_ptr + i*b_matrix_stride; + TOut* const mtr_c = c_ptr + i*c_matrix_stride; + + // Perform the GEMM + BlockedGemm( + mtr_a, mtr_b, mtr_c, M, K, N, + a_row_stride, b_row_stride, c_row_stride + ); + } +} + +template class winograd::BatchedBlockedGemm<4, 16, float, float>; + diff --git a/src/core/NEON/kernels/winograd/gemm.hpp b/src/core/NEON/kernels/winograd/gemm.hpp deleted file mode 100644 index 111e19602a..0000000000 --- a/src/core/NEON/kernels/winograd/gemm.hpp +++ /dev/null @@ -1,127 +0,0 @@ - -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include "utils.hpp" - -template -void Gemm(const TIn* const a, const TIn* const b, TOut *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride, - const bool a_transposed=false, - const bool b_transposed=false) { - // Array access methods - const auto A = [a, a_transposed, M, K, a_row_stride] (const int i, const int j) -> TIn { - return a[(!a_transposed) ? i*a_row_stride + j : i + j*M]; - }; - - const auto B = [b, b_transposed, K, N, b_row_stride] (const int i, const int j) -> TIn { - return b[(!b_transposed) ? i*b_row_stride + j : i + j*N]; - }; - - const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& { - return c[i*c_row_stride + j]; - }; - - // Perform the matrix multiplication - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - for (int k = 0; k < K; k++) { - C(i, j) += A(i, k) * B(k, j); - } - } - } -} - -template -void BlockedGemm( - const TIn* const a, const TIn* const b, TOut *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - // Array access methods - const auto A = [a, a_row_stride] (const int i, const int j) -> TIn { - return a[i*a_row_stride + j]; - }; - - const auto B = [b, b_row_stride] (const int i, const int j) -> TIn { - return b[i*b_row_stride + j]; - }; - - const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& { - return c[i*c_row_stride + j]; - }; - - const int M_BLOCKS = iceildiv(M, M_BLOCK); - const int N_BLOCKS = iceildiv(N, N_BLOCK); - - // For each block of output rows - for (int mblock = 0; mblock < M_BLOCKS; mblock++) { - // For each block of output columns - for (int nblock = 0; nblock < N_BLOCKS; nblock++) { - // Create an appropriately sized block of accumulators - TOut accum[M_BLOCK][N_BLOCK]; - for (int i = 0; i < M_BLOCK; i++) { - for (int j = 0; j < N_BLOCK; j++) { - accum[i][j] = static_cast(0); - } - } - - // Perform this portion of the matrix multiply - for (int k = 0; k < K; k++) { - // Load elements of A - TIn elems_a[M_BLOCK]; - for (int i = 0; i < M_BLOCK; i++) { - elems_a[i] = A(mblock*M_BLOCK + i, k); - } - - // Load elements of B - TIn elems_b[N_BLOCK]; - for (int j = 0; j < N_BLOCK; j++) { - elems_b[j] = B(k, nblock*N_BLOCK + j); - } - - // Perform the partial matrix multiply - for (int i = 0; i < M_BLOCK; i++) { - for (int j = 0; j < N_BLOCK; j++) { - accum[i][j] += elems_a[i] * elems_b[j]; - } - } - } - - // Store the partial product - for (int i = 0; i < M_BLOCK; i++) { - for (int j = 0; j < N_BLOCK; j++) { - C(mblock*M_BLOCK + i, nblock*N_BLOCK + j) = accum[i][j]; - } - } - } - } -} - -#include "gemm/a64_sgemm.hpp" diff --git a/src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp b/src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp deleted file mode 100644 index e1b7488c31..0000000000 --- a/src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp +++ /dev/null @@ -1,355 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include -#include "../utils.hpp" - -#ifdef __aarch64__ - -template <> -inline void BlockedGemm<8, 12, float, float>( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - const int M_BLOCK = 8; - const int N_BLOCK = 12; - - const int m_blocks = iceildiv(M, M_BLOCK); - const int n_blocks = iceildiv(N, N_BLOCK); - - // For each block of output rows - for (int mblock = 0; mblock < m_blocks; mblock++) { - // For each block of output columns - for (int nblock = 0; nblock < n_blocks; nblock++) { - const float *aptr = a + mblock*M_BLOCK*a_row_stride; - const float *bptr = b + nblock*N_BLOCK; - float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; - int k = K; - - asm volatile ( - // Create an 8x12 block of accumulators - " A_1 .req v27\n" - "sA_1 .req s27\n" - " A_2 .req v28\n" - "sA_2 .req s28\n" - " A_3 .req v29\n" - "sA_3 .req s29\n" - " A_4 .req v30\n" - "sA_4 .req s30\n" - - " B_1 .req v24\n" " B_2 .req v25\n" " B_3 .req v26\n" - "qB_1 .req q24\n" "qB_2 .req q25\n" "qB_3 .req q26\n" - - " C_11 .req v0\n" " C_12 .req v1\n" " C_13 .req v2\n" - " C_21 .req v3\n" " C_22 .req v4\n" " C_23 .req v5\n" - " C_31 .req v6\n" " C_32 .req v7\n" " C_33 .req v8\n" - " C_41 .req v9\n" " C_42 .req v10\n" " C_43 .req v11\n" - " C_51 .req v12\n" " C_52 .req v13\n" " C_53 .req v14\n" - " C_61 .req v15\n" " C_62 .req v16\n" " C_63 .req v17\n" - " C_71 .req v18\n" " C_72 .req v19\n" " C_73 .req v20\n" - " C_81 .req v21\n" " C_82 .req v22\n" " C_83 .req v23\n" - - "qC_11 .req q0\n" "qC_12 .req q1\n" "qC_13 .req q2\n" - "qC_21 .req q3\n" "qC_22 .req q4\n" "qC_23 .req q5\n" - "qC_31 .req q6\n" "qC_32 .req q7\n" "qC_33 .req q8\n" - "qC_41 .req q9\n" "qC_42 .req q10\n" "qC_43 .req q11\n" - "qC_51 .req q12\n" "qC_52 .req q13\n" "qC_53 .req q14\n" - "qC_61 .req q15\n" "qC_62 .req q16\n" "qC_63 .req q17\n" - "qC_71 .req q18\n" "qC_72 .req q19\n" "qC_73 .req q20\n" - "qC_81 .req q21\n" "qC_82 .req q22\n" "qC_83 .req q23\n" - - "aptr1 .req x17\n" - "aptr2 .req x18\n" - "aptr3 .req x19\n" - "aptr4 .req x20\n" - "aptr5 .req x21\n" - "aptr6 .req x22\n" - "aptr7 .req x23\n" - - // Initialise accumulators with 0 - // Initialise pointers - "movi C_11.4s, #0\n" - "add aptr1, %x[aptr], %x[a_row_stride]\n" - "movi C_12.4s, #0\n" - "add aptr2, aptr1, %x[a_row_stride]\n" - "movi C_13.4s, #0\n" - "add aptr3, aptr2, %x[a_row_stride]\n" - "movi C_21.4s, #0\n" - "add aptr4, aptr3, %x[a_row_stride]\n" - "movi C_22.4s, #0\n" - "add aptr5, aptr4, %x[a_row_stride]\n" - "movi C_23.4s, #0\n" - "add aptr6, aptr5, %x[a_row_stride]\n" - "movi C_31.4s, #0\n" - "add aptr7, aptr6, %x[a_row_stride]\n" - "movi C_32.4s, #0\n" - "ldr qB_1, [%x[bptr]]\n" - "movi C_33.4s, #0\n" - "ldr qB_2, [%x[bptr], #0x10]\n" - "movi C_41.4s, #0\n" - "prfm pldl1keep, [%x[bptr], #0x00]\n" - "movi C_42.4s, #0\n" - "prfm pldl1keep, [%x[bptr], #0x10]\n" - "movi C_43.4s, #0\n" - "prfm pldl1keep, [%x[bptr], #0x20]\n" - "movi C_51.4s, #0\n" - "prfm pldl1keep, [%x[aptr], #0x00]\n" - "movi C_52.4s, #0\n" - "prfm pldl1keep, [ aptr1, #0x00]\n" - "movi C_53.4s, #0\n" - "prfm pldl1keep, [ aptr2, #0x00]\n" - "movi C_61.4s, #0\n" - "prfm pldl1keep, [ aptr3, #0x00]\n" - "movi C_62.4s, #0\n" - "prfm pldl1keep, [ aptr4, #0x00]\n" - "movi C_63.4s, #0\n" - "prfm pldl1keep, [ aptr5, #0x00]\n" - "movi C_71.4s, #0\n" - "prfm pldl1keep, [ aptr6, #0x00]\n" - "movi C_72.4s, #0\n" - "prfm pldl1keep, [ aptr7, #0x00]\n" - "movi C_73.4s, #0\n" - "ldr sA_1, [%x[aptr]], #0x4\n" - "movi C_81.4s, #0\n" - "ldr sA_2, [ aptr1], #0x4\n" - "movi C_82.4s, #0\n" - "ldr sA_3, [ aptr2], #0x4\n" - "movi C_83.4s, #0\n" - "subs %x[k], %x[k], #1\n" - "beq 2f\n" - - "1:" - "fmla C_11.4s, B_1.4s, A_1.s[0]\n" - "ldr qB_3, [%x[bptr], #0x20]\n" - "fmla C_12.4s, B_2.4s, A_1.s[0]\n" - "ldr sA_4, [ aptr3], #0x4\n" - "fmla C_13.4s, B_3.4s, A_1.s[0]\n" - "ldr sA_1, [ aptr4], #0x04\n" - - "fmla C_21.4s, B_1.4s, A_2.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride]\n" - "fmla C_22.4s, B_2.4s, A_2.s[0]\n" - "prfm pldl1keep, [ aptr3, #0x10]\n" - "fmla C_23.4s, B_3.4s, A_2.s[0]\n" - "ldr sA_2, [ aptr5], #0x04\n" - - "fmla C_31.4s, B_1.4s, A_3.s[0]\n" - "prfm pldl1keep, [%x[bptr], #0x00]\n" - "fmla C_32.4s, B_2.4s, A_3.s[0]\n" - "prfm pldl1keep, [%x[bptr], #0x10]\n" - "fmla C_33.4s, B_3.4s, A_3.s[0]\n" - "ldr sA_3, [ aptr6], #0x04\n" - - "fmla C_41.4s, B_1.4s, A_4.s[0]\n" - "prfm pldl1keep, [%x[bptr], #0x20]\n" - "fmla C_42.4s, B_2.4s, A_4.s[0]\n" - "prfm pldl1keep, [ aptr4, #0x10]\n" - "fmla C_43.4s, B_3.4s, A_4.s[0]\n" - "ldr sA_4, [ aptr7], #0x04\n" - - "fmla C_51.4s, B_1.4s, A_1.s[0]\n" - "prfm pldl1keep, [ aptr5, #0x10]\n" - "fmla C_52.4s, B_2.4s, A_1.s[0]\n" - "prfm pldl1keep, [ aptr6, #0x10]\n" - "fmla C_53.4s, B_3.4s, A_1.s[0]\n" - "ldr sA_1, [%x[aptr]], #0x04\n" - - "fmla C_61.4s, B_1.4s, A_2.s[0]\n" - "prfm pldl1keep, [ aptr7, #0x10]\n" - "fmla C_62.4s, B_2.4s, A_2.s[0]\n" - "subs %x[k], %x[k], #1\n" - "fmla C_63.4s, B_3.4s, A_2.s[0]\n" - "ldr sA_2, [ aptr1], #0x04\n" - - "fmla C_71.4s, B_1.4s, A_3.s[0]\n" - "prfm pldl1keep, [%x[aptr], #0x10]\n" - "fmla C_72.4s, B_2.4s, A_3.s[0]\n" - "prfm pldl1keep, [ aptr1, #0x10]\n" - "fmla C_73.4s, B_3.4s, A_3.s[0]\n" - "ldr sA_3, [ aptr2], #0x04\n" - - "fmla C_81.4s, B_1.4s, A_4.s[0]\n" - "prfm pldl1keep, [ aptr2, #0x10]\n" - "fmla C_82.4s, B_2.4s, A_4.s[0]\n" - "ldp qB_1, qB_2, [%x[bptr]]\n" - "fmla C_83.4s, B_3.4s, A_4.s[0]\n" - "bne 1b\n" - - "2:" - "fmla C_11.4s, B_1.4s, A_1.s[0]\n" - "ldr qB_3, [%x[bptr], #0x20]\n" - "fmla C_12.4s, B_2.4s, A_1.s[0]\n" - "stp qC_11, qC_12, [%x[cptr]]\n" - "fmla C_13.4s, B_3.4s, A_1.s[0]\n" - "str qC_13, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - "ldr sA_1, [ aptr4], #0x04\n" - - "fmla C_21.4s, B_1.4s, A_2.s[0]\n" - "ldr sA_4, [ aptr3], #0x4\n" - "fmla C_22.4s, B_2.4s, A_2.s[0]\n" - "stp qC_21, qC_22, [%x[cptr]]\n" - "fmla C_23.4s, B_3.4s, A_2.s[0]\n" - "str qC_23, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - "ldr sA_2, [ aptr5], #0x04\n" - - "fmla C_31.4s, B_1.4s, A_3.s[0]\n" - "fmla C_32.4s, B_2.4s, A_3.s[0]\n" - "stp qC_31, qC_32, [%x[cptr]]\n" - "fmla C_33.4s, B_3.4s, A_3.s[0]\n" - "str qC_33, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - "ldr sA_3, [ aptr6], #0x04\n" - - "fmla C_41.4s, B_1.4s, A_4.s[0]\n" - "fmla C_42.4s, B_2.4s, A_4.s[0]\n" - "stp qC_41, qC_42, [%x[cptr]]\n" - "fmla C_43.4s, B_3.4s, A_4.s[0]\n" - "str qC_43, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - "ldr sA_4, [ aptr7], #0x04\n" - - "fmla C_51.4s, B_1.4s, A_1.s[0]\n" - "fmla C_52.4s, B_2.4s, A_1.s[0]\n" - "stp qC_51, qC_52, [%x[cptr]]\n" - "fmla C_53.4s, B_3.4s, A_1.s[0]\n" - "str qC_53, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - - "fmla C_61.4s, B_1.4s, A_2.s[0]\n" - "fmla C_62.4s, B_2.4s, A_2.s[0]\n" - "stp qC_61, qC_62, [%x[cptr]]\n" - "fmla C_63.4s, B_3.4s, A_2.s[0]\n" - "str qC_63, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - - "fmla C_71.4s, B_1.4s, A_3.s[0]\n" - "fmla C_72.4s, B_2.4s, A_3.s[0]\n" - "stp qC_71, qC_72, [%x[cptr]]\n" - "fmla C_73.4s, B_3.4s, A_3.s[0]\n" - "str qC_73, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - - "fmla C_81.4s, B_1.4s, A_4.s[0]\n" - "fmla C_82.4s, B_2.4s, A_4.s[0]\n" - "stp qC_81, qC_82, [%x[cptr]]\n" - "fmla C_83.4s, B_3.4s, A_4.s[0]\n" - "str qC_83, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride]\n" - - // Clear aliases - ".unreq aptr1\n" - ".unreq aptr2\n" - ".unreq aptr3\n" - ".unreq aptr4\n" - ".unreq aptr5\n" - ".unreq aptr6\n" - ".unreq aptr7\n" - - ".unreq A_1\n" ".unreq A_2\n" ".unreq A_3\n" ".unreq A_4\n" - ".unreq sA_1\n" ".unreq sA_2\n" ".unreq sA_3\n" ".unreq sA_4\n" - - ".unreq B_1\n" ".unreq B_2\n" ".unreq B_3\n" - ".unreq qB_1\n" ".unreq qB_2\n" ".unreq qB_3\n" - - ".unreq C_11\n" ".unreq C_12\n" ".unreq C_13\n" - ".unreq C_21\n" ".unreq C_22\n" ".unreq C_23\n" - ".unreq C_31\n" ".unreq C_32\n" ".unreq C_33\n" - ".unreq C_41\n" ".unreq C_42\n" ".unreq C_43\n" - ".unreq C_51\n" ".unreq C_52\n" ".unreq C_53\n" - ".unreq C_61\n" ".unreq C_62\n" ".unreq C_63\n" - ".unreq C_71\n" ".unreq C_72\n" ".unreq C_73\n" - ".unreq C_81\n" ".unreq C_82\n" ".unreq C_83\n" - - ".unreq qC_11\n" ".unreq qC_12\n" ".unreq qC_13\n" - ".unreq qC_21\n" ".unreq qC_22\n" ".unreq qC_23\n" - ".unreq qC_31\n" ".unreq qC_32\n" ".unreq qC_33\n" - ".unreq qC_41\n" ".unreq qC_42\n" ".unreq qC_43\n" - ".unreq qC_51\n" ".unreq qC_52\n" ".unreq qC_53\n" - ".unreq qC_61\n" ".unreq qC_62\n" ".unreq qC_63\n" - ".unreq qC_71\n" ".unreq qC_72\n" ".unreq qC_73\n" - ".unreq qC_81\n" ".unreq qC_82\n" ".unreq qC_83\n" - : [aptr] "+r" (aptr), - [bptr] "+r" (bptr), - [cptr] "+r" (cptr), - [k] "+r" (k) - : [a_row_stride] "r" (a_row_stride * sizeof(float)), - [b_row_stride] "r" (b_row_stride * sizeof(float)), - [c_row_stride] "r" (c_row_stride * sizeof(float)) - : "cc", "memory", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "x17", "x18", "x19", "x20", "x21", "x22", "x23" - ); - } - } -} - -/*****************************************************************************/ -/* 4x16 blocked GEMM with specialised tails - */ -#include "a64_sgemm_4x16.hpp" - -template <> -inline void BlockedGemm<4, 16, float, float>( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - // Despatch based on tail of K - switch (K % 4) { - case 3: - sgemm_4x16_impl<3>( - a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride - ); - break; - case 2: - sgemm_4x16_impl<2>( - a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride - ); - break; - case 1: - sgemm_4x16_impl<1>( - a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride - ); - break; - case 0: - sgemm_4x16_impl<0>( - a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride - ); - break; - default: - assert(0); - break; - } -} - -#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp b/src/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp deleted file mode 100644 index e74610ef27..0000000000 --- a/src/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp +++ /dev/null @@ -1,1445 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -template -inline void sgemm_4x16_impl( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -); - -template <> -inline void sgemm_4x16_impl<0>( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - const int TAIL_SIZE = 0; - const int M_BLOCK = 4; - const int N_BLOCK = 16; - - const int m_blocks = iceildiv(M, M_BLOCK); - const int n_blocks = iceildiv(N, N_BLOCK); - - // For each block of output rows - for (int mblock = 0; mblock < m_blocks; mblock++) { - // For each block of output columns - for (int nblock = 0; nblock < n_blocks; nblock++) { - const float *aptr = a + mblock*M_BLOCK*a_row_stride; - const float *bptr = b + nblock*N_BLOCK; - float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; - int k = (K - TAIL_SIZE) / 4; - - asm volatile( - "aptr2 .req X20\n" - "aptr3 .req X21\n" - "aptr4 .req X22\n" - "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" - "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" - "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" - "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" - "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" - "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" - "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" - "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" - "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" - "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" - "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" - "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" - "vB1 .req v20\n" "qB1 .req q20\n" - "vB2 .req v21\n" "qB2 .req q21\n" - "vB3 .req v22\n" "qB3 .req q22\n" - "vB4 .req v23\n" "qB4 .req q23\n" - - // Clear accumulators, initialise pointers - "movi vC11.4s, #0\n" - "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" - "movi vC12.4s, #0\n" - "add aptr3, aptr2, %x[a_row_stride_bytes]\n" - "movi vC13.4s, #0\n" - "add aptr4, aptr3, %x[a_row_stride_bytes]\n" - "movi vC14.4s, #0\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "movi vC21.4s, #0\n" - "ldr qA2, [ aptr2], #0x10\n" - "movi vC22.4s, #0\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "movi vC23.4s, #0\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "movi vC24.4s, #0\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "movi vC31.4s, #0\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "subs %x[k], %x[k], #1\n" - "beq 2f\n" - - "1:" // Loop proper - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "subs %x[k], %x[k], #1\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr qA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - "bne 1b\n" - - "2:" // Tail - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "stp qC11, qC12, [%x[cptr], #0x00]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "stp qC13, qC14, [%x[cptr], #0x20]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "stp qC21, qC22, [%x[cptr], #0x00]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "stp qC23, qC24, [%x[cptr], #0x20]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "stp qC31, qC32, [%x[cptr], #0x00]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "stp qC33, qC34, [%x[cptr], #0x20]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "stp qC41, qC42, [%x[cptr], #0x00]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - "stp qC43, qC44, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - - ".unreq vB4\n" ".unreq qB4\n" - ".unreq vB3\n" ".unreq qB3\n" - ".unreq vB2\n" ".unreq qB2\n" - ".unreq vB1\n" ".unreq qB1\n" - ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" - ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" - ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" - ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" - ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" - ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" - ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" - ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" - ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" - ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" - ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" - ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" - ".unreq aptr2\n" - ".unreq aptr3\n" - ".unreq aptr4\n" - - : [aptr] "+r" (aptr), - [bptr] "+r" (bptr), - [cptr] "+r" (cptr), - [k] "+r" (k) - : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), - [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), - [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) - : "cc", "memory", "x20", "x21", "x22", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", - "v21", "v22", "v23" - ); - } - } -} - -template <> -inline void sgemm_4x16_impl<1>( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - const int TAIL_SIZE = 1; - const int M_BLOCK = 4; - const int N_BLOCK = 16; - - const int m_blocks = iceildiv(M, M_BLOCK); - const int n_blocks = iceildiv(N, N_BLOCK); - - // For each block of output rows - for (int mblock = 0; mblock < m_blocks; mblock++) { - // For each block of output columns - for (int nblock = 0; nblock < n_blocks; nblock++) { - const float *aptr = a + mblock*M_BLOCK*a_row_stride; - const float *bptr = b + nblock*N_BLOCK; - float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; - int k = (K - TAIL_SIZE) / 4; - - asm volatile( - "aptr2 .req X20\n" - "aptr3 .req X21\n" - "aptr4 .req X22\n" - "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" - "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" - "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" - "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" - "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" - "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" - "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" - "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" - "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" - "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" - "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" - "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" - "vB1 .req v20\n" "qB1 .req q20\n" - "vB2 .req v21\n" "qB2 .req q21\n" - "vB3 .req v22\n" "qB3 .req q22\n" - "vB4 .req v23\n" "qB4 .req q23\n" - - // Clear accumulators, initialise pointers - "movi vC11.4s, #0\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "movi vC12.4s, #0\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "movi vC13.4s, #0\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "movi vC14.4s, #0\n" - "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" - "movi vC21.4s, #0\n" - "add aptr3, aptr2, %x[a_row_stride_bytes]\n" - "movi vC22.4s, #0\n" - "add aptr4, aptr3, %x[a_row_stride_bytes]\n" - "movi vC23.4s, #0\n" - "cbnz %x[k], 3f\n" - - // Prepare for tail in K - "movi vC24.4s, #0\n" - "ldr sA1, [%x[aptr]], #0x04\n" - "movi vC31.4s, #0\n" - "ldr sA2, [ aptr2], #0x04\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "b 2f\n" // Jump to tail - - "3:" // Prepare for loop over K - "movi vC24.4s, #0\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "movi vC31.4s, #0\n" - "ldr qA2, [ aptr2], #0x10\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "subs %x[k], %x[k], #1\n" - "beq 4f\n" - - "1:" // Loop proper - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "subs %x[k], %x[k], #1\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr qA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - "bne 1b\n" - - "4:" // Tail iteration - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr sA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr sA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - - "2:" // Common tail - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "stp qC11, qC12, [%x[cptr], #0x00]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "ldr sA3, [ aptr3], #0x10\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "stp qC13, qC14, [%x[cptr], #0x20]\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "stp qC21, qC22, [%x[cptr], #0x00]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "ldr sA4, [ aptr4], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "stp qC23, qC24, [%x[cptr], #0x20]\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "stp qC31, qC32, [%x[cptr], #0x00]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "stp qC33, qC34, [%x[cptr], #0x20]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "stp qC41, qC42, [%x[cptr], #0x00]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - "stp qC43, qC44, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - - ".unreq vB4\n" ".unreq qB4\n" - ".unreq vB3\n" ".unreq qB3\n" - ".unreq vB2\n" ".unreq qB2\n" - ".unreq vB1\n" ".unreq qB1\n" - ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" - ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" - ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" - ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" - ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" - ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" - ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" - ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" - ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" - ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" - ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" - ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" - ".unreq aptr2\n" - ".unreq aptr3\n" - ".unreq aptr4\n" - - : [aptr] "+r" (aptr), - [bptr] "+r" (bptr), - [cptr] "+r" (cptr), - [k] "+r" (k) - : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), - [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), - [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) - : "cc", "memory", "x20", "x21", "x22", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", - "v21", "v22", "v23" - ); - } - } -} - -template <> -inline void sgemm_4x16_impl<2>( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - const int TAIL_SIZE = 2; - const int M_BLOCK = 4; - const int N_BLOCK = 16; - - const int m_blocks = iceildiv(M, M_BLOCK); - const int n_blocks = iceildiv(N, N_BLOCK); - - // For each block of output rows - for (int mblock = 0; mblock < m_blocks; mblock++) { - // For each block of output columns - for (int nblock = 0; nblock < n_blocks; nblock++) { - const float *aptr = a + mblock*M_BLOCK*a_row_stride; - const float *bptr = b + nblock*N_BLOCK; - float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; - int k = (K - TAIL_SIZE) / 4; - - asm volatile( - "aptr2 .req X20\n" - "aptr3 .req X21\n" - "aptr4 .req X22\n" - "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" - "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" - "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" - "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" - "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" - "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" - "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" - "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" - "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" - "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" - "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" - "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" - "vB1 .req v20\n" "qB1 .req q20\n" - "vB2 .req v21\n" "qB2 .req q21\n" - "vB3 .req v22\n" "qB3 .req q22\n" - "vB4 .req v23\n" "qB4 .req q23\n" - - // Clear accumulators, initialise pointers - "movi vC11.4s, #0\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "movi vC12.4s, #0\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "movi vC13.4s, #0\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "movi vC14.4s, #0\n" - "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" - "movi vC21.4s, #0\n" - "add aptr3, aptr2, %x[a_row_stride_bytes]\n" - "movi vC22.4s, #0\n" - "add aptr4, aptr3, %x[a_row_stride_bytes]\n" - "movi vC23.4s, #0\n" - "cbnz %x[k], 3f\n" - - // Prepare for tail in K - "movi vC24.4s, #0\n" - "ldr dA1, [%x[aptr]], #0x08\n" - "movi vC31.4s, #0\n" - "ldr dA2, [ aptr2], #0x08\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "b 2f\n" // Jump to tail - - "3:" // Prepare for loop over K - "movi vC24.4s, #0\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "movi vC31.4s, #0\n" - "ldr qA2, [ aptr2], #0x10\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "subs %x[k], %x[k], #1\n" - "beq 4f\n" - - "1:" // Loop proper - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "subs %x[k], %x[k], #1\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr qA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - "bne 1b\n" - - "4:" // Tail iteration - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr dA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr dA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - - "2:" // Common tail - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr dA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr dA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "stp qC11, qC12, [%x[cptr], #0x00]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "stp qC13, qC14, [%x[cptr], #0x20]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "stp qC21, qC22, [%x[cptr], #0x00]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "stp qC23, qC24, [%x[cptr], #0x20]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "stp qC31, qC32, [%x[cptr], #0x00]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "stp qC33, qC34, [%x[cptr], #0x20]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "stp qC41, qC42, [%x[cptr], #0x00]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - "stp qC43, qC44, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - - ".unreq vB4\n" ".unreq qB4\n" - ".unreq vB3\n" ".unreq qB3\n" - ".unreq vB2\n" ".unreq qB2\n" - ".unreq vB1\n" ".unreq qB1\n" - ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" - ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" - ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" - ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" - ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" - ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" - ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" - ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" - ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" - ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" - ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" - ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" - ".unreq aptr2\n" - ".unreq aptr3\n" - ".unreq aptr4\n" - - : [aptr] "+r" (aptr), - [bptr] "+r" (bptr), - [cptr] "+r" (cptr), - [k] "+r" (k) - : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), - [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), - [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) - : "cc", "memory", "x20", "x21", "x22", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", - "v21", "v22", "v23" - ); - } - } -} - -template <> -inline void sgemm_4x16_impl<3>( - const float* const a, const float* const b, float *c, - const int M, const int K, const int N, - const int a_row_stride, - const int b_row_stride, - const int c_row_stride -) { - const int TAIL_SIZE = 3; - const int M_BLOCK = 4; - const int N_BLOCK = 16; - - const int m_blocks = iceildiv(M, M_BLOCK); - const int n_blocks = iceildiv(N, N_BLOCK); - - // For each block of output rows - for (int mblock = 0; mblock < m_blocks; mblock++) { - // For each block of output columns - for (int nblock = 0; nblock < n_blocks; nblock++) { - const float *aptr = a + mblock*M_BLOCK*a_row_stride; - const float *bptr = b + nblock*N_BLOCK; - float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK; - int k = (K - TAIL_SIZE) / 4; - - asm volatile( - "aptr2 .req X20\n" - "aptr3 .req X21\n" - "aptr4 .req X22\n" - "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n" - "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n" - "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n" - "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n" - "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n" - "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n" - "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n" - "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n" - "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n" - "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n" - "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n" - "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n" - "vB1 .req v20\n" "qB1 .req q20\n" - "vB2 .req v21\n" "qB2 .req q21\n" - "vB3 .req v22\n" "qB3 .req q22\n" - "vB4 .req v23\n" "qB4 .req q23\n" - - // Clear accumulators, initialise pointers - "movi vC11.4s, #0\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "movi vC12.4s, #0\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "movi vC13.4s, #0\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "movi vC14.4s, #0\n" - "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n" - "movi vC21.4s, #0\n" - "add aptr3, aptr2, %x[a_row_stride_bytes]\n" - "movi vC22.4s, #0\n" - "add aptr4, aptr3, %x[a_row_stride_bytes]\n" - "movi vC23.4s, #0\n" - "cbnz %x[k], 3f\n" - - // Prepare for tail in K - "movi vC24.4s, #0\n" - "ldr dA1, [%x[aptr]], #0x08\n" - "movi vC31.4s, #0\n" - "ldr dA2, [ aptr2], #0x08\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "b 2f\n" // Jump to tail - - "3:" // Prepare for loop over K - "movi vC24.4s, #0\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "movi vC31.4s, #0\n" - "ldr qA2, [ aptr2], #0x10\n" - "movi vC32.4s, #0\n" - "movi vC33.4s, #0\n" - "movi vC34.4s, #0\n" - "movi vC41.4s, #0\n" - "movi vC42.4s, #0\n" - "movi vC43.4s, #0\n" - "movi vC44.4s, #0\n" - "subs %x[k], %x[k], #1\n" - "beq 4f\n" - - "1:" // Loop proper - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "subs %x[k], %x[k], #1\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr qA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr qA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - "bne 1b\n" - - "4:" // Tail iteration - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr qA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[2]\n" - "fmla vC21.4s, vB1.4s, vA2.s[2]\n" - "fmla vC31.4s, vB1.4s, vA3.s[2]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[2]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[2]\n" - "fmla vC22.4s, vB2.4s, vA2.s[2]\n" - "fmla vC32.4s, vB2.4s, vA3.s[2]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[2]\n" - "fmla vC13.4s, vB3.4s, vA1.s[2]\n" - "fmla vC23.4s, vB3.4s, vA2.s[2]\n" - "fmla vC33.4s, vB3.4s, vA3.s[2]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[2]\n" - "fmla vC14.4s, vB4.4s, vA1.s[2]\n" - "fmla vC24.4s, vB4.4s, vA2.s[2]\n" - "fmla vC34.4s, vB4.4s, vA3.s[2]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[2]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[3]\n" - "fmla vC21.4s, vB1.4s, vA2.s[3]\n" - "fmla vC31.4s, vB1.4s, vA3.s[3]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[3]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[3]\n" - "fmla vC22.4s, vB2.4s, vA2.s[3]\n" - "fmla vC32.4s, vB2.4s, vA3.s[3]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[3]\n" - "fmla vC13.4s, vB3.4s, vA1.s[3]\n" - "fmla vC23.4s, vB3.4s, vA2.s[3]\n" - "fmla vC33.4s, vB3.4s, vA3.s[3]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[3]\n" - "fmla vC14.4s, vB4.4s, vA1.s[3]\n" - "ldr dA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[3]\n" - "ldr dA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[3]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[3]\n" - - "2:" // Common tail - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr dA3, [ aptr3], #0x10\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "ldr dA4, [ aptr4], #0x10\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[1]\n" - "fmla vC21.4s, vB1.4s, vA2.s[1]\n" - "fmla vC31.4s, vB1.4s, vA3.s[1]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC41.4s, vB1.4s, vA4.s[1]\n" - "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n" - "fmla vC12.4s, vB2.4s, vA1.s[1]\n" - "fmla vC22.4s, vB2.4s, vA2.s[1]\n" - "fmla vC32.4s, vB2.4s, vA3.s[1]\n" - "ldr qB1, [%x[bptr], #0x00]\n" - "fmla vC42.4s, vB2.4s, vA4.s[1]\n" - "fmla vC13.4s, vB3.4s, vA1.s[1]\n" - "fmla vC23.4s, vB3.4s, vA2.s[1]\n" - "fmla vC33.4s, vB3.4s, vA3.s[1]\n" - "ldr qB2, [%x[bptr], #0x10]\n" - "fmla vC43.4s, vB3.4s, vA4.s[1]\n" - "fmla vC14.4s, vB4.4s, vA1.s[1]\n" - "ldr sA1, [%x[aptr]], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[1]\n" - "ldr sA2, [ aptr2], #0x10\n" - "fmla vC34.4s, vB4.4s, vA3.s[1]\n" - "ldr qB3, [%x[bptr], #0x20]\n" - "fmla vC44.4s, vB4.4s, vA4.s[1]\n" - - "fmla vC11.4s, vB1.4s, vA1.s[0]\n" - "ldr qB4, [%x[bptr], #0x30]\n" - "fmla vC12.4s, vB2.4s, vA1.s[0]\n" - "stp qC11, qC12, [%x[cptr], #0x00]\n" - "fmla vC13.4s, vB3.4s, vA1.s[0]\n" - "ldr sA3, [ aptr3], #0x10\n" - "fmla vC14.4s, vB4.4s, vA1.s[0]\n" - "stp qC13, qC14, [%x[cptr], #0x20]\n" - "fmla vC21.4s, vB1.4s, vA2.s[0]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC22.4s, vB2.4s, vA2.s[0]\n" - "stp qC21, qC22, [%x[cptr], #0x00]\n" - "fmla vC23.4s, vB3.4s, vA2.s[0]\n" - "ldr sA4, [ aptr4], #0x10\n" - "fmla vC24.4s, vB4.4s, vA2.s[0]\n" - "stp qC23, qC24, [%x[cptr], #0x20]\n" - "fmla vC31.4s, vB1.4s, vA3.s[0]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC32.4s, vB2.4s, vA3.s[0]\n" - "stp qC31, qC32, [%x[cptr], #0x00]\n" - "fmla vC33.4s, vB3.4s, vA3.s[0]\n" - "fmla vC34.4s, vB4.4s, vA3.s[0]\n" - "stp qC33, qC34, [%x[cptr], #0x20]\n" - "fmla vC41.4s, vB1.4s, vA4.s[0]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - "fmla vC42.4s, vB2.4s, vA4.s[0]\n" - "stp qC41, qC42, [%x[cptr], #0x00]\n" - "fmla vC43.4s, vB3.4s, vA4.s[0]\n" - "fmla vC44.4s, vB4.4s, vA4.s[0]\n" - "stp qC43, qC44, [%x[cptr], #0x20]\n" - "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n" - - ".unreq vB4\n" ".unreq qB4\n" - ".unreq vB3\n" ".unreq qB3\n" - ".unreq vB2\n" ".unreq qB2\n" - ".unreq vB1\n" ".unreq qB1\n" - ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n" - ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n" - ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n" - ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n" - ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n" - ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n" - ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n" - ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n" - ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n" - ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n" - ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n" - ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n" - ".unreq aptr2\n" - ".unreq aptr3\n" - ".unreq aptr4\n" - - : [aptr] "+r" (aptr), - [bptr] "+r" (bptr), - [cptr] "+r" (cptr), - [k] "+r" (k) - : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)), - [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)), - [c_row_stride_bytes] "r" (c_row_stride * sizeof(float)) - : "cc", "memory", "x20", "x21", "x22", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", - "v21", "v22", "v23" - ); - } - } -} diff --git a/src/core/NEON/kernels/winograd/perf.h b/src/core/NEON/kernels/winograd/perf.h deleted file mode 100644 index 11fb0c452f..0000000000 --- a/src/core/NEON/kernels/winograd/perf.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -/* Prototypes from perf.c */ - -void start_counter(int fd); -long long get_counter(int fd); -long long stop_counter(int fd); -int open_instruction_counter(void); -int open_cycle_counter(void); diff --git a/src/core/NEON/kernels/winograd/profiler.hpp b/src/core/NEON/kernels/winograd/profiler.hpp deleted file mode 100644 index 143192b589..0000000000 --- a/src/core/NEON/kernels/winograd/profiler.hpp +++ /dev/null @@ -1,244 +0,0 @@ - -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "perf.h" -#include - -class profiler { -private: -#ifdef CYCLE_PROFILING - struct ProfileEntry { - int event_id; - long int bytes_read, ops, bytes_written; - long int duration; - }; - - static const int maxevents = 10000; - ProfileEntry events[maxevents]; - int currentevent; - int countfd; - - std::map event_ids; - - int get_event_id(const char *id) { - if (!event_ids.count(id)) { - event_ids.emplace(id, event_ids.size()); - } - return event_ids[id]; - } -#endif // CYCLE_PROFILING - -public: -#ifdef CYCLE_PROFILING - profiler() { - currentevent = 0; - countfd = open_cycle_counter(); - } - - ~profiler() { - close(countfd); - - // Compute performance from recorded events - struct ProfileResult { - ProfileResult() : total_calls(0), - total_duration(0), - total_bytes_read(0), - total_ops(0), - total_bytes_written(0) { - } - - void operator+=(const ProfileEntry &rhs) { - total_calls++; - total_duration += rhs.duration; - total_bytes_read += rhs.bytes_read; - total_ops += rhs.ops; - total_bytes_written = rhs.bytes_written; - } - - float avg_duration(void) const { - return static_cast(total_duration) / - static_cast(total_calls); - } - - float bytes_read_per_cycle(void) const { - return static_cast(total_bytes_read) / - static_cast(total_duration); - } - - float ops_per_cycle(void) const { - return static_cast(total_ops) / - static_cast(total_duration); - } - - float bytes_written_per_cycle(void) const { - return static_cast(total_bytes_written) / - static_cast(total_duration); - } - - long int total_calls, - total_duration, - total_bytes_read, - total_ops, - total_bytes_written; - }; - - std::vector totals; - totals.resize(event_ids.size()); - for (int i = 0; i < currentevent; i++) { - const auto &event = events[i]; - totals[event.event_id] += event; - } - - // Get the longest label - int len_label = 0; - for (const auto &kv : event_ids) { - len_label = std::max(len_label, static_cast(strlen(kv.first))); - } - - // Get the longest values for every other field - const auto get_length_of_field = - [totals] (const char *title, auto f, auto len) -> size_t { - size_t l = strlen(title); - for (const auto &v : totals) { - l = std::max(l, len(f(v))); - } - return l; - }; - - // Get the strlen for an int - const auto intlen = [] (long int x) -> size_t { - size_t len = 0; - do { - x /= 10; - len++; - } while (x); - return len; - }; - - // Get the strlen for a float - const auto floatlen = [] (const int precision) { - return [precision] (float x) { - size_t len = 0; - - if (!std::isfinite(x)) { - return static_cast(3); - } - - do { - x /= 10.0f; - len++; - } while (x > 1.0f); - return len + 1 + precision; - }; - }; - - const int len_calls = get_length_of_field( - "Calls", [] (const auto &v) {return v.total_calls;}, - intlen - ); - const int len_duration = get_length_of_field( - "Duration", [] (const auto &v) {return v.total_duration;}, - intlen - ); - const int len_average_duration = get_length_of_field( - "Average", [] (const auto &v) {return v.avg_duration();}, - floatlen(2) - ); - const int len_reads_per_cycle = get_length_of_field( - "Reads / cycle", - [] (const auto &v) {return v.bytes_read_per_cycle();}, - floatlen(6) - ); - const int len_ops_per_cycle = get_length_of_field( - "Ops / cycle", - [] (const auto &v) {return v.ops_per_cycle();}, - floatlen(6) - ); - const int len_writes_per_cycle = get_length_of_field( - "Writes / cycle", - [] (const auto &v) {return v.bytes_written_per_cycle();}, - floatlen(6) - ); - - // Print header - printf( - "%*s %*s %*s %*s %*s %*s %*s\n", - len_label, "", - len_calls, "Calls", - len_duration, "Duration", - len_average_duration, "Average", - len_reads_per_cycle, "Reads / cycle", - len_ops_per_cycle, "Ops / cycle", - len_writes_per_cycle, "Writes / cycle" - ); - for (const auto &kv : event_ids) { - const auto id = kv.second; - printf( - "%*s %*ld %*ld %*.2f %*.6f %*.6f %*.6f\n", - len_label, kv.first, - len_calls, totals[id].total_calls, - len_duration, totals[id].total_duration, - len_average_duration, totals[id].avg_duration(), - len_reads_per_cycle, totals[id].bytes_read_per_cycle(), - len_ops_per_cycle, totals[id].ops_per_cycle(), - len_writes_per_cycle, totals[id].bytes_written_per_cycle() - ); - } - printf("\n"); - } -#endif // CYCLE_PROFILING - - template - void operator() (const char * event, - T func, - long int bytes_read = 0, - long int ops = 0, - long int bytes_written = 0) { -#ifdef CYCLE_PROFILING - if (currentevent==maxevents) { - func(); - } else { - start_counter(countfd); - func(); - long long cycs = stop_counter(countfd); - - // Store the profiling data - events[currentevent++] = { - get_event_id(event), bytes_read, ops, bytes_written, cycs - }; - } -#else - func(); -#endif // CYCLE_PROFILING - } -}; diff --git a/src/core/NEON/kernels/winograd/shims.hpp b/src/core/NEON/kernels/winograd/shims.hpp deleted file mode 100644 index 249e5757f0..0000000000 --- a/src/core/NEON/kernels/winograd/shims.hpp +++ /dev/null @@ -1,319 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#pragma once - -/** Re-order a weight tensor from [Output feature map x Input feature map x - * Height x Width] format to [Height x Width x Input feature map x Output - * feature map] format. - */ -template -inline void ofm_ifm_h_w_to_h_w_ifm_ofm( - const T* const in, // Input in [Output x Input x Height x Width] form - T* const out, // Output in [Height x Width x Input x Output] form - const int n_output_feature_maps, - const int n_input_feature_maps, - const int n_rows, - const int n_cols, - int in_output_feature_map_stride=0, - int in_input_feature_map_stride=0, - int in_row_stride=0, - int out_row_stride=0, - int out_col_stride=0, - int out_input_feature_map_stride=0 -); - -/** Re-order a weight tensor from [Height x Width x Input feature map x Output - * feature map] format to [Output feature map x Input feature map x Height x - * Width] format. - */ -template -inline void h_w_ifm_ofm_to_ofm_ifm_h_w( - const T* const in, // Input in [Height x Width x Input x Output] form - T* const out, // Output in [Output x Input x Height x Width] form - const int n_rows, - const int n_cols, - const int n_input_feature_maps, - const int n_output_feature_maps, - int in_row_stride=0, - int in_col_stride=0, - int in_input_feature_map_stride=0, - int out_output_feature_map_stride=0, - int out_input_feature_map_stride=0, - int out_row_stride=0 -); - - -/* Re-order a tensor from NCHW format to NHWC. - */ -template -inline void nchw_to_nhwc( - const T* const in, - T* const out, - const int n_batches, - const int n_channels, - const int n_rows, - const int n_cols, - int in_batch_stride=0, - int in_channel_stride=0, - int in_row_stride=0, - int out_batch_stride=0, - int out_row_stride=0, - int out_col_stride=0 -) -{ - // Fill in the stride values - in_row_stride = (in_row_stride) ? in_row_stride : n_cols; - in_channel_stride = (in_channel_stride) ? in_channel_stride - : n_rows * in_row_stride; - in_batch_stride = (in_batch_stride) ? in_batch_stride - : n_channels * in_channel_stride; - - out_col_stride = (out_col_stride) ? out_col_stride : n_channels; - out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride; - out_batch_stride = (out_batch_stride) ? out_batch_stride - : n_rows * out_row_stride; - - // Perform the re-ordering - for (int n = 0; n < n_batches; n++) - { - const T* const in_batch = in + n*in_batch_stride; - T* const out_batch = out + n*out_batch_stride; - - for (int i = 0; i < n_rows; i++) - { - const T* const in_row = in_batch + i*in_row_stride; - T* const out_row = out_batch + i*out_row_stride; - - for (int j = 0; j < n_cols; j++) - { - const T* const in_col = in_row + j; - T* const out_col = out_row + j*out_col_stride; - - for (int c = 0; c < n_channels; c++) - { - const T* const in_channel = in_col + c*in_channel_stride; - out_col[c] = *(in_channel); - } - } - } - } -} - -/* Re-order a tensor from NHWC format to NCHW. - */ -template -inline void nhwc_to_nchw( - const T* const in, // Input data in NHWC form - T* const out, // Output data in NCHW form - const int n_batches, - const int n_rows, - const int n_cols, - const int n_channels, - int in_batch_stride=0, - int in_row_stride=0, - int in_col_stride=0, - int out_batch_stride=0, - int out_channel_stride=0, - int out_row_stride=0 -) -{ - // Fill in stride values - in_col_stride = (in_col_stride) ? in_col_stride : n_channels; - in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride; - in_batch_stride = (in_batch_stride) ? in_batch_stride - : n_rows * in_row_stride; - - out_row_stride = (out_row_stride) ? out_row_stride : n_cols; - out_channel_stride = (out_channel_stride) ? out_channel_stride - : n_rows * out_row_stride; - out_batch_stride = (out_batch_stride) ? out_batch_stride - : n_channels * out_channel_stride; - - // Perform the re-ordering - // For every batch - for (int n = 0; n < n_batches; n++) - { - const T* const in_batch = in + n*in_batch_stride; - T* const out_batch = out + n*out_batch_stride; - - // For every row - for (int i = 0; i < n_rows; i++) - { - const T* const in_i = in_batch + i*in_row_stride; - T* const out_i = out_batch + i*out_row_stride; - - // For every column - for (int j = 0; j < n_cols; j++) - { - const T* const in_j = in_i + j*in_col_stride; - T* const out_j = out_i + j; - - // For every channel - for (int c = 0; c < n_channels; c++) - { - const T* const in_channel = in_j + c; - T* const out_channel = out_j + c*out_channel_stride; - *(out_channel) = *(in_channel); - } - } - } - } -} - - -/*****************************************************************************/ -/* Generic weight re-order implementation. - */ -template -inline void ofm_ifm_h_w_to_h_w_ifm_ofm( - const T* const in, // Input in [Output x Input x Height x Width] form - T* const out, // Output in [Height x Width x Input x Output] form - const int n_output_feature_maps, - const int n_input_feature_maps, - const int n_rows, - const int n_cols, - int in_output_feature_map_stride, - int in_input_feature_map_stride, - int in_row_stride, - int out_row_stride, - int out_col_stride, - int out_input_feature_map_stride -) -{ - // Fill in stride values - in_row_stride = (in_row_stride) - ? in_row_stride - : n_cols; - in_input_feature_map_stride = (in_input_feature_map_stride) - ? in_input_feature_map_stride - : n_rows * in_row_stride; - in_output_feature_map_stride = (in_output_feature_map_stride) - ? in_output_feature_map_stride - : n_input_feature_maps * in_input_feature_map_stride; - - out_input_feature_map_stride = (out_input_feature_map_stride) - ? out_input_feature_map_stride - : n_output_feature_maps; - out_col_stride = (out_col_stride) - ? out_col_stride - : n_input_feature_maps * out_input_feature_map_stride; - out_row_stride = (out_row_stride) - ? out_row_stride - : n_cols * out_col_stride; - - // Perform the re-ordering - for (int i = 0; i < n_rows; i++) - { - const T* const in_row = in + i * in_row_stride; - T* out_row = out + i * out_row_stride; - - for (int j = 0; j < n_cols; j++) - { - const T* const in_col = in_row + j; - T* const out_col = out_row + j * out_col_stride; - - for (int ifm = 0; ifm < n_input_feature_maps; ifm++) - { - const T* const in_ifm = in_col + ifm * in_input_feature_map_stride; - T* const out_ifm = out_col + ifm * out_input_feature_map_stride; - - for (int ofm = 0; ofm < n_output_feature_maps; ofm++) - { - const T* const in_ofm = in_ifm + ofm * in_output_feature_map_stride; - T* const out_ofm = out_ifm + ofm; - *(out_ofm) = *(in_ofm); - } - } - } - } -} - -/*****************************************************************************/ -/* Generic weight re-order implementation. - */ -template -inline void h_w_ifm_ofm_to_ofm_ifm_h_w( - const T* const in, // Input in [Height x Width x Input x Output] form - T* const out, // Output in [Output x Input x Height x Width] form - const int n_rows, - const int n_cols, - const int n_input_feature_maps, - const int n_output_feature_maps, - int in_row_stride, - int in_col_stride, - int in_input_feature_map_stride, - int out_output_feature_map_stride, - int out_input_feature_map_stride, - int out_row_stride -) -{ - // Fill in the stride values - in_input_feature_map_stride = (in_input_feature_map_stride) - ? in_input_feature_map_stride - : n_output_feature_maps; - in_col_stride = (in_col_stride) - ? in_col_stride - : n_input_feature_maps * in_input_feature_map_stride; - in_row_stride = (in_row_stride) - ? in_row_stride - : n_cols * in_col_stride; - - out_row_stride = (out_row_stride) - ? out_row_stride - : n_cols; - out_input_feature_map_stride = (out_input_feature_map_stride) - ? out_input_feature_map_stride - : n_rows * out_row_stride; - out_output_feature_map_stride = (out_output_feature_map_stride) - ? out_output_feature_map_stride - : n_input_feature_maps * out_input_feature_map_stride; - - // Perform the re-ordering - for (int i = 0; i < n_rows; i++) - { - const T* const in_row = in + i * in_row_stride; - T* const out_row = out + i * out_row_stride; - - for (int j = 0; j < n_cols; j++) - { - const T* const in_col = in_row + j * in_col_stride; - T* const out_col = out_row + j; - - for (int ifm = 0; ifm < n_input_feature_maps; ifm++) - { - const T* const in_ifm = in_col + ifm * in_input_feature_map_stride; - T* const out_ifm = out_col + ifm * out_input_feature_map_stride; - - for (int ofm = 0; ofm < n_output_feature_maps; ofm++) - { - const T* const in_ofm = in_ifm + ofm; - T* const out_ofm = out_ifm + ofm * out_output_feature_map_stride; - *(out_ofm) = *(in_ofm); - } - } - } - } -} - diff --git a/src/core/NEON/kernels/winograd/transforms.hpp b/src/core/NEON/kernels/winograd/transforms.hpp deleted file mode 100644 index 8546ee9e2e..0000000000 --- a/src/core/NEON/kernels/winograd/transforms.hpp +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#pragma once - -#include "transforms/input_2x2_3x3.hpp" -#include "transforms/kernel_2x2_3x3.hpp" -#include "transforms/output_2x2_3x3.hpp" diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp deleted file mode 100644 index ca8d012e5e..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp +++ /dev/null @@ -1,639 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include "arm_compute/core/NEON/kernels/winograd/tensor.hpp" - - -namespace winograd { - /* Transform an input tensor into the Winograd domain. - */ - template - struct Winograd2x2_3x3GemmInput { - static void execute( - const T *inptr, - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const int tile_M, - const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride - ); - - static size_t bytes_read(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - return input_shape.n_batches * tile_rows * (16 + 8*(tile_cols - 1)) * input_shape.n_channels * sizeof(T); - } - - static int flops_performed(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - return input_shape.n_batches * tile_rows * (32 + 24*(tile_cols - 1)) * input_shape.n_channels; - } - - static size_t bytes_written(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - const int M = input_shape.n_batches * tile_rows * tile_cols; - return 16 * M * input_shape.n_channels * sizeof(T); - } - - protected: - template - static void process_tile_tensor( - const int tile_M, // Number of rows of tiles - const int tile_N, // Number of columns of tiles - int n_channels, // Number of input channels - const T* const input, // Base input pointer (appropriate to batch and channel) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - T* const matrix, // 1st output matrix (appropriate to batch and channel) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix - ); - - template - static void process_tile_row( - const int tile_N, // Number of tiles in the row - const T* const input, // Base input pointer (appropriate to batch, channel and row) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - T* const matrix, // 1st output matrix (appropriate to batch, channel and row) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix - ); - }; - - template - struct Winograd2x2_3x3GemmInputChannelwise { - static void execute( - const T *inptr, - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const int tile_M, - const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride - ); - - static size_t bytes_read(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - // We read as many bytes as we write - return bytes_written(input_shape, output_shape); - } - - static int flops_performed(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - return input_shape.n_batches * tile_rows * 32 * tile_cols * input_shape.n_channels; - } - - static size_t bytes_written(const Tensor4DShape &input_shape, - const Tensor4DShape &output_shape) { - return winograd::Winograd2x2_3x3GemmInput::bytes_written(input_shape, output_shape); - } - - protected: - typedef void (*tilefunc)(int, const T*, int, int, T*, int); - template - static void process_tile( - int n_channels, // Number of channels in the tile - const T* const input_base, - const int input_row_stride, - const int input_col_stride, - T* const matrix_base, - const int matrix_stride - ); - - private: - template - static void _process_tile( - int &n_channels, const T* &inptr, - const int input_row_stride, const int input_col_stride, - T* &outptr, const int matrix_stride - ); - }; -} - -/*****************************************************************************/ -// Include specialised implementations here -#include "input_2x2_3x3/a64_float.hpp" -#include "input_2x2_3x3/a64_float_channelwise.hpp" -/*****************************************************************************/ - -/*****************************************************************************/ -template -void winograd::Winograd2x2_3x3GemmInput::execute( - const T *inptr_base, - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const int tile_M, - const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride -) { - // Select an appropriate matrix processing method for the shape and padding - // of the input tensor. - typedef void (*tensorfunc)(int, int, int, const T*, int, int, T*, int, int); - const auto process_tensor = [&padding_type, &input_shape] () -> tensorfunc { - if (padding_type == PADDING_VALID) { - const int pad_bottom = input_shape.n_rows % 2; - const int pad_right = input_shape.n_cols % 2; - - if (pad_bottom == 0 && pad_right == 0) { - return process_tile_tensor; - } else if (pad_bottom == 0 && pad_right == 1) { - return process_tile_tensor; - } else if (pad_bottom == 1 && pad_right == 0) { - return process_tile_tensor; - } else if (pad_bottom == 1 && pad_right == 1) { - return process_tile_tensor; - } - } else { // PADDING_SAME - const int pad_bottom = 1 + input_shape.n_rows % 2; - const int pad_right = 1 + input_shape.n_cols % 2; - - if (pad_bottom == 1 && pad_right == 1) { - return process_tile_tensor; - } else if (pad_bottom == 1 && pad_right == 2) { - return process_tile_tensor; - } else if (pad_bottom == 2 && pad_right == 1) { - return process_tile_tensor; - } else if (pad_bottom == 2 && pad_right == 2) { - return process_tile_tensor; - } - } - - printf("%s::%u Uncovered case.\n", __FILE__, __LINE__); - exit(-1); - return NULL; // No function found - } (); - - // Compute strides - const int input_row_stride = input_shape.n_cols * input_shape.n_channels; - const int input_col_stride = input_shape.n_channels; - - // Process each batch of the tensor in turn. - for (int batch = 0; batch < input_shape.n_batches; batch++) { - // Work out pointers - const T *inptr = inptr_base + (batch * input_shape.n_rows * - input_shape.n_cols * input_shape.n_channels); - T *outptr = outptr_base + batch * matrix_batch_stride; - - // Delegate doing the actual work - process_tensor( - tile_M, tile_N, input_shape.n_channels, - inptr, input_row_stride, input_col_stride, - outptr, matrix_stride, matrix_row_stride - ); - } -} - -/*****************************************************************************/ -template -template -void winograd::Winograd2x2_3x3GemmInput::process_tile_tensor( - const int tile_M, // Number of rows of tiles - const int tile_N, // Number of columns of tiles - int n_channels, // Number of input channels - const T* const input, // Base input pointer (appropriate to batch and channel) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - T* const matrix, // 1st output matrix (appropriate to batch and channel) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix -) { - // Base row processing functions - typedef void (*rowfunc)(int, const T*, int, int, T*, int, int); - const rowfunc process_top_row[3] = { - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 1> - : process_tile_row<1, 1, 0, pad_right, 1>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 2> - : process_tile_row<1, 1, 0, pad_right, 2>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 4> - : process_tile_row<1, 1, 0, pad_right, 4>, - }; - const rowfunc process_middle_row[3] = { - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 1> - : process_tile_row<0, 1, 0, pad_right, 1>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 2> - : process_tile_row<0, 1, 0, pad_right, 2>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, 0, pad_right, 4> - : process_tile_row<0, 1, 0, pad_right, 4>, - }; - const rowfunc process_bottom_row[3] = { - (padding == PADDING_VALID) - ? process_tile_row<0, 0, pad_bottom, pad_right, 1> - : process_tile_row<0, 1, pad_bottom, pad_right, 1>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, pad_bottom, pad_right, 2> - : process_tile_row<0, 1, pad_bottom, pad_right, 2>, - (padding == PADDING_VALID) - ? process_tile_row<0, 0, pad_bottom, pad_right, 4> - : process_tile_row<0, 1, pad_bottom, pad_right, 4>, - }; - - // Method to get an input pointer for the given tile row - const auto get_inptr = [&input, &input_row_stride] (const int tile_i) { - if (padding == PADDING_VALID) { - return input + 2 * tile_i * input_row_stride; - } else { - return input + (2 * tile_i - (tile_i ? 1 : 0)) * input_row_stride; - } - }; - - // Wrapper to process a row of tiles, covering all channels. - const auto process_row = - [tile_N, input_row_stride, input_col_stride, matrix_stride, matrix_row_stride, n_channels] - (const rowfunc f[3], const T *inptr, T *outptr) { - int rem_channels = n_channels; - - // While there remain channels to process continue to process the - // row. - for (; rem_channels >= 4; rem_channels -= 4, inptr += 4, outptr += 4) { - f[2](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); - } - for (; rem_channels >= 2; rem_channels -= 2, inptr += 2, outptr += 2) { - f[1](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); - } - if (rem_channels) { - f[0](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride); - } - }; - - // Process all rows of tiles in the tensor - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - T* const m_row = matrix + tile_i * tile_N * matrix_row_stride; - const T *row_inptr = get_inptr(tile_i); - - if (tile_i == 0) { - // Top row of the input - process_row(process_top_row, row_inptr, m_row); - } else if (tile_i == tile_M - 1) { - // Bottom row of the input - process_row(process_bottom_row, row_inptr, m_row); - } else { - // Any other row of the input - process_row(process_middle_row, row_inptr, m_row); - } - } -} - -/*****************************************************************************/ -template -template -void winograd::Winograd2x2_3x3GemmInput::process_tile_row( - const int tile_N, // Number of tiles in the row - const T* const input, // Base input pointer (appropriate to batch, channel and row) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - T* const matrix, // 1st output matrix (appropriate to batch, channel and row) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix -) { - // Construct copies of the pointers - const T *inptr = input; - T *outptr = matrix; - - // Storage for the tensors x, X.T x, and X.T x X. - T x[4][4][proc_channels], XTx[4][4][proc_channels], XTxX[4][4][proc_channels]; - - // For every tile in the row - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - // Determine the padding for the tile - const int tile_pad_left = (tile_j == 0) ? pad_left : 0; - const int tile_pad_right = (tile_j == tile_N - 1) ? pad_right : 0; - - // Load tile values. If this is the first tile in the row then we must load - // all values, otherwise we can just load the final two columns of the input. - for (int i = 0; i < 4; i++) { - for (int j = ((tile_j == 0) ? 0 : 2); j < 4; j++) { - // Fill with padding if required - if (i < pad_top || 4 - pad_bottom <= i || - j < tile_pad_left || 4 - tile_pad_right <= j) { - for (int c = 0; c < proc_channels; c++) { - x[i][j][c] = static_cast(0); // Padding - } - } else { - // Load values, note that the initial padding offsets the pointer we - // were provided. - for (int c = 0; c < proc_channels; c++) { - const int row_offset = (i - pad_top) * input_row_stride; - const int col_offset = (j - tile_pad_left) * input_col_stride; - x[i][j][c] = inptr[row_offset + col_offset + c]; - } - } - } - } - - // Compute the matrix X.T x. Note, can elide operations depending on the - // padding. Furthermore, if this isn't the left-most tile we can skip half - // of the operations by copying results from the previous version of X.T x. - // This latter optimisation can be simplified by unrolling the outermost - // loop by two and by renaming the registers containing XTx. - if (tile_j == 0) { - for (int j = 0; j < 4; j++) { - for (int c = 0; c < proc_channels; c++) { - XTx[0][j][c] = x[0][j][c] - x[2][j][c]; - XTx[1][j][c] = x[1][j][c] + x[2][j][c]; - XTx[2][j][c] = -x[1][j][c] + x[2][j][c]; - XTx[3][j][c] = x[1][j][c] - x[3][j][c]; - } - } - } else { - for (int j = 0; j < 2; j++) { - for (int c = 0; c < proc_channels; c++) { - XTx[0][j][c] = XTx[0][j + 2][c]; - XTx[1][j][c] = XTx[1][j + 2][c]; - XTx[2][j][c] = XTx[2][j + 2][c]; - XTx[3][j][c] = XTx[3][j + 2][c]; - } - } - for (int j = 2; j < 4; j++) { - for (int c = 0; c < proc_channels; c++) { - XTx[0][j][c] = x[0][j][c] - x[2][j][c]; - XTx[1][j][c] = x[1][j][c] + x[2][j][c]; - XTx[2][j][c] = -x[1][j][c] + x[2][j][c]; - XTx[3][j][c] = x[1][j][c] - x[3][j][c]; - } - } - } - - // Compute the matrix X.T x X. Note, can elide operations based on the - // padding. - for (int i = 0; i < 4; i++) { - for (int c = 0; c < proc_channels; c++) { - XTxX[i][0][c] = XTx[i][0][c] - XTx[i][2][c]; - XTxX[i][1][c] = XTx[i][1][c] + XTx[i][2][c]; - XTxX[i][2][c] = -XTx[i][1][c] + XTx[i][2][c]; - XTxX[i][3][c] = XTx[i][1][c] - XTx[i][3][c]; - } - } - - // Store the output matrix (X.T x X) - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - // Get a pointer to the relevant output matrix - T *mptr = outptr + (i*4 + j)*matrix_stride; - - // Write out the channels - for (int c = 0; c < proc_channels; c++) { - mptr[c] = XTxX[i][j][c]; - } - } - } - - // Update the pointers - inptr += input_col_stride * ((tile_j == 0 && pad_left) ? 1 : 2); - outptr += matrix_row_stride; - } -} - -/*****************************************************************************/ -template -void winograd::Winograd2x2_3x3GemmInputChannelwise::execute( - const T *inptr, - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const int tile_M, - const int tile_N, - T *outptr_base, - const int matrix_stride, - const int matrix_batch_stride, - const int matrix_row_stride -) { - const int n_channels = input_shape.n_channels; - const int input_col_stride = n_channels; - const int input_row_stride = input_shape.n_cols * input_col_stride; - - // Determine the padding and hence select appropriate methods for each tile. - tilefunc fs[3][3]; - - if (padding_type == PADDING_VALID) { - constexpr int pad_top = 0; - constexpr int pad_left = 0; - const int pad_right = input_shape.n_cols % 2 == 0; - - fs[0][0] = process_tile; - fs[0][1] = process_tile; - fs[0][2] = (pad_right) ? process_tile : process_tile; - - fs[1][0] = process_tile<0, pad_left, 0, 0>; - fs[1][1] = process_tile<0, 0, 0, 0>; - fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 0> : process_tile<0, 0, 0, 1>; - - if (input_shape.n_rows % 2 == 0) { - constexpr int pad_bottom = 0; - fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; - fs[2][1] = process_tile<0, 0, pad_bottom, 0>; - fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>; - } else { - constexpr int pad_bottom = 1; - fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; - fs[2][1] = process_tile<0, 0, pad_bottom, 0>; - fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>; - } - } else { - constexpr int pad_top = 1; - constexpr int pad_left = 1; - const int pad_right = input_shape.n_cols % 2 == 0; - - fs[0][0] = process_tile; - fs[0][1] = process_tile; - fs[0][2] = (pad_right) ? process_tile : process_tile; - - fs[1][0] = process_tile<0, pad_left, 0, 0>; - fs[1][1] = process_tile<0, 0, 0, 0>; - fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 1> : process_tile<0, 0, 0, 2>; - - if (input_shape.n_rows % 2 == 0) { - constexpr int pad_bottom = 1; - fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; - fs[2][1] = process_tile<0, 0, pad_bottom, 0>; - fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>; - } else { - constexpr int pad_bottom = 2; - fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>; - fs[2][1] = process_tile<0, 0, pad_bottom, 0>; - fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>; - } - } - - // Process each tile in turn - for (int batch = 0; batch < input_shape.n_batches; batch++) { - const T* const input_base_batch = inptr + batch*input_shape.n_rows*input_shape.n_cols*n_channels; - - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - const int row_offset = (tile_i == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1); - const T* const input_base_row = input_base_batch + (2*tile_i - row_offset)*input_shape.n_cols*n_channels; - - // Select the set of functions for the row - const int fs_i = (tile_i == 0) ? 0 : ((tile_i < tile_M - 1) ? 1 : 2); - - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - // Select the function for the column - const int fs_j = (tile_j == 0) ? 0 : ((tile_j < tile_N - 1) ? 1 : 2); - const auto f = fs[fs_i][fs_j]; - - // Get pointers into the input and outputs - const int col_offset = (tile_j == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1); - const T* const input_base_col = input_base_row + (2*tile_j - col_offset)*n_channels; - T* const matrix_base = outptr_base + batch*matrix_batch_stride + (tile_i*tile_N + tile_j)*matrix_row_stride; - f(n_channels, input_base_col, input_row_stride, input_col_stride, - matrix_base, matrix_stride); - } - } - } -} - -template -template -void winograd::Winograd2x2_3x3GemmInputChannelwise::process_tile( - int n_channels, // Number of channels in the tile - const T* const input_base, - const int input_row_stride, - const int input_col_stride, - T* const matrix_base, - const int matrix_stride -) { - // Copy pointers - const T *inptr = input_base; - T *outptr = matrix_base; - - // Process channels (modifies inptr, outptr and n_channels) - _process_tile( - n_channels, inptr, input_row_stride, input_col_stride, - outptr, matrix_stride - ); - _process_tile( - n_channels, inptr, input_row_stride, input_col_stride, - outptr, matrix_stride - ); - _process_tile( - n_channels, inptr, input_row_stride, input_col_stride, - outptr, matrix_stride - ); -} - -template -template -void winograd::Winograd2x2_3x3GemmInputChannelwise::_process_tile( - int &n_channels, - const T* &inptr, const int input_row_stride, const int input_col_stride, - T* &outptr, const int matrix_stride -) { - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - T* outptrs[4] = { - outptr, - outptr + matrix_stride * 4, - outptr + matrix_stride * 8, - outptr + matrix_stride * 12 - }; - - // The matrix X; zeroed to account for padding. - T x[4][4]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - x[i][j] = 0; - } - } - - // The matrices X.T x and U - T XTx[4][4], U[4][4]; - - // Now progress through each channel - for (; n_channels >= proc_channels; n_channels -= proc_channels) { - for (int n = 0; n < proc_channels; n++) { - // Load the matrix X - for (int cell_i = pad_top, i = 0; cell_i < 4 - pad_bottom; cell_i++, i++) { - for (int cell_j = pad_left, j = 0; cell_j < 4 - pad_right; cell_j++, j++) { - x[cell_i][cell_j] = inptr[i*input_row_stride + j*input_col_stride]; - } - } - inptr++; - - // Compute the matrix X.T - for (int j = 0; j < 4; j++) { - XTx[0][j] = x[0][j] - x[2][j]; - XTx[1][j] = x[1][j] + x[2][j]; - XTx[2][j] = x[2][j] - x[1][j]; - XTx[3][j] = x[1][j] - x[3][j]; - } - - // Hence compute the matrix U - for (int i = 0; i < 4; i++) { - U[i][0] = XTx[i][0] - XTx[i][2]; - U[i][1] = XTx[i][1] + XTx[i][2]; - U[i][2] = XTx[i][2] - XTx[i][1]; - U[i][3] = XTx[i][1] - XTx[i][3]; - } - - // Store the matrix U - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - outptrs[i][j * matrix_stride] = U[i][j]; - } - outptrs[i]++; - } - } - } - - // Update the output pointer for future calls - outptr = outptrs[0]; -} diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp deleted file mode 100644 index a99cbe325b..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp +++ /dev/null @@ -1,1498 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include "../input_2x2_3x3.hpp" - -#ifdef __aarch64__ -namespace winograd { - -// Pad left by one column, pad right by one column, no upper or lower padding, 4 channels -template <> -template <> -inline void Winograd2x2_3x3GemmInput::process_tile_row<0, 1, 0, 1, 4>( - const int tile_N, // Number of tiles in the row - const float* const input, // Base input pointer (appropriate to batch, channel and row) - const int input_row_stride, // Stride between rows of the input - const int input_col_stride, // Stride between columns of the input - float* const matrix, // 1st output matrix (appropriate to batch, channel and row) - const int matrix_stride, // Stride between matrices - const int matrix_row_stride // Stride between rows of the output matrix -) { - /* SIMD register allocation - * ======================== - * - * In the following code we read 4x4 tiles of a matrix `x`, with which we - * compute another matrix `X.T x` where: - * - * / 1 0 0 0 \ - * X = | 0 1 -1 1 | - * | -1 1 1 0 | - * \ 0 0 0 -1 / - * - * Hence, `X.T` is a program which operates upon rows of the matrix `X`. - * We subsequently compute and store the matrix `U = (X.T x) X`. - * - * Importantly, each iteration of the loop below loads a new matrix `x'` - * where the final two columns of `x'` are the first two columns of the - * previous `x`. That is: - * - * x11 x12 x13 x14 - * x21 x22 x23 x24 - * x31 x32 x33 x34 - * x41 x42 x43 x44 - * - * x'11 x'12 x'13 x'14 - * x'21 x'22 x'23 x'24 - * x'31 x'32 x'33 x'34 - * x'41 x'42 x'43 x'44 - * - * Consequently, while the first iteration of the below loop must load 16 - * values for `x`, the second need load only 8. *Furthermore*, since we noted - * above that the operation `X.T x` was a program which operated upon *rows* - * of the matrix `x` it follows that that the relation that `x'[i][1] = - * x[i][3]` and `x'[i][2] = x[i][4]` applies also the matrices `X.T x'` and - * `X.T x`. That is: - * - * (X.T x)11 (X.T x)12 (X.T x)13 (X.T x)14 - * (X.T x)21 (X.T x)22 (X.T x)23 (X.T x)24 - * (X.T x)31 (X.T x)32 (X.T x)33 (X.T x)34 - * (X.T x)41 (X.T x)42 (X.T x)43 (X.T x)44 - * - * (X.T x')11 (X.T x')12 (X.T x')13 (X.T x')14 - * (X.T x')12 (X.T x')12 (X.T x')12 (X.T x')12 - * (X.T x')13 (X.T x')13 (X.T x')13 (X.T x')13 - * (X.T x')14 (X.T x')14 (X.T x')14 (X.T x')14 - * - * Hence, as well as not needing to load new values for x'[i][1..2] it is - * also unnecessary to recompute values for (X.T x')[i][1..2]. - * - * Following this we break the registers into blocks `A` and `B` used by the - * two stages of the unrolled loop. These registers named such that the - * latter columns of `A` become the earlier columns of `B` and vice-versa: - * - * AXTx11 AXTx12 > AXTx13 AXTx14 | - * AXTx21 AXTx22 > AXTx23 AXTx24 | - * AXTx31 AXTx32 > AXTx33 AXTx34 | - * AXTx41 AXTx42 > AXTx43 AXTx44 | - * - * BXTx13 BXTx14 | BXTx11 BXTx12 > - * BXTx23 BXTx24 | BXTx21 BXTx22 > - * BXTx33 BXTx34 | BXTx31 BXTx32 > - * BXTx43 BXTx44 | BXTx41 BXTx42 > - * - * These 32 named registers require only 16 architectural registers. 1 - * additional architectural register is used as scratch space and 8 - * architectural registers are used to load in the values x[1..4][3,4]. - * - * Input and output addressing - * =========================== - * TODO Description - */ - const float *inptr0 = input; - const float *inptr1 = input + input_row_stride; - const float *inptr2 = input + input_row_stride * 2; - const float *inptr3 = input + input_row_stride * 3; - - float *outptr0 = matrix; - float *outptr4 = matrix + matrix_stride * 4; - float *outptr8 = matrix + matrix_stride * 8; - float *outptr12 = matrix + matrix_stride * 12; - - int tile_j = tile_N; // Tiles to process - - asm volatile ( - // Named SIMD registers according to the policy given above - // Registers into which to load the latter two columns of `x` - "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n" - "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" - "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" - "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n" - - // Registers for storing X.T x (both A and B halves) - "AXTx11 .req v8\n" "BXTx13 .req v8\n" - "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" - "AXTx21 .req v10\n" "BXTx23 .req v10\n" - "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" - "AXTx31 .req v12\n" "BXTx33 .req v12\n" - "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" - "AXTx41 .req v14\n" "BXTx43 .req v14\n" - "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" - "AXTx13 .req v16\n" "BXTx11 .req v16\n" - "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" - "AXTx23 .req v18\n" "BXTx21 .req v18\n" - "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" - "AXTx33 .req v20\n" "BXTx31 .req v20\n" - "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" - "AXTx43 .req v22\n" "BXTx41 .req v22\n" - "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" - - // Result register (TODO Does using more registers yield better - // performance) - "U .req v24\n qU .req q24\n" - - // ---------------------------------------------------------------------- - // Head of loop - // Loads a complete 4x4 tile of x, computes X.T x, computes and stores - // `U = X.T x X`. Prepares for the 'A' half of the loop. - // NOTE: Since the first tile has the leftmost column padded we can - // skip 4 loads and 4 calculations for the matrix X.T x X. - - // Temporarily alias registers for computing the first (non-padded) - // column of x. - "x_12 .req v0\n qx_12 .req q0\n" - "x_22 .req v1\n qx_22 .req q1\n" - "x_32 .req v2\n qx_32 .req q2\n" - "x_42 .req v3\n qx_42 .req q3\n" - - "ldr qx_12, [%x[inptr0]]\n" - "ldr qx_22, [%x[inptr1]]\n" - "ldr qx_32, [%x[inptr2]]\n" - "ldr qx_42, [%x[inptr3]]\n" - - "fsub BXTx12.4s, x_12.4s, x_32.4s\n" - "fadd BXTx22.4s, x_22.4s, x_32.4s\n" - "fsub BXTx32.4s, x_32.4s, x_22.4s\n" - "fsub BXTx42.4s, x_22.4s, x_42.4s\n" - - ".unreq x_12\n .unreq qx_12\n" - ".unreq x_22\n .unreq qx_22\n" - ".unreq x_32\n .unreq qx_32\n" - ".unreq x_42\n .unreq qx_42\n" - - // Load and compute latter two columns of the first tile. Progress the - // input pointers (by three columns so that the each points are the - // second column of the next tile, that is, each points at the first - // column which must be read for the next tile. - "ldr qx_13, [%x[inptr0], %x[colstride1]]\n" - "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" - "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" - "ldr qx_43, [%x[inptr3], %x[colstride1]]\n" - - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride2]]\n" - - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" - - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" - - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride2]]\n" - - "fsub BXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride3]\n" - - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride3]\n" - - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride3]\n" - - "fsub BXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride3]\n" - - // Compute and store U for the first tile - // First row - "fneg U.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fneg U.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fneg U.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row, simultaneously load the first column of inputs for the - // next tile. - "fneg U.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - // Update the loop counter, subtract two to account for both the head and - // the tail. - "subs %x[tile_j], %x[tile_j], #2\n" - "beq 2f\n" // Jump to "A" tail if out of tiles - - // ---------------------------------------------------------------------- - "1:" - // Start part A - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fsub AXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "fsub AXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" - "fsub AXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride2]\n" - "fadd AXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub AXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "fsub AXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride2]\n" - - // Compute and store U. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, AXTx12.4s, AXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, AXTx22.4s, AXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, AXTx32.4s, AXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, AXTx42.4s, AXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - "subs %x[tile_j], %x[tile_j], #1\n" - "beq 3f\n" // Jump to 'B' tail - - // Start part B - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" - "fsub BXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride2]\n" - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "fsub BXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride2]\n" - - // Compute and store U. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - "subs %x[tile_j], %x[tile_j], #1\n" - "bne 1b\n" // Continue loop, otherwise flow into 'A' tail - - // ---------------------------------------------------------------------- - "2:" - // 'A' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fsub AXTx13.4s, x_13.4s, x_33.4s\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "fsub AXTx43.4s, x_23.4s, x_43.4s\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" - - "b 4f\n" // Jump to end of function - - // ---------------------------------------------------------------------- - "3:" - // 'B' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" - - // ---------------------------------------------------------------------- - "4:" - // End of function - - // Clear names - ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n" - ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" - ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" - ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n" - ".unreq AXTx11\n" ".unreq BXTx13\n" - ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" - ".unreq AXTx21\n" ".unreq BXTx23\n" - ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" - ".unreq AXTx31\n" ".unreq BXTx33\n" - ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" - ".unreq AXTx41\n" ".unreq BXTx43\n" - ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" - ".unreq AXTx13\n" ".unreq BXTx11\n" - ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" - ".unreq AXTx23\n" ".unreq BXTx21\n" - ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" - ".unreq AXTx33\n" ".unreq BXTx31\n" - ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" - ".unreq AXTx43\n" ".unreq BXTx41\n" - ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" - ".unreq U\n" ".unreq qU\n" - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [inptr3] "+r" (inptr3), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [tile_j] "+r" (tile_j) // Tile counter - : [colstride1] "r" (1 * input_col_stride * sizeof(float)), - [colstride2] "r" (2 * input_col_stride * sizeof(float)), - [colstride3] "r" (3 * input_col_stride * sizeof(float)), - [mstride1] "r" (1 * matrix_stride * sizeof(float)), - [mstride2] "r" (2 * matrix_stride * sizeof(float)), - [mstride3] "r" (3 * matrix_stride * sizeof(float)), - [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24" - ); -} - -// Pad top, left and right by 1. -template <> -template <> -inline void Winograd2x2_3x3GemmInput::process_tile_row<1, 1, 0, 1, 4>( - const int tile_N, - const float* const input, - const int input_row_stride, - const int input_col_stride, - float* const matrix, - const int matrix_stride, - const int matrix_row_stride -) { - const float *inptr0 = input; - const float *inptr1 = input + input_row_stride; - const float *inptr2 = input + input_row_stride * 2; - - float *outptr0 = matrix; - float *outptr4 = matrix + matrix_stride * 4; - float *outptr8 = matrix + matrix_stride * 8; - float *outptr12 = matrix + matrix_stride * 12; - - int tile_j = tile_N; // Tiles to process - - asm volatile ( - // Named SIMD registers according to the policy given above - // Registers into which to load the latter two columns of `x` - // NOTE: We need only load the latter three rows since we know that the - // first row is padded. - "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" - "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" - "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n" - - // Registers for storing X.T x (both A and B halves) - "AXTx11 .req v8\n" "BXTx13 .req v8\n" - "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" - "AXTx21 .req v10\n" "BXTx23 .req v10\n" - "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" - "AXTx31 .req v12\n" "BXTx33 .req v12\n" - "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" - "AXTx41 .req v14\n" "BXTx43 .req v14\n" - "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" - "AXTx13 .req v16\n" "BXTx11 .req v16\n" - "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" - "AXTx23 .req v18\n" "BXTx21 .req v18\n" - "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" - "AXTx33 .req v20\n" "BXTx31 .req v20\n" - "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" - "AXTx43 .req v22\n" "BXTx41 .req v22\n" - "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" - - // Result register (TODO Does using more registers yield better - // performance) - "U .req v24\n qU .req q24\n" - - // ---------------------------------------------------------------------- - // Head of loop - // Loads a complete 4x4 tile of x, computes X.T x, computes and stores - // `U = X.T x X`. Prepares for the 'A' half of the loop. - // NOTE: Since the first tile has the leftmost column padded we can - // skip 4 loads and 4 calculations for the matrix X.T x X. - - // Temporarily alias registers for computing the first (non-padded) - // column of x. - "x_22 .req v1\n qx_22 .req q1\n" - "x_32 .req v2\n qx_32 .req q2\n" - "x_42 .req v3\n qx_42 .req q3\n" - - "ldr qx_22, [%x[inptr1]]\n" - "ldr qx_32, [%x[inptr2]]\n" - "ldr qx_42, [%x[inptr3]]\n" - - "fneg BXTx12.4s, x_32.4s\n" - "fadd BXTx22.4s, x_22.4s, x_32.4s\n" - "fsub BXTx32.4s, x_32.4s, x_22.4s\n" - "fsub BXTx42.4s, x_22.4s, x_42.4s\n" - - ".unreq x_22\n .unreq qx_22\n" - ".unreq x_32\n .unreq qx_32\n" - ".unreq x_42\n .unreq qx_42\n" - - // Load and compute latter two columns of the first tile. Progress the - // input pointers (by three columns so that the each points are the - // second column of the next tile, that is, each points at the first - // column which must be read for the next tile. - "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" - "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" - "ldr qx_43, [%x[inptr3], %x[colstride1]]\n" - - "fneg BXTx13.4s, x_33.4s\n" - - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" - - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" - - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride2]]\n" - - "fneg BXTx14.4s, x_34.4s\n" - - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride3]\n" - - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride3]\n" - - "fsub BXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride3]\n" - - // Compute and store U for the first tile - // First row - "fneg U.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fneg U.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fneg U.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row, simultaneously load the first column of inputs for the - // next tile. - "fneg U.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - // Update the loop counter, subtract two to account for both the head and - // the tail. - "subs %x[tile_j], %x[tile_j], #2\n" - "beq 2f\n" // Jump to "A" tail if out of tiles - - // ---------------------------------------------------------------------- - "1:" - // Start part A - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fneg AXTx13.4s, x_33.4s\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "fsub AXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" - "fneg AXTx14.4s, x_34.4s\n" - "fadd AXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub AXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "fsub AXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride2]\n" - - // Compute and store U. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, AXTx12.4s, AXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, AXTx22.4s, AXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, AXTx32.4s, AXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, AXTx42.4s, AXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - "subs %x[tile_j], %x[tile_j], #1\n" - "beq 3f\n" // Jump to 'B' tail - - // Start part B - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fneg BXTx13.4s, x_33.4s\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - "ldr qx_44, [%x[inptr3], %x[colstride1]]\n" - "fneg BXTx14.4s, x_34.4s\n" - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "fsub BXTx44.4s, x_24.4s, x_44.4s\n" - "add %x[inptr3], %x[inptr3], %x[colstride2]\n" - - // Compute and store U. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "ldr qx_43, [%x[inptr3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - "subs %x[tile_j], %x[tile_j], #1\n" - "bne 1b\n" // Continue loop, otherwise flow into 'A' tail - - // ---------------------------------------------------------------------- - "2:" - // 'A' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fneg AXTx13.4s, x_33.4s\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "fsub AXTx43.4s, x_23.4s, x_43.4s\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" - - "b 4f\n" // Jump to end of function - - // ---------------------------------------------------------------------- - "3:" - // 'B' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fneg BXTx13.4s, x_33.4s\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "fsub BXTx43.4s, x_23.4s, x_43.4s\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" - - // ---------------------------------------------------------------------- - "4:" - // End of function - - // Clear names - ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" - ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" - ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n" - ".unreq AXTx11\n" ".unreq BXTx13\n" - ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" - ".unreq AXTx21\n" ".unreq BXTx23\n" - ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" - ".unreq AXTx31\n" ".unreq BXTx33\n" - ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" - ".unreq AXTx41\n" ".unreq BXTx43\n" - ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" - ".unreq AXTx13\n" ".unreq BXTx11\n" - ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" - ".unreq AXTx23\n" ".unreq BXTx21\n" - ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" - ".unreq AXTx33\n" ".unreq BXTx31\n" - ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" - ".unreq AXTx43\n" ".unreq BXTx41\n" - ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" - ".unreq U\n" ".unreq qU\n" - : [inptr1] "+r" (inptr0), // Offset to account for padded row - [inptr2] "+r" (inptr1), // Offset to account for padded row - [inptr3] "+r" (inptr2), // Offset to account for padded row - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [tile_j] "+r" (tile_j) // Tile counter - : [colstride1] "r" (1 * input_col_stride * sizeof(float)), - [colstride2] "r" (2 * input_col_stride * sizeof(float)), - [colstride3] "r" (3 * input_col_stride * sizeof(float)), - [mstride1] "r" (1 * matrix_stride * sizeof(float)), - [mstride2] "r" (2 * matrix_stride * sizeof(float)), - [mstride3] "r" (3 * matrix_stride * sizeof(float)), - [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24" - ); -} - -// Pad left, right and bottom by 1. -template <> -template <> -inline void Winograd2x2_3x3GemmInput::process_tile_row<0, 1, 1, 1, 4>( - const int tile_N, - const float* const input, - const int input_row_stride, - const int input_col_stride, - float* const matrix, - const int matrix_stride, - const int matrix_row_stride -) { - const float *inptr0 = input; - const float *inptr1 = input + input_row_stride; - const float *inptr2 = input + input_row_stride * 2; - - float *outptr0 = matrix; - float *outptr4 = matrix + matrix_stride * 4; - float *outptr8 = matrix + matrix_stride * 8; - float *outptr12 = matrix + matrix_stride * 12; - - int tile_j = tile_N; // Tiles to process - - asm volatile ( - // Named SIMD registers according to the policy given above - // Registers into which to load the latter two columns of `x` - // NOTE: Bottom row is not required since since it is padded. - "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n" - "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n" - "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n" - - // Registers for storing X.T x (both A and B halves) - "AXTx11 .req v8\n" "BXTx13 .req v8\n" - "AXTx12 .req v9\n" "BXTx14 .req v9\n" "qAXTx12 .req q9\n" - "AXTx21 .req v10\n" "BXTx23 .req v10\n" - "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n" - "AXTx31 .req v12\n" "BXTx33 .req v12\n" - "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n" - "AXTx41 .req v14\n" "BXTx43 .req v14\n" - "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n" - "AXTx13 .req v16\n" "BXTx11 .req v16\n" - "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n" - "AXTx23 .req v18\n" "BXTx21 .req v18\n" - "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n" - "AXTx33 .req v20\n" "BXTx31 .req v20\n" - "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n" - "AXTx43 .req v22\n" "BXTx41 .req v22\n" - "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n" - - // Result register (TODO Does using more registers yield better - // performance) - "U .req v24\n qU .req q24\n" - - // ---------------------------------------------------------------------- - // Head of loop - // Loads a complete 4x4 tile of x, computes X.T x, computes and stores - // `U = X.T x X`. Prepares for the 'A' half of the loop. - // NOTE: Since the first tile has the leftmost column padded we can - // skip 4 loads and 4 calculations for the matrix X.T x X. - - // Temporarily alias registers for computing the first (non-padded) - // column of x. - "x_12 .req v0\n qx_12 .req q0\n" - "x_22 .req v1\n qx_22 .req q1\n" - "x_32 .req v2\n qx_32 .req q2\n" - - "ldr qx_12, [%x[inptr0]]\n" - "ldr qx_22, [%x[inptr1]]\n" - "ldr qx_32, [%x[inptr2]]\n" - - "fsub BXTx12.4s, x_12.4s, x_32.4s\n" - "fadd BXTx22.4s, x_22.4s, x_32.4s\n" - "fsub BXTx32.4s, x_32.4s, x_22.4s\n" - "mov BXTx42.16b, x_22.16b\n" // Probably should do better - - ".unreq x_12\n .unreq qx_12\n" - ".unreq x_22\n .unreq qx_22\n" - ".unreq x_32\n .unreq qx_32\n" - - // Load and compute latter two columns of the first tile. Progress the - // input pointers (by three columns so that the each points are the - // second column of the next tile, that is, each points at the first - // column which must be read for the next tile. - "ldr qx_13, [%x[inptr0], %x[colstride1]]\n" - "ldr qx_23, [%x[inptr1], %x[colstride1]]\n" - "ldr qx_33, [%x[inptr2], %x[colstride1]]\n" - - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride2]]\n" - - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride2]]\n" - - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride2]]\n" - - "mov BXTx43.16b, x_23.16b\n" - "fsub BXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride3]\n" - - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride3]\n" - - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride3]\n" - - "mov BXTx44.16b, x_24.16b\n" - - // Compute and store U for the first tile - // First row - "fneg U.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fneg U.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fneg U.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row, simultaneously load the first column of inputs for the - // next tile. - "fneg U.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - // Update the loop counter, subtract two to account for both the head and - // the tail. - "subs %x[tile_j], %x[tile_j], #2\n" - "beq 2f\n" // Jump to "A" tail if out of tiles - - // ---------------------------------------------------------------------- - "1:" - // Start part A - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fsub AXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "mov AXTx43.16b, x_23.16b\n" - - "fsub AXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride2]\n" - "fadd AXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub AXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "mov AXTx44.16b, x_24.16b\n" - - // Compute and store U. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, AXTx12.4s, AXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, AXTx22.4s, AXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, AXTx32.4s, AXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, AXTx42.4s, AXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - - "subs %x[tile_j], %x[tile_j], #1\n" - "beq 3f\n" // Jump to 'B' tail - - // Start part B - // Load last column of this tile (the first column has already been - // loaded) and compute latter two columns of X.T x. - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "ldr qx_14, [%x[inptr0], %x[colstride1]]\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "ldr qx_24, [%x[inptr1], %x[colstride1]]\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "ldr qx_34, [%x[inptr2], %x[colstride1]]\n" - "mov BXTx43.16b, x_23.16b\n" - - "fsub BXTx14.4s, x_14.4s, x_34.4s\n" - "add %x[inptr0], %x[inptr0], %x[colstride2]\n" - "fadd BXTx24.4s, x_24.4s, x_34.4s\n" - "add %x[inptr1], %x[inptr1], %x[colstride2]\n" - "fsub BXTx34.4s, x_34.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], %x[colstride2]\n" - "mov BXTx44.16b, x_24.16b\n" - - // Compute and store U. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, BXTx12.4s, BXTx14.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fsub U.4s, BXTx22.4s, BXTx24.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, BXTx32.4s, BXTx34.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "ldr qx_13, [%x[inptr0]]\n" - - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "ldr qx_23, [%x[inptr1]]\n" - - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "ldr qx_33, [%x[inptr2]]\n" - - "fsub U.4s, BXTx42.4s, BXTx44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - - "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n" - "subs %x[tile_j], %x[tile_j], #1\n" - "bne 1b\n" // Continue loop, otherwise flow into 'A' tail - - // ---------------------------------------------------------------------- - "2:" - // 'A' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fsub AXTx13.4s, x_13.4s, x_33.4s\n" - "fadd AXTx23.4s, x_23.4s, x_33.4s\n" - "fsub AXTx33.4s, x_33.4s, x_23.4s\n" - "mov AXTx43.16b, x_23.16b\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, AXTx11.4s, AXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, AXTx12.4s, AXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, AXTx13.4s, AXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qAXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, AXTx21.4s, AXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, AXTx22.4s, AXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, AXTx23.4s, AXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qAXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, AXTx31.4s, AXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, AXTx32.4s, AXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, AXTx33.4s, AXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qAXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, AXTx41.4s, AXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, AXTx42.4s, AXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, AXTx43.4s, AXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qAXTx42, [%x[outptr12], %x[mstride3]]\n" - - "b 4f\n" // Jump to end of function - - // ---------------------------------------------------------------------- - "3:" - // 'B' tail - // Since the final column is padding and the last-but-one column has - // already been loaded just compute the 3rd column of `X.T x'. - "fsub BXTx13.4s, x_13.4s, x_33.4s\n" - "fadd BXTx23.4s, x_23.4s, x_33.4s\n" - "fsub BXTx33.4s, x_33.4s, x_23.4s\n" - "mov BXTx43.16b, x_23.16b\n" - - // Compute and store U. Modified to account for the final column of X.T - // x containing padding. Note, it is also unnecessary to update the - // output pointers. - // First row - "fsub U.4s, BXTx11.4s, BXTx13.4s\n" - "str qU, [%x[outptr0]]\n" - "fadd U.4s, BXTx12.4s, BXTx13.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, BXTx13.4s, BXTx12.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "str qBXTx12, [%x[outptr0], %x[mstride3]]\n" - - // Second row - "fsub U.4s, BXTx21.4s, BXTx23.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, BXTx22.4s, BXTx23.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fsub U.4s, BXTx23.4s, BXTx22.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "str qBXTx22, [%x[outptr4], %x[mstride3]]\n" - - // Third row - "fsub U.4s, BXTx31.4s, BXTx33.4s\n" - "str qU, [%x[outptr8]]\n" - "fadd U.4s, BXTx32.4s, BXTx33.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, BXTx33.4s, BXTx32.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "str qBXTx32, [%x[outptr8], %x[mstride3]]\n" - - // Fourth row - "fsub U.4s, BXTx41.4s, BXTx43.4s\n" - "str qU, [%x[outptr12]]\n" - "fadd U.4s, BXTx42.4s, BXTx43.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, BXTx43.4s, BXTx42.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "str qBXTx42, [%x[outptr12], %x[mstride3]]\n" - - // ---------------------------------------------------------------------- - "4:" - // End of function - - // Clear names - ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n" - ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n" - ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n" - ".unreq AXTx11\n" ".unreq BXTx13\n" - ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n" - ".unreq AXTx21\n" ".unreq BXTx23\n" - ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n" - ".unreq AXTx31\n" ".unreq BXTx33\n" - ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n" - ".unreq AXTx41\n" ".unreq BXTx43\n" - ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n" - ".unreq AXTx13\n" ".unreq BXTx11\n" - ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n" - ".unreq AXTx23\n" ".unreq BXTx21\n" - ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n" - ".unreq AXTx33\n" ".unreq BXTx31\n" - ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n" - ".unreq AXTx43\n" ".unreq BXTx41\n" - ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n" - ".unreq U\n" ".unreq qU\n" - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [tile_j] "+r" (tile_j) // Tile counter - : [colstride1] "r" (1 * input_col_stride * sizeof(float)), - [colstride2] "r" (2 * input_col_stride * sizeof(float)), - [colstride3] "r" (3 * input_col_stride * sizeof(float)), - [mstride1] "r" (1 * matrix_stride * sizeof(float)), - [mstride2] "r" (2 * matrix_stride * sizeof(float)), - [mstride3] "r" (3 * matrix_stride * sizeof(float)), - [matrix_row_stride] "r" (matrix_row_stride * sizeof(float)) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24" - ); -} -} -#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp deleted file mode 100644 index ad1ad55291..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp +++ /dev/null @@ -1,961 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include "../input_2x2_3x3.hpp" - -#ifdef __aarch64__ - -namespace winograd { - -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 0, 0, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 1*input_row_stride; - auto inptr2 = inptr0 + 2*input_row_stride; - auto inptr3 = inptr0 + 3*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_11 .req v0\n" "qX_11 .req q0\n" - "X_12 .req v1\n" "qX_12 .req q1\n" - "X_13 .req v2\n" "qX_13 .req q2\n" - "X_14 .req v3\n" "qX_14 .req q3\n" - "X_21 .req v4\n" "qX_21 .req q4\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_24 .req v7\n" "qX_24 .req q7\n" - "X_31 .req v8\n" "qX_31 .req q8\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_34 .req v11\n" "qX_34 .req q11\n" - "X_41 .req v12\n" "qX_41 .req q12\n" - "X_42 .req v13\n" "qX_42 .req q13\n" - "X_43 .req v14\n" "qX_43 .req q14\n" - "X_44 .req v15\n" "qX_44 .req q15\n" - "xX_11 .req v16\n" - "xX_12 .req v17\n" - "xX_13 .req v18\n" - "xX_14 .req v19\n" - "xX_21 .req v20\n" - "xX_22 .req v21\n" - "xX_23 .req v22\n" - "xX_24 .req v23\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req v27\n" - "xX_41 .req v28\n" - "xX_42 .req v29\n" - "xX_43 .req v30\n" - "xX_44 .req v31\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_11, [%x[inptr0]]\n" - "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" - "ldr qX_14, [%x[inptr0], %x[colstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qX_21, [%x[inptr1]]\n" - "fsub xX_11.4s, x_11.4s, x_13.4s\n" - "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" - "fadd xX_12.4s, x_12.4s, x_13.4s\n" - "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" - "fsub xX_13.4s, x_13.4s, x_12.4s\n" - "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" - "fsub xX_14.4s, x_12.4s, x_14.4s\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qX_31, [%x[inptr2]]\n" - "fsub xX_21.4s, x_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" - "fsub xX_24.4s, x_22.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "ldr qX_41, [%x[inptr3]]\n" - "fsub xX_31.4s, x_31.4s, x_33.4s\n" - "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "ldr qX_44, [%x[inptr3], %x[colstride3]]\n" - "fsub xX_34.4s, x_32.4s, x_34.4s\n" - "add %x[inptr3], %x[inptr3], #0x10\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fsub xX_41.4s, x_41.4s, x_43.4s\n" - - "fsub U.4s, xX_11.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fsub U.4s, xX_12.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, xX_13.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, xX_14.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd xX_42.4s, x_42.4s, x_43.4s\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub xX_43.4s, x_43.4s, x_42.4s\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fsub xX_44.4s, x_42.4s, x_44.4s\n" - - "fsub U.4s, xX_21.4s, xX_41.4s\n" - "str qU, [%x[outptr12]]\n" - "fsub U.4s, xX_22.4s, xX_42.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, xX_23.4s, xX_43.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "fsub U.4s, xX_24.4s, xX_44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq qU\n" - ".unreq U\n" - ".unreq X_11\n" ".unreq qX_11\n" - ".unreq X_12\n" ".unreq qX_12\n" - ".unreq X_13\n" ".unreq qX_13\n" - ".unreq X_14\n" ".unreq qX_14\n" - ".unreq X_21\n" ".unreq qX_21\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_24\n" ".unreq qX_24\n" - ".unreq X_31\n" ".unreq qX_31\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_34\n" ".unreq qX_34\n" - ".unreq X_41\n" ".unreq qX_41\n" - ".unreq X_42\n" ".unreq qX_42\n" - ".unreq X_43\n" ".unreq qX_43\n" - ".unreq X_44\n" ".unreq qX_44\n" - ".unreq xX_11\n" - ".unreq xX_12\n" - ".unreq xX_13\n" - ".unreq xX_14\n" - ".unreq xX_21\n" - ".unreq xX_22\n" - ".unreq xX_23\n" - ".unreq xX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - ".unreq xX_41\n" - ".unreq xX_42\n" - ".unreq xX_43\n" - ".unreq xX_44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [inptr3] "+r" (inptr3), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [colstride3] "r" (input_col_stride * sizeof(float) * 3), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} - -// Pad top by 1 -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<1, 0, 0, 0, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 0*input_row_stride; - auto inptr2 = inptr0 + 1*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_21 .req v4\n" "qX_21 .req q4\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_24 .req v7\n" "qX_24 .req q7\n" - "X_31 .req v8\n" "qX_31 .req q8\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_34 .req v11\n" "qX_34 .req q11\n" - "X_41 .req v12\n" "qX_41 .req q12\n" - "X_42 .req v13\n" "qX_42 .req q13\n" - "X_43 .req v14\n" "qX_43 .req q14\n" - "X_44 .req v15\n" "qX_44 .req q15\n" - "xX_21 .req v20\n" - "xX_22 .req v21\n" - "xX_23 .req v22\n" - "xX_24 .req v23\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req v27\n" - "xX_41 .req v28\n" - "xX_42 .req v29\n" - "xX_43 .req v30\n" - "xX_44 .req v31\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_21, [%x[inptr1]]\n" - "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" - "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" - "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qX_31, [%x[inptr2]]\n" - "fsub xX_21.4s, x_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" - "fsub xX_24.4s, x_22.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "ldr qX_41, [%x[inptr3]]\n" - "fsub xX_31.4s, x_31.4s, x_33.4s\n" - "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "ldr qX_44, [%x[inptr3], %x[colstride3]]\n" - "fsub xX_34.4s, x_32.4s, x_34.4s\n" - "add %x[inptr3], %x[inptr3], #0x10\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fsub xX_41.4s, x_41.4s, x_43.4s\n" - - "fneg U.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fneg U.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fneg U.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fneg U.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd xX_42.4s, x_42.4s, x_43.4s\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub xX_43.4s, x_43.4s, x_42.4s\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fsub xX_44.4s, x_42.4s, x_44.4s\n" - - "fsub U.4s, xX_21.4s, xX_41.4s\n" - "str qU, [%x[outptr12]]\n" - "fsub U.4s, xX_22.4s, xX_42.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, xX_23.4s, xX_43.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "fsub U.4s, xX_24.4s, xX_44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq qU\n" - ".unreq U\n" - ".unreq X_21\n" ".unreq qX_21\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_24\n" ".unreq qX_24\n" - ".unreq X_31\n" ".unreq qX_31\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_34\n" ".unreq qX_34\n" - ".unreq X_41\n" ".unreq qX_41\n" - ".unreq X_42\n" ".unreq qX_42\n" - ".unreq X_43\n" ".unreq qX_43\n" - ".unreq X_44\n" ".unreq qX_44\n" - ".unreq xX_21\n" - ".unreq xX_22\n" - ".unreq xX_23\n" - ".unreq xX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - ".unreq xX_41\n" - ".unreq xX_42\n" - ".unreq xX_43\n" - ".unreq xX_44\n" - - : [inptr1] "+r" (inptr0), // Offset for missing row - [inptr2] "+r" (inptr1), // Offset for missing row - [inptr3] "+r" (inptr2), // Offset for missing row - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [colstride3] "r" (input_col_stride * sizeof(float) * 3), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} - -// Pad left by 1 -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 1, 0, 0, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 1*input_row_stride; - auto inptr2 = inptr0 + 2*input_row_stride; - auto inptr3 = inptr0 + 3*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_12 .req v1\n" "qX_12 .req q1\n" - "X_13 .req v2\n" "qX_13 .req q2\n" - "X_14 .req v3\n" "qX_14 .req q3\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_24 .req v7\n" "qX_24 .req q7\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_34 .req v11\n" "qX_34 .req q11\n" - "X_42 .req v13\n" "qX_42 .req q13\n" - "X_43 .req v14\n" "qX_43 .req q14\n" - "X_44 .req v15\n" "qX_44 .req q15\n" - "xX_11 .req v16\n" - "xX_12 .req v17\n" - "xX_13 .req v18\n" - "xX_14 .req v19\n" - "xX_21 .req v20\n" - "xX_22 .req v21\n" - "xX_23 .req v22\n" - "xX_24 .req v23\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req v27\n" - "xX_41 .req v28\n" - "xX_42 .req v29\n" - "xX_43 .req v30\n" - "xX_44 .req v31\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_12, [%x[inptr0]]\n" - "ldr qX_13, [%x[inptr0], %x[colstride1]]\n" - "ldr qX_14, [%x[inptr0], %x[colstride2]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "fneg xX_11.4s, x_13.4s\n" - "ldr qX_22, [%x[inptr1]]\n" - "fadd xX_12.4s, x_12.4s, x_13.4s\n" - "ldr qX_23, [%x[inptr1], %x[colstride1]]\n" - "fsub xX_13.4s, x_13.4s, x_12.4s\n" - "ldr qX_24, [%x[inptr1], %x[colstride2]]\n" - "fsub xX_14.4s, x_12.4s, x_14.4s\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "fneg xX_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride1]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "ldr qX_34, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_24.4s, x_22.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "fneg xX_31.4s, x_33.4s\n" - "ldr qX_42, [%x[inptr3]]\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "ldr qX_43, [%x[inptr3], %x[colstride1]]\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "ldr qX_44, [%x[inptr3], %x[colstride2]]\n" - "fsub xX_34.4s, x_32.4s, x_34.4s\n" - "add %x[inptr3], %x[inptr3], #0x10\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fneg xX_41.4s, x_43.4s\n" - - "fsub U.4s, xX_11.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fsub U.4s, xX_12.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, xX_13.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, xX_14.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd xX_42.4s, x_42.4s, x_43.4s\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub xX_43.4s, x_43.4s, x_42.4s\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fsub xX_44.4s, x_42.4s, x_44.4s\n" - - "fsub U.4s, xX_21.4s, xX_41.4s\n" - "str qU, [%x[outptr12]]\n" - "fsub U.4s, xX_22.4s, xX_42.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, xX_23.4s, xX_43.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "fsub U.4s, xX_24.4s, xX_44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq X_12\n" ".unreq qX_12\n" - ".unreq X_13\n" ".unreq qX_13\n" - ".unreq X_14\n" ".unreq qX_14\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_24\n" ".unreq qX_24\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_34\n" ".unreq qX_34\n" - ".unreq X_42\n" ".unreq qX_42\n" - ".unreq X_43\n" ".unreq qX_43\n" - ".unreq X_44\n" ".unreq qX_44\n" - ".unreq xX_11\n" - ".unreq xX_12\n" - ".unreq xX_13\n" - ".unreq xX_14\n" - ".unreq xX_21\n" - ".unreq xX_22\n" - ".unreq xX_23\n" - ".unreq xX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - ".unreq xX_41\n" - ".unreq xX_42\n" - ".unreq xX_43\n" - ".unreq xX_44\n" - ".unreq U\n" - ".unreq qU\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [inptr3] "+r" (inptr3), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} - -// Pad bottom by 1 -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 1, 0, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 1*input_row_stride; - auto inptr2 = inptr0 + 2*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_11 .req v0\n" "qX_11 .req q0\n" - "X_12 .req v1\n" "qX_12 .req q1\n" - "X_13 .req v2\n" "qX_13 .req q2\n" - "X_14 .req v3\n" "qX_14 .req q3\n" - "X_21 .req v4\n" "qX_21 .req q4\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_24 .req v7\n" "qX_24 .req q7\n" - "X_31 .req v8\n" "qX_31 .req q8\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_34 .req v11\n" "qX_34 .req q11\n" - "xX_11 .req v16\n" - "xX_12 .req v17\n" - "xX_13 .req v18\n" - "xX_14 .req v19\n" - "xX_21 .req v20\n" "qxX_21 .req q20\n" - "xX_22 .req v21\n" "qxX_22 .req q21\n" - "xX_23 .req v22\n" "qxX_23 .req q22\n" - "xX_24 .req v23\n" "qxX_24 .req q23\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req v27\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_11, [%x[inptr0]]\n" - "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" - "ldr qX_14, [%x[inptr0], %x[colstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qX_21, [%x[inptr1]]\n" - "fsub xX_11.4s, x_11.4s, x_13.4s\n" - "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" - "fadd xX_12.4s, x_12.4s, x_13.4s\n" - "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" - "fsub xX_13.4s, x_13.4s, x_12.4s\n" - "ldr qX_24, [%x[inptr1], %x[colstride3]]\n" - "fsub xX_14.4s, x_12.4s, x_14.4s\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qX_31, [%x[inptr2]]\n" - "fsub xX_21.4s, x_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "ldr qX_34, [%x[inptr2], %x[colstride3]]\n" - "fsub xX_24.4s, x_22.4s, x_24.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "fsub xX_31.4s, x_31.4s, x_33.4s\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "fsub xX_34.4s, x_32.4s, x_34.4s\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fsub U.4s, xX_11.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fsub U.4s, xX_12.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, xX_13.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, xX_14.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "str qxX_21, [%x[outptr12]]\n" - "str qxX_22, [%x[outptr12], %x[mstride1]]\n" - "str qxX_23, [%x[outptr12], %x[mstride2]]\n" - "str qxX_24, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq qU\n" - ".unreq U\n" - ".unreq X_11\n" ".unreq qX_11\n" - ".unreq X_12\n" ".unreq qX_12\n" - ".unreq X_13\n" ".unreq qX_13\n" - ".unreq X_14\n" ".unreq qX_14\n" - ".unreq X_21\n" ".unreq qX_21\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_24\n" ".unreq qX_24\n" - ".unreq X_31\n" ".unreq qX_31\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_34\n" ".unreq qX_34\n" - ".unreq xX_11\n" - ".unreq xX_12\n" - ".unreq xX_13\n" - ".unreq xX_14\n" - ".unreq xX_21\n" ".unreq qxX_21\n" - ".unreq xX_22\n" ".unreq qxX_22\n" - ".unreq xX_23\n" ".unreq qxX_23\n" - ".unreq xX_24\n" ".unreq qxX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [colstride3] "r" (input_col_stride * sizeof(float) * 3), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} - -// Pad right by 1 -template <> -template <> -inline void Winograd2x2_3x3GemmInputChannelwise::_process_tile<0, 0, 0, 1, 4>( - int &n_channels, // Number of channels in the tile - const float* &inptr0, - const int input_row_stride, - const int input_col_stride, - float* &outptr0, - const int matrix_stride -) { - // We use 4 pointers to point to the starting position on each row and use - // three offsets to extract elements from each of the other 3 columns. - auto inptr1 = inptr0 + 1*input_row_stride; - auto inptr2 = inptr0 + 2*input_row_stride; - auto inptr3 = inptr0 + 3*input_row_stride; - - // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three - // offsets to access the intermediate matrices. - auto outptr1 = outptr0 + matrix_stride * 4; - auto outptr2 = outptr0 + matrix_stride * 8; - auto outptr3 = outptr0 + matrix_stride * 12; - - for (; n_channels > 3; n_channels -= 4) { - asm volatile ( - "X_11 .req v0\n" "qX_11 .req q0\n" - "X_12 .req v1\n" "qX_12 .req q1\n" - "X_13 .req v2\n" "qX_13 .req q2\n" - "X_21 .req v4\n" "qX_21 .req q4\n" - "X_22 .req v5\n" "qX_22 .req q5\n" - "X_23 .req v6\n" "qX_23 .req q6\n" - "X_31 .req v8\n" "qX_31 .req q8\n" - "X_32 .req v9\n" "qX_32 .req q9\n" - "X_33 .req v10\n" "qX_33 .req q10\n" - "X_41 .req v12\n" "qX_41 .req q12\n" - "X_42 .req v13\n" "qX_42 .req q13\n" - "X_43 .req v14\n" "qX_43 .req q14\n" - "xX_11 .req v16\n" - "xX_12 .req v17\n" - "xX_13 .req v18\n" - "xX_14 .req x_12\n" - "xX_21 .req v20\n" - "xX_22 .req v21\n" - "xX_23 .req v22\n" - "xX_24 .req x_22\n" - "xX_31 .req v24\n" - "xX_32 .req v25\n" - "xX_33 .req v26\n" - "xX_34 .req x_32\n" - "xX_41 .req v28\n" - "xX_42 .req v29\n" - "xX_43 .req v30\n" - "xX_44 .req x_42\n" - " U .req v0\n" - "qU .req q0\n" - - // Load the tile, and compute compute the matrix xX - "ldr qX_11, [%x[inptr0]]\n" - "ldr qX_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qX_13, [%x[inptr0], %x[colstride2]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qX_21, [%x[inptr1]]\n" - "fsub xX_11.4s, x_11.4s, x_13.4s\n" - "ldr qX_22, [%x[inptr1], %x[colstride1]]\n" - "fadd xX_12.4s, x_12.4s, x_13.4s\n" - "ldr qX_23, [%x[inptr1], %x[colstride2]]\n" - "fsub xX_13.4s, x_13.4s, x_12.4s\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qX_31, [%x[inptr2]]\n" - "fsub xX_21.4s, x_21.4s, x_23.4s\n" - "ldr qX_32, [%x[inptr2], %x[colstride1]]\n" - "fadd xX_22.4s, x_22.4s, x_23.4s\n" - "ldr qX_33, [%x[inptr2], %x[colstride2]]\n" - "fsub xX_23.4s, x_23.4s, x_22.4s\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - "ldr qX_41, [%x[inptr3]]\n" - "fsub xX_31.4s, x_31.4s, x_33.4s\n" - "ldr qX_42, [%x[inptr3], %x[colstride1]]\n" - "fadd xX_32.4s, x_32.4s, x_33.4s\n" - "ldr qX_43, [%x[inptr3], %x[colstride2]]\n" - "fsub xX_33.4s, x_33.4s, x_32.4s\n" - "add %x[inptr3], %x[inptr3], #0x10\n" - - // Complete computing xX while beginning to compute and store - // $U = X.T x X$ - - "fsub xX_41.4s, x_41.4s, x_43.4s\n" - - "fsub U.4s, xX_11.4s, xX_31.4s\n" - "str qU, [%x[outptr0]]\n" - "fsub U.4s, xX_12.4s, xX_32.4s\n" - "str qU, [%x[outptr0], %x[mstride1]]\n" - "fsub U.4s, xX_13.4s, xX_33.4s\n" - "str qU, [%x[outptr0], %x[mstride2]]\n" - "fsub U.4s, xX_14.4s, xX_34.4s\n" - "str qU, [%x[outptr0], %x[mstride3]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd xX_42.4s, x_42.4s, x_43.4s\n" - - "fadd U.4s, xX_21.4s, xX_31.4s\n" - "str qU, [%x[outptr4]]\n" - "fadd U.4s, xX_22.4s, xX_32.4s\n" - "str qU, [%x[outptr4], %x[mstride1]]\n" - "fadd U.4s, xX_23.4s, xX_33.4s\n" - "str qU, [%x[outptr4], %x[mstride2]]\n" - "fadd U.4s, xX_24.4s, xX_34.4s\n" - "str qU, [%x[outptr4], %x[mstride3]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fsub xX_43.4s, x_43.4s, x_42.4s\n" - - "fsub U.4s, xX_31.4s, xX_21.4s\n" - "str qU, [%x[outptr8]]\n" - "fsub U.4s, xX_32.4s, xX_22.4s\n" - "str qU, [%x[outptr8], %x[mstride1]]\n" - "fsub U.4s, xX_33.4s, xX_23.4s\n" - "str qU, [%x[outptr8], %x[mstride2]]\n" - "fsub U.4s, xX_34.4s, xX_24.4s\n" - "str qU, [%x[outptr8], %x[mstride3]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fsub U.4s, xX_21.4s, xX_41.4s\n" - "str qU, [%x[outptr12]]\n" - "fsub U.4s, xX_22.4s, xX_42.4s\n" - "str qU, [%x[outptr12], %x[mstride1]]\n" - "fsub U.4s, xX_23.4s, xX_43.4s\n" - "str qU, [%x[outptr12], %x[mstride2]]\n" - "fsub U.4s, xX_24.4s, xX_44.4s\n" - "str qU, [%x[outptr12], %x[mstride3]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - ".unreq qU\n" - ".unreq U\n" - ".unreq X_11\n" ".unreq qX_11\n" - ".unreq X_12\n" ".unreq qX_12\n" - ".unreq X_13\n" ".unreq qX_13\n" - ".unreq X_21\n" ".unreq qX_21\n" - ".unreq X_22\n" ".unreq qX_22\n" - ".unreq X_23\n" ".unreq qX_23\n" - ".unreq X_31\n" ".unreq qX_31\n" - ".unreq X_32\n" ".unreq qX_32\n" - ".unreq X_33\n" ".unreq qX_33\n" - ".unreq X_41\n" ".unreq qX_41\n" - ".unreq X_42\n" ".unreq qX_42\n" - ".unreq X_43\n" ".unreq qX_43\n" - ".unreq xX_11\n" - ".unreq xX_12\n" - ".unreq xX_13\n" - ".unreq xX_14\n" - ".unreq xX_21\n" - ".unreq xX_22\n" - ".unreq xX_23\n" - ".unreq xX_24\n" - ".unreq xX_31\n" - ".unreq xX_32\n" - ".unreq xX_33\n" - ".unreq xX_34\n" - ".unreq xX_41\n" - ".unreq xX_42\n" - ".unreq xX_43\n" - ".unreq xX_44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [inptr3] "+r" (inptr3), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr1), - [outptr8] "+r" (outptr2), - [outptr12] "+r" (outptr3) - : [colstride1] "r" (input_col_stride * sizeof(float)), - [colstride2] "r" (input_col_stride * sizeof(float) * 2), - [mstride1] "r" (matrix_stride * sizeof(float)), - [mstride2] "r" (matrix_stride * sizeof(float) * 2), - [mstride3] "r" (matrix_stride * sizeof(float) * 3) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31" - ); - } -} -} -#endif diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3_fp32.cpp new file mode 100644 index 0000000000..381ae92182 --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3_fp32.cpp @@ -0,0 +1,409 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "transforms/input.hpp" +#include "winograd_gemm.hpp" +#include "arm.hpp" + +namespace winograd +{ + +using Transform = WinogradGEMM<2, 2, 3, 3>::InputTransform; + +/****************************************************************************** + * Cost methods for the input transform. + * ===================================== + */ +template <> +template <> +int Transform::ops_performed(const Tensor4DShape &input_shape) +{ + // NOTE: Cost in FLOPs rather than instructions or uops. + const int tile_M = iceildiv(input_shape.n_rows, inner_tile_rows); + const int tile_N = iceildiv(input_shape.n_cols, inner_tile_cols); + return 16 * 16 * tile_M * tile_N * input_shape.n_channels; +} +/*****************************************************************************/ + +/***************************************************************************** +* F(2x2, 3x3) implies the use of a 4x4 input tile. Such tiles can require a +* variety of padding types. For example, tiles at the top and left of an image +* can require one row or column of padding on their top and left sides if the +* padding type is SAME (where X represents a padded value): +* +* _______ _______ +* |X X X X| |X X X X| +* |X | | | . . . +* |X | | | +* |X______| |_______| +* _______ +* |X | . +* |X | . . . . +* |X | . +* |X______| +* +* For tiles near the right or bottom of the image it is more complicated. Such +* tiles might require padding by 0 or 1 rows or columns if the padding type is +* VALID or 1 or 2 rows or columns if the padding type is SAME: +* +* _______ _______ _______ _______ +* |X X X X| |X X X X| |X X X X| |X X X X| +* |X | | | | X| | X X| +* |X | | | | X| | X X| +* |X______| |_______| |______X| |____X_X| +* _______ _______ _______ _______ +* |X | | | | X| | X X| +* |X | | | | X| | X X| +* |X | | | | X| | X X| +* |X______| |_______| |______X| |____X_X| +* _______ _______ _______ _______ +* |X | | | | X| | X X| +* |X | | | | X| | X X| +* |X | | | | X| | X X| +* |X_X_X_X| |X_X_X_X| |X_X_X_X| |X_X_X_X| +* _______ _______ _______ _______ +* |X | | | | X| | X X| +* |X | | | | X| | X X| +* |X X X X| |X X X X| |X X X X| |X X X X| +* |X_X_X_X| |X_X_X_X| |X_X_X_X| |X_X_X_X| +* +* Additional tiles are required for especially small input images. +* +* Build an array of the specialised methods that deal with each of the +* different padding combinations which may be required. These padding +* constraints are the space: +* +* Padding top in {0, 1} +* Padding left in {0, 1} +* Padding bottom in {0, 1, 2} +* Padding right in {0, 1, 2} +*/ +template <> +template <> +template +void Transform::process_tile( + int n_channels, + const float* const input_base, + const int input_row_stride, + const int input_col_stride, + float* const matrix_base, + const int matrix_stride +) +{ + constexpr int inner_tile_i = 4, inner_tile_j = 4; + constexpr int cells_i = inner_tile_i - pad_bottom; + constexpr int cells_j = inner_tile_i - pad_right; + + float *outptr = matrix_base; + + // Get pointers into the input tile + const float *x_ptrs[inner_tile_i][inner_tile_j]; + for (int i = pad_top, xi = 0; i < cells_i; i++, xi++) + { + // Get a pointer into the row + const float* const row_ptr = input_base + xi*input_row_stride; + + for (int j = pad_left, xj = 0; j < cells_j; j++, xj++) + { + x_ptrs[i][j] = row_ptr + xj*input_col_stride; + } + } + + // Matrices used/computed in this kernel. + float x[inner_tile_i][inner_tile_j]; + float XTx[inner_tile_i][inner_tile_j]; + float U[inner_tile_i][inner_tile_j]; + + for (int i = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++) + { + x[i][j] = XTx[i][j] = 0.0f; + } + } + + // Perform the Winograd input transformation for each channel in the input + // tensor. + int channels_remaining = n_channels; +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used/computed in this kernel. + float32x4_t x[inner_tile_i][inner_tile_j]; + float32x4_t XTx[inner_tile_i][inner_tile_j]; + float32x4_t U[inner_tile_i][inner_tile_j]; + + for (int i = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++) + { + x[i][j] = vdupq_n_f32(0.0f); + XTx[i][j] = vdupq_n_f32(0.0f); + } + } + + // Load x + for (int i = pad_top; i < cells_i; i++) + { + for (int j = pad_left; j < cells_j; j++) + { + x[i][j] = vld1q_f32(x_ptrs[i][j]); + x_ptrs[i][j] += 4; + } + } + + // Compute XT . x + for (int j = pad_left; j < cells_j; j++) + { + // XTx[0][j] = x[0][j] - x[2][j]; + XTx[0][j] = vsubq_f32(x[0][j], x[2][j]); + + // XTx[1][j] = x[1][j] + x[2][j]; + XTx[1][j] = vaddq_f32(x[1][j], x[2][j]); + + // XTx[2][j] = x[2][j] - x[1][j]; + XTx[2][j] = vsubq_f32(x[2][j], x[1][j]); + + // XTx[3][j] = x[1][j] - x[3][j]; + XTx[3][j] = vsubq_f32(x[1][j], x[3][j]); + } + + // Compute U = XT . x . X + for (int i = 0; i < inner_tile_i; i++) + { + // U[i][0] = XTx[i][0] - XTx[i][2]; + U[i][0] = vsubq_f32(XTx[i][0], XTx[i][2]); + + // U[i][1] = XTx[i][1] + XTx[i][2]; + U[i][1] = vaddq_f32(XTx[i][1], XTx[i][2]); + + // U[i][2] = XTx[i][2] - XTx[i][1]; + U[i][2] = vsubq_f32(XTx[i][2], XTx[i][1]); + + // U[i][3] = XTx[i][1] - XTx[i][3]; + U[i][3] = vsubq_f32(XTx[i][1], XTx[i][3]); + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++, m++) + { + vst1q_f32(outptr + m*matrix_stride, U[i][j]); + } + } + outptr += 4; + } +#endif // __aarch64__ +#ifdef __arm_any__ + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used/computed in this kernel. + float32x2_t x[inner_tile_i][inner_tile_j]; + float32x2_t XTx[inner_tile_i][inner_tile_j]; + float32x2_t U[inner_tile_i][inner_tile_j]; + + for (int i = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++) + { + x[i][j] = vdup_n_f32(0.0f); + XTx[i][j] = vdup_n_f32(0.0f); + } + } + + // Load x + for (int i = pad_top; i < cells_i; i++) + { + for (int j = pad_left; j < cells_j; j++) + { + x[i][j] = vld1_f32(x_ptrs[i][j]); + x_ptrs[i][j] += 2; + } + } + + // Compute XT . x + for (int j = pad_left; j < cells_j; j++) + { + // XTx[0][j] = x[0][j] - x[2][j]; + XTx[0][j] = vsub_f32(x[0][j], x[2][j]); + + // XTx[1][j] = x[1][j] + x[2][j]; + XTx[1][j] = vadd_f32(x[1][j], x[2][j]); + + // XTx[2][j] = x[2][j] - x[1][j]; + XTx[2][j] = vsub_f32(x[2][j], x[1][j]); + + // XTx[3][j] = x[1][j] - x[3][j]; + XTx[3][j] = vsub_f32(x[1][j], x[3][j]); + } + + // Compute U = XT . x . X + for (int i = 0; i < inner_tile_i; i++) + { + // U[i][0] = XTx[i][0] - XTx[i][2]; + U[i][0] = vsub_f32(XTx[i][0], XTx[i][2]); + + // U[i][1] = XTx[i][1] + XTx[i][2]; + U[i][1] = vadd_f32(XTx[i][1], XTx[i][2]); + + // U[i][2] = XTx[i][2] - XTx[i][1]; + U[i][2] = vsub_f32(XTx[i][2], XTx[i][1]); + + // U[i][3] = XTx[i][1] - XTx[i][3]; + U[i][3] = vsub_f32(XTx[i][1], XTx[i][3]); + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++, m++) + { + vst1_f32(outptr + m*matrix_stride, U[i][j]); + } + } + outptr += 2; + } +#endif // __arm_any__ + for (; channels_remaining; channels_remaining--) + { + // Load x + for (int i = pad_top; i < cells_i; i++) + { + for (int j = pad_left; j < cells_j; j++) + { + x[i][j] = *(x_ptrs[i][j]++); + } + } + + // Compute XT . x + for (int j = pad_left; j < cells_j; j++) + { + XTx[0][j] = x[0][j] - x[2][j]; + XTx[1][j] = x[1][j] + x[2][j]; + XTx[2][j] = x[2][j] - x[1][j]; + XTx[3][j] = x[1][j] - x[3][j]; + } + + // Compute U = XT . x . X + for (int i = 0; i < inner_tile_i; i++) + { + U[i][0] = XTx[i][0] - XTx[i][2]; + U[i][1] = XTx[i][1] + XTx[i][2]; + U[i][2] = XTx[i][2] - XTx[i][1]; + U[i][3] = XTx[i][1] - XTx[i][3]; + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++, m++) + { + *(outptr + m*matrix_stride) = U[i][j]; + } + } + outptr++; + } +} + +template <> +template <> +const Transform::TileFn Transform::tile_fns[2][2][max_pad_bottom][max_pad_right] = +{ + { + { + { + Transform::template process_tile<0, 0, 0, 0>, // No padding + Transform::template process_tile<0, 0, 0, 1>, // Right + Transform::template process_tile<0, 0, 0, 2>, // Right + }, + { + Transform::template process_tile<0, 0, 1, 0>, // Bottom + Transform::template process_tile<0, 0, 1, 1>, // Bottom-right + Transform::template process_tile<0, 0, 1, 2>, // Bottom-right + }, + { + Transform::template process_tile<0, 0, 2, 0>, // Bottom + Transform::template process_tile<0, 0, 2, 1>, // Bottom-right + Transform::template process_tile<0, 0, 2, 2>, // Bottom-right + } + }, + { + { + Transform::template process_tile<0, 1, 0, 0>, // Left + Transform::template process_tile<0, 1, 0, 1>, // Left AND right + Transform::template process_tile<0, 1, 0, 2>, // Left AND right + }, + { + Transform::template process_tile<0, 1, 1, 0>, // Left-bottom + Transform::template process_tile<0, 1, 1, 1>, // Left, bottom AND right + Transform::template process_tile<0, 1, 1, 2>, // Left, bottom AND right + }, + { + Transform::template process_tile<0, 1, 2, 0>, // Left-bottom + Transform::template process_tile<0, 1, 2, 1>, // Left, bottom AND right + Transform::template process_tile<0, 1, 2, 2>, // Left, bottom AND right + } + }, + }, + { + { + { + Transform::template process_tile<1, 0, 0, 0>, // Top + Transform::template process_tile<1, 0, 0, 1>, // Top-right + Transform::template process_tile<1, 0, 0, 2>, // Top-right + }, + { + Transform::template process_tile<1, 0, 1, 0>, // Top AND bottom + Transform::template process_tile<1, 0, 1, 1>, // Top, bottom AND right + Transform::template process_tile<1, 0, 1, 2>, // Top, bottom AND right + }, + { + Transform::template process_tile<1, 0, 2, 0>, // Top AND bottom + Transform::template process_tile<1, 0, 2, 1>, // Top, bottom AND right + Transform::template process_tile<1, 0, 2, 2>, // Top, bottom AND right + } + }, + { + { + Transform::template process_tile<1, 1, 0, 0>, // Top-left + Transform::template process_tile<1, 1, 0, 1>, // Top, left AND right + Transform::template process_tile<1, 1, 0, 2>, // Top, left AND right + }, + { + Transform::template process_tile<1, 1, 1, 0>, // Top, left AND bottom + Transform::template process_tile<1, 1, 1, 1>, // All padded + Transform::template process_tile<1, 1, 1, 2>, // All padded + }, + { + Transform::template process_tile<1, 1, 2, 0>, // Top, left AND bottom + Transform::template process_tile<1, 1, 2, 1>, // All padded + Transform::template process_tile<1, 1, 2, 2>, // All padded + } + } + } +}; + +template struct WinogradGEMM<2, 2, 3, 3>::InputTransform; +} // namespace winograd diff --git a/src/core/NEON/kernels/winograd/transforms/input_4x4_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/input_4x4_3x3_fp32.cpp new file mode 100644 index 0000000000..477aaaf34e --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/input_4x4_3x3_fp32.cpp @@ -0,0 +1,486 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "transforms/input.hpp" +#include "winograd_gemm.hpp" +#include "arm.hpp" + +namespace winograd +{ + +using Transform = WinogradGEMM<4, 4, 3, 3>::InputTransform; + +template <> +template <> +int Transform::ops_performed(const Tensor4DShape &input_shape) +{ + // NOTE: Cost in FLOPs rather than instructions or uops. + const int tile_M = iceildiv(input_shape.n_rows, inner_tile_rows); + const int tile_N = iceildiv(input_shape.n_cols, inner_tile_cols); + return 12 * 24 * tile_M * tile_N * input_shape.n_channels; +} + +/* F(4x4, 3x3) implies the use of a 6x6 input tile. Such tiles can require a +* variety of padding types. For example, tiles at the top and left of an +* image can require one row or column of padding on their top and left sides +* if the padding type is SAME (where X represents a padded value): +* +* ___________ ___________ +* |X X X X X X| |X X X X X X| +* |X | | | +* |X | | | +* |X | | | +* |X | | | +* |X__________| |___________| +* ___________ +* |X | +* |X | +* |X | +* |X | +* |X | +* |X__________| +* +* For tiles near the right or bottom of the image it is more complicated. +* Such tiles might require padding by 0, 1, 2 or 3 rows or columns if the +* padding type is VALID or 1, 2, 3 or 4 rows or columns if the padding +* type is SAME. +* +* Build an array of the specialised methods that deal with each of the +* different padding combinations which may be required. These padding +* constraints are the space: +* +* Padding top in {0, 1} +* Padding left in {0, 1} +* Padding bottom in {0, 1, 2, 3, 4} +* Padding right in {0, 1, 2, 3, 4} +*/ +template <> +template <> +template +void Transform::process_tile( + int n_channels, + const float* const input_base, + const int input_row_stride, + const int input_col_stride, + float* const matrix_base, + const int matrix_stride +) +{ + constexpr int cells_i = 6 - pad_bottom; + constexpr int cells_j = 6 - pad_right; + + float *outptr = matrix_base; + + // Get pointers into the input tile + const float *x_ptrs[6][6]; + for (int i = pad_top, xi = 0; i < cells_i; i++, xi++) + { + // Get a pointer into the row + const float* const row_ptr = input_base + xi*input_row_stride; + + for (int j = pad_left, xj = 0; j < cells_j; j++, xj++) + { + x_ptrs[i][j] = row_ptr + xj*input_col_stride; + } + } + + // Matrices used/computed in this kernel. + float x[6][6], XTx[6][6], U[6][6]; + for (int i = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++) + { + x[i][j] = XTx[i][j] = 0.0f; + } + } + + // Perform the Winograd input transformation for each channel in the input + // tensor. + int channels_remaining = n_channels; +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used/computed in this kernel + float32x4_t x[6][6], XTx[6][6], U[6][6]; + for (int i = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++) + { + x[i][j] = vdupq_n_f32(0.0f); + XTx[i][j] = vdupq_n_f32(0.0f); + } + } + + // Read a 6x6 tile in the Winograd domain + for (int i = pad_top; i < cells_i; i++) + { + for (int j = pad_left; j < cells_j; j++) + { + x[i][j] = vld1q_f32(x_ptrs[i][j]); + x_ptrs[i][j] += 4; + } + } + + // Compute XT . x + for (int j = pad_left; j < cells_j; j++) + { + // XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j]; + XTx[0][j] = vmlsq_n_f32(vmlaq_n_f32(x[4][j], x[0][j], 4.0f), x[2][j], 5.0f); + + // XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j]; + XTx[1][j] = vmlsq_n_f32(vaddq_f32(x[3][j], x[4][j]), vaddq_f32(x[1][j], x[2][j]), 4.0f); + + // XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j]; + XTx[2][j] = vmlaq_n_f32(vsubq_f32(x[4][j], x[3][j]), vsubq_f32(x[1][j], x[2][j]), 4.0f); + + // XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j]; + XTx[3][j] = vmlaq_n_f32(vsubq_f32(x[4][j], x[2][j]), vsubq_f32(x[3][j], x[1][j]), 2.0f); + + // XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j]; + XTx[4][j] = vmlaq_n_f32(vsubq_f32(x[4][j], x[2][j]), vsubq_f32(x[1][j], x[3][j]), 2.0f); + + // XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j]; + XTx[5][j] = vmlsq_n_f32(vmlaq_n_f32(x[5][j], x[1][j], 4.0f), x[3][j], 5.0f); + } + + // Compute U = XT . x . X + for (int i = 0; i < 6; i++) + { + // U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4]; + U[i][0] = vmlsq_n_f32(vmlaq_n_f32(XTx[i][4], XTx[i][0], 4.0f), XTx[i][2], 5.0f); + + // U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4]; + U[i][1] = vmlsq_n_f32(vaddq_f32(XTx[i][3], XTx[i][4]), vaddq_f32(XTx[i][1], XTx[i][2]), 4.0f); + + // U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4]; + U[i][2] = vmlaq_n_f32(vsubq_f32(XTx[i][4], XTx[i][3]), vsubq_f32(XTx[i][1], XTx[i][2]), 4.0f); + + // U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4]; + U[i][3] = vmlaq_n_f32(vsubq_f32(XTx[i][4], XTx[i][2]), vsubq_f32(XTx[i][3], XTx[i][1]), 2.0f); + + // U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4]; + U[i][4] = vmlaq_n_f32(vsubq_f32(XTx[i][4], XTx[i][2]), vsubq_f32(XTx[i][1], XTx[i][3]), 2.0f); + + // U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5]; + U[i][5] = vmlsq_n_f32(vmlaq_n_f32(XTx[i][5], XTx[i][1], 4.0f), XTx[i][3], 5.0f); + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + vst1q_f32(outptr + m*matrix_stride, U[i][j]); + } + } + outptr += 4; + } +#endif // __aarch64__ +#ifdef __arm_any__ + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used/computed in this kernel + float32x2_t x[6][6], XTx[6][6], U[6][6]; + for (int i = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++) + { + x[i][j] = vdup_n_f32(0.0f); + XTx[i][j] = vdup_n_f32(0.0f); + } + } + + // Read a 6x6 tile in the Winograd domain + for (int i = pad_top; i < cells_i; i++) + { + for (int j = pad_left; j < cells_j; j++) + { + x[i][j] = vld1_f32(x_ptrs[i][j]); + x_ptrs[i][j] += 2; + } + } + + // Compute XT . x + for (int j = pad_left; j < cells_j; j++) + { + // XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j]; + XTx[0][j] = vmls_n_f32(vmla_n_f32(x[4][j], x[0][j], 4.0f), x[2][j], 5.0f); + + // XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j]; + XTx[1][j] = vmls_n_f32(vadd_f32(x[3][j], x[4][j]), vadd_f32(x[1][j], x[2][j]), 4.0f); + + // XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j]; + XTx[2][j] = vmla_n_f32(vsub_f32(x[4][j], x[3][j]), vsub_f32(x[1][j], x[2][j]), 4.0f); + + // XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j]; + XTx[3][j] = vmla_n_f32(vsub_f32(x[4][j], x[2][j]), vsub_f32(x[3][j], x[1][j]), 2.0f); + + // XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j]; + XTx[4][j] = vmla_n_f32(vsub_f32(x[4][j], x[2][j]), vsub_f32(x[1][j], x[3][j]), 2.0f); + + // XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j]; + XTx[5][j] = vmls_n_f32(vmla_n_f32(x[5][j], x[1][j], 4.0f), x[3][j], 5.0f); + } + + // Compute U = XT . x . X + for (int i = 0; i < 6; i++) + { + // U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4]; + U[i][0] = vmls_n_f32(vmla_n_f32(XTx[i][4], XTx[i][0], 4.0f), XTx[i][2], 5.0f); + + // U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4]; + U[i][1] = vmls_n_f32(vadd_f32(XTx[i][3], XTx[i][4]), vadd_f32(XTx[i][1], XTx[i][2]), 4.0f); + + // U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4]; + U[i][2] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][3]), vsub_f32(XTx[i][1], XTx[i][2]), 4.0f); + + // U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4]; + U[i][3] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][2]), vsub_f32(XTx[i][3], XTx[i][1]), 2.0f); + + // U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4]; + U[i][4] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][2]), vsub_f32(XTx[i][1], XTx[i][3]), 2.0f); + + // U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5]; + U[i][5] = vmls_n_f32(vmla_n_f32(XTx[i][5], XTx[i][1], 4.0f), XTx[i][3], 5.0f); + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + vst1_f32(outptr + m*matrix_stride, U[i][j]); + } + } + outptr += 2; + } +#endif // __arm_any__ + for (; channels_remaining; channels_remaining--) + { + // Load x + for (int i = pad_top; i < cells_i; i++) + { + for (int j = pad_left; j < cells_j; j++) + { + x[i][j] = *(x_ptrs[i][j]++); + } + } + + // Compute XT . x + for (int j = pad_left; j < cells_j; j++) + { + XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j]; + XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j]; + XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j]; + XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j]; + XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j]; + XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j]; + } + + // Compute U = XT . x . X + for (int i = 0; i < 6; i++) + { + U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4]; + U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4]; + U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4]; + U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4]; + U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4]; + U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5]; + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + *(outptr + m*matrix_stride) = U[i][j]; + } + } + outptr++; + } +} + +/* In the below, unusual or especially small tiles are routed via the slow + * path whereas common or large tiles are routed through a faster path. + */ +template <> +template <> +const Transform::TileFn Transform::tile_fns[2][2][max_pad_bottom][max_pad_right] = +{ + { + { + { + Transform::template process_tile<0, 0, 0, 0>, // No padding + Transform::template process_tile<0, 0, 0, 1>, // Right + Transform::template process_tile<0, 0, 0, 2>, // " " + Transform::template process_tile<0, 0, 0, 3>, // " " + Transform::template process_tile<0, 0, 0, 4>, // " " + }, + { + Transform::template process_tile<0, 0, 1, 0>, // Bottom + Transform::template process_tile<0, 0, 1, 1>, // Bottom right + Transform::template process_tile<0, 0, 1, 2>, // " " + Transform::template process_tile<0, 0, 1, 3>, // " " + Transform::template process_tile<0, 0, 1, 4>, // " " + }, + { + Transform::template process_tile<0, 0, 2, 0>, // Bottom + Transform::template process_tile<0, 0, 2, 1>, // Bottom right + Transform::template process_tile<0, 0, 2, 2>, // " " + Transform::template process_tile<0, 0, 2, 3>, // " " + Transform::template process_tile<0, 0, 2, 4>, // " " + }, + { + Transform::template process_tile<0, 0, 3, 0>, // Bottom + Transform::template process_tile<0, 0, 3, 1>, // Bottom right + Transform::template process_tile<0, 0, 3, 2>, // " " + Transform::template process_tile<0, 0, 3, 3>, // " " + Transform::template process_tile<0, 0, 3, 4>, // " " + }, + { + Transform::template process_tile<0, 0, 4, 0>, // Bottom + Transform::template process_tile<0, 0, 4, 1>, // Bottom right + Transform::template process_tile<0, 0, 4, 2>, // " " + Transform::template process_tile<0, 0, 4, 3>, // " " + Transform::template process_tile<0, 0, 4, 4>, // " " + } + }, + { + { + Transform::template process_tile<0, 1, 0, 0>, // Left + Transform::template process_tile<0, 1, 0, 1>, + Transform::template process_tile<0, 1, 0, 2>, + Transform::template process_tile<0, 1, 0, 3>, + Transform::template process_tile<0, 1, 0, 4>, + }, + { + Transform::template process_tile<0, 1, 1, 0>, // Bottom left + Transform::template process_tile<0, 1, 1, 1>, + Transform::template process_tile<0, 1, 1, 2>, + Transform::template process_tile<0, 1, 1, 3>, + Transform::template process_tile<0, 1, 1, 4>, + }, + { + Transform::template process_tile<0, 1, 2, 0>, // " " + Transform::template process_tile<0, 1, 2, 1>, + Transform::template process_tile<0, 1, 2, 2>, + Transform::template process_tile<0, 1, 2, 3>, + Transform::template process_tile<0, 1, 2, 4>, + }, + { + Transform::template process_tile<0, 1, 3, 0>, // " " + Transform::template process_tile<0, 1, 3, 1>, + Transform::template process_tile<0, 1, 3, 2>, + Transform::template process_tile<0, 1, 3, 3>, + Transform::template process_tile<0, 1, 3, 4>, + }, + { + Transform::template process_tile<0, 1, 4, 0>, // " " + Transform::template process_tile<0, 1, 4, 1>, + Transform::template process_tile<0, 1, 4, 2>, + Transform::template process_tile<0, 1, 4, 3>, + Transform::template process_tile<0, 1, 4, 4>, + } + } + }, + { + { + { + Transform::template process_tile<1, 0, 0, 0>, // Top + Transform::template process_tile<1, 0, 0, 1>, // Top right + Transform::template process_tile<1, 0, 0, 2>, // " " + Transform::template process_tile<1, 0, 0, 3>, // " " + Transform::template process_tile<1, 0, 0, 4>, // " " + }, + { + Transform::template process_tile<1, 0, 1, 0>, + Transform::template process_tile<1, 0, 1, 1>, + Transform::template process_tile<1, 0, 1, 2>, + Transform::template process_tile<1, 0, 1, 3>, + Transform::template process_tile<1, 0, 1, 4>, + }, + { + Transform::template process_tile<1, 0, 2, 0>, + Transform::template process_tile<1, 0, 2, 1>, + Transform::template process_tile<1, 0, 2, 2>, + Transform::template process_tile<1, 0, 2, 3>, + Transform::template process_tile<1, 0, 2, 4>, + }, + { + Transform::template process_tile<1, 0, 3, 0>, + Transform::template process_tile<1, 0, 3, 1>, + Transform::template process_tile<1, 0, 3, 2>, + Transform::template process_tile<1, 0, 3, 3>, + Transform::template process_tile<1, 0, 3, 4>, + }, + { + Transform::template process_tile<1, 0, 4, 0>, + Transform::template process_tile<1, 0, 4, 1>, + Transform::template process_tile<1, 0, 4, 2>, + Transform::template process_tile<1, 0, 4, 3>, + Transform::template process_tile<1, 0, 4, 4>, + }, + }, + { + { + Transform::template process_tile<1, 1, 0, 0>, // Top left + Transform::template process_tile<1, 1, 0, 1>, + Transform::template process_tile<1, 1, 0, 2>, + Transform::template process_tile<1, 1, 0, 3>, + Transform::template process_tile<1, 1, 0, 4>, + }, + { + Transform::template process_tile<1, 1, 1, 0>, + Transform::template process_tile<1, 1, 1, 1>, + Transform::template process_tile<1, 1, 1, 2>, + Transform::template process_tile<1, 1, 1, 3>, + Transform::template process_tile<1, 1, 1, 4>, + }, + { + Transform::template process_tile<1, 1, 2, 0>, + Transform::template process_tile<1, 1, 2, 1>, + Transform::template process_tile<1, 1, 2, 2>, + Transform::template process_tile<1, 1, 2, 3>, + Transform::template process_tile<1, 1, 2, 4>, + }, + { + Transform::template process_tile<1, 1, 3, 0>, + Transform::template process_tile<1, 1, 3, 1>, + Transform::template process_tile<1, 1, 3, 2>, + Transform::template process_tile<1, 1, 3, 3>, + Transform::template process_tile<1, 1, 3, 4>, + }, + { + Transform::template process_tile<1, 1, 4, 0>, + Transform::template process_tile<1, 1, 4, 1>, + Transform::template process_tile<1, 1, 4, 2>, + Transform::template process_tile<1, 1, 4, 3>, + Transform::template process_tile<1, 1, 4, 4>, + } + } + } +}; + +template struct WinogradGEMM<4, 4, 3, 3>::InputTransform; +} // namespace winograd diff --git a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp b/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp deleted file mode 100644 index 033442aa14..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -namespace winograd { - /* Transform a kernel into the Winograd domain. - * - * NOTE: It is assumed that the kernel is in the form [height x width x - * input_channels x output_channel]. - */ - template - struct winograd2x2_3x3_gemm_kernel_transform_impl{ - static void execute( - const KernelShape &shape, - const T* const kernel, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride - ); - - protected: - template - static void transform_kernel( - const T* const kernel, - const int n_input_channels, - const int n_output_channels, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride - ); - }; -} - -/*****************************************************************************/ -/* Transform a fp32 kernel into the Winograd domain. - */ -#include "kernel_2x2_3x3/a64_float.hpp" // AArch64 specialisations - -namespace winograd -{ -template <> -inline void winograd2x2_3x3_gemm_kernel_transform_impl::execute( - const KernelShape &shape, - const float* const kernel, - float* const matrix_base, - const int matrix_stride, - const int matrix_row_stride -) { - // Delegate based on tail size - const int n_input_channels = shape.n_input_channels; - const int n_output_channels = shape.n_output_channels; - - switch (n_output_channels % 4) { - case 0: - transform_kernel<0>( - kernel, n_input_channels, n_output_channels, - matrix_base, matrix_stride, matrix_row_stride - ); - break; - case 1: - transform_kernel<1>( - kernel, n_input_channels, n_output_channels, - matrix_base, matrix_stride, matrix_row_stride - ); - break; - case 2: - transform_kernel<2>( - kernel, n_input_channels, n_output_channels, - matrix_base, matrix_stride, matrix_row_stride - ); - break; - case 3: - transform_kernel<3>( - kernel, n_input_channels, n_output_channels, - matrix_base, matrix_stride, matrix_row_stride - ); - break; - default: - ARM_COMPUTE_ERROR("Cannot happen"); - break; - } -} - -template <> -template -inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel( - const float* const kernel, - const int n_input_channels, - const int n_output_channels, - float* const matrix_base, - const int mstride, - const int matrix_row_stride -) { - // Use one input pointer for each row of the kernel, use two additional - // offsets to extract columns. - const int kernel_col_stride = n_input_channels * n_output_channels; - const int kernel_row_stride = 3 * kernel_col_stride; - const float *inptr0 = kernel; - const float *inptr1 = kernel + kernel_row_stride; - const float *inptr2 = kernel + kernel_row_stride*2; - - // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three - // offsets to extract further matrices. - float *outptr0 = matrix_base; - float *outptr4 = matrix_base + mstride * 4; - float *outptr8 = matrix_base + mstride * 8; - float *outptr12 = matrix_base + mstride * 12; - - // For every input channel - for (int in_c = 0; in_c < n_input_channels; in_c++) { - // For every output channel - for (int c = 0; c < n_output_channels; c++) { - // Read in the kernel - float w11 = inptr0[0], w12 = inptr0[kernel_col_stride], w13 = inptr0[kernel_col_stride*2]; - float w21 = inptr1[0], w22 = inptr1[kernel_col_stride], w23 = inptr1[kernel_col_stride*2]; - float w31 = inptr2[0], w32 = inptr2[kernel_col_stride], w33 = inptr2[kernel_col_stride*2]; - - // Progress input pointers - inptr0++; - inptr1++; - inptr2++; - - // Compute the kernel W w, note we need only compute the middle two rows - // (2 and 3) because the first and last rows are merely copies of values - // from the matrix w. - float Ww11 = w11, Ww12 = w12, Ww13 = w13; - float Ww21 = 0.5*(w11 + w21 + w31), Ww22 = 0.5*(w12 + w22 + w32), Ww23 = 0.5*(w13 + w23 + w33); - float Ww31 = 0.5*(w11 - w21 + w31), Ww32 = 0.5*(w12 - w22 + w32), Ww33 = 0.5*(w13 - w23 + w33); - float Ww41 = w31, Ww42 = w32, Ww43 = w33; - - // Hence compute W w W.T; again note we need compute only the middle two - // columns since the first and last columns are copies of the first and - // last columns of the previous matrix. - float WwWT11 = Ww11, WwWT12 = 0.5*(Ww11 + Ww12 + Ww13), WwWT13 = 0.5*(Ww11 - Ww12 + Ww13), WwWT14 = Ww13; - float WwWT21 = Ww21, WwWT22 = 0.5*(Ww21 + Ww22 + Ww23), WwWT23 = 0.5*(Ww21 - Ww22 + Ww23), WwWT24 = Ww23; - float WwWT31 = Ww31, WwWT32 = 0.5*(Ww31 + Ww32 + Ww33), WwWT33 = 0.5*(Ww31 - Ww32 + Ww33), WwWT34 = Ww33; - float WwWT41 = Ww41, WwWT42 = 0.5*(Ww41 + Ww42 + Ww43), WwWT43 = 0.5*(Ww41 - Ww42 + Ww43), WwWT44 = Ww43; - - // Store the computed weights - outptr0[0 * mstride] = WwWT11; - outptr0[1 * mstride] = WwWT12; - outptr0[2 * mstride] = WwWT13; - outptr0[3 * mstride] = WwWT14; - - outptr4[0 * mstride] = WwWT21; - outptr4[1 * mstride] = WwWT22; - outptr4[2 * mstride] = WwWT23; - outptr4[3 * mstride] = WwWT24; - - outptr8[0 * mstride] = WwWT31; - outptr8[1 * mstride] = WwWT32; - outptr8[2 * mstride] = WwWT33; - outptr8[3 * mstride] = WwWT34; - - outptr12[0 * mstride] = WwWT41; - outptr12[1 * mstride] = WwWT42; - outptr12[2 * mstride] = WwWT43; - outptr12[3 * mstride] = WwWT44; - - // Progress output pointers - outptr0++; - outptr4++; - outptr8++; - outptr12++; - } - - // Progression to complete stride - outptr0 += matrix_row_stride - n_output_channels; - outptr4 += matrix_row_stride - n_output_channels; - outptr8 += matrix_row_stride - n_output_channels; - outptr12 += matrix_row_stride - n_output_channels; - } -} -} diff --git a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp b/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp deleted file mode 100644 index 3dd62d1ac1..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp +++ /dev/null @@ -1,822 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __aarch64__ -namespace winograd { -template <> -template <> -inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<0>( - const float* const kernel, - const int n_input_channels, - const int n_output_channels, - float* const matrix_base, - const int mstride, - const int matrix_row_stride -) { - // Use one input pointer for each row of the kernel, use two additional - // offsets to extract columns. - const int kernel_col_stride = n_input_channels * n_output_channels; - const int kernel_row_stride = 3 * kernel_col_stride; - const float *inptr0 = kernel; - const float *inptr1 = kernel + kernel_row_stride; - const float *inptr2 = kernel + kernel_row_stride*2; - - // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three - // offsets to extract further matrices. - float *outptr0 = matrix_base; - float *outptr4 = matrix_base + mstride * 4; - float *outptr8 = matrix_base + mstride * 8; - float *outptr12 = matrix_base + mstride * 12; - - // For every input channel - for (int in_c = 0; in_c < n_input_channels; in_c++) { - int n_remaining_channels = n_output_channels; - - asm volatile ( - // Registers into which to read the kernel - "w_11 .req v0\n" "qw_11 .req q0\n" - "w_12 .req v1\n" "qw_12 .req q1\n" - "w_13 .req v2\n" "qw_13 .req q2\n" - "w_21 .req v3\n" "qw_21 .req q3\n" - "w_22 .req v4\n" "qw_22 .req q4\n" - "w_23 .req v5\n" "qw_23 .req q5\n" - "w_31 .req v6\n" "qw_31 .req q6\n" - "w_32 .req v7\n" "qw_32 .req q7\n" - "w_33 .req v8\n" "qw_33 .req q8\n" - - // Transformed matrix Ww - "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" - "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" - "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" - "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" - - // Output matrix U = WwWT - "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" - "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" - "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" - "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" - - // Storage view of output matrices - "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" - "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" - "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" - "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" - - "half .req v23\n" // {0.5, ..., 0.5} - "dup half.4s, %w[one_half]\n" - "scratch .req v24\n" - - "1:" - // Load tile of the kernel - "ldr qw_11, [%x[inptr0]]\n" - "str qU11, [%x[outptr0]]\n" - "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" - "str qU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qw_21, [%x[inptr1]]\n" - "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qw_31, [%x[inptr2]]\n" - "str qU41, [%x[outptr12]]\n" - "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" - "str qU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.4s, w_11.4s, w_31.4s\n" - "fmul Ww21.4s, scratch.4s, half.4s\n" - "fmla Ww21.4s, w_21.4s, half.4s\n" - "str qU21, [%x[outptr4]]\n" - "fmul Ww31.4s, scratch.4s, half.4s\n" - "fmls Ww31.4s, w_21.4s, half.4s\n" - "str qU31, [%x[outptr8]]\n" - - "fadd scratch.4s, w_12.4s, w_32.4s\n" - "fmul Ww22.4s, scratch.4s, half.4s\n" - "fmla Ww22.4s, w_22.4s, half.4s\n" - "fmul Ww32.4s, scratch.4s, half.4s\n" - "fmls Ww32.4s, w_22.4s, half.4s\n" - - "fadd scratch.4s, w_13.4s, w_33.4s\n" - "fmul Ww23.4s, scratch.4s, half.4s\n" - "fmla Ww23.4s, w_23.4s, half.4s\n" - "str qU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.4s, scratch.4s, half.4s\n" - "fmls Ww33.4s, w_23.4s, half.4s\n" - "str qU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns - // of U and update output pointers - "fadd scratch.4s, Ww11.4s, Ww13.4s\n" - "fmul U12.4s, scratch.4s, half.4s\n" - "fmla U12.4s, Ww12.4s, half.4s\n" - "str qU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.4s, scratch.4s, half.4s\n" - "fmls U13.4s, Ww12.4s, half.4s\n" - "str qU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd scratch.4s, Ww21.4s, Ww23.4s\n" - "fmul U22.4s, scratch.4s, half.4s\n" - "fmla U22.4s, Ww22.4s, half.4s\n" - "str qU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.4s, scratch.4s, half.4s\n" - "fmls U23.4s, Ww22.4s, half.4s\n" - "str qU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fadd scratch.4s, Ww31.4s, Ww33.4s\n" - "fmul U32.4s, scratch.4s, half.4s\n" - "fmla U32.4s, Ww32.4s, half.4s\n" - "str qU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.4s, scratch.4s, half.4s\n" - "fmls U33.4s, Ww32.4s, half.4s\n" - "str qU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fadd scratch.4s, Ww41.4s, Ww43.4s\n" - "fmul U42.4s, scratch.4s, half.4s\n" - "fmla U42.4s, Ww42.4s, half.4s\n" - "str qU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.4s, scratch.4s, half.4s\n" - "fmls U43.4s, Ww42.4s, half.4s\n" - "str qU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" - "bne 1b\n" - - // Clear aliases - ".unreq half\n" - ".unreq scratch\n" - ".unreq w_11\n" ".unreq qw_11\n" - ".unreq w_12\n" ".unreq qw_12\n" - ".unreq w_13\n" ".unreq qw_13\n" - ".unreq w_21\n" ".unreq qw_21\n" - ".unreq w_22\n" ".unreq qw_22\n" - ".unreq w_23\n" ".unreq qw_23\n" - ".unreq w_31\n" ".unreq qw_31\n" - ".unreq w_32\n" ".unreq qw_32\n" - ".unreq w_33\n" ".unreq qw_33\n" - ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" - ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" - ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" - ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" - ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" - ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" - ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" - ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" - ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" - ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" - ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" - ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [n_remaining_channels] "+r" (n_remaining_channels) - : [mstride1] "r" (sizeof(float) * mstride), - [mstride2] "r" (sizeof(float) * mstride * 2), - [mstride3] "r" (sizeof(float) * mstride * 3), - [colstride1] "r" (sizeof(float) * kernel_col_stride), - [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), - [one_half] "r" (0.5f) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24" - ); - - // Progression to complete stride - outptr0 += matrix_row_stride - n_output_channels; - outptr4 += matrix_row_stride - n_output_channels; - outptr8 += matrix_row_stride - n_output_channels; - outptr12 += matrix_row_stride - n_output_channels; - } -} - -template <> -template <> -inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<2>( - const float* const kernel, - const int n_input_channels, - const int n_output_channels, - float* const matrix_base, - const int mstride, - const int matrix_row_stride -) { - // Use one input pointer for each row of the kernel, use two additional - // offsets to extract columns. - const int kernel_col_stride = n_input_channels * n_output_channels; - const int kernel_row_stride = 3 * kernel_col_stride; - const float *inptr0 = kernel; - const float *inptr1 = kernel + kernel_row_stride; - const float *inptr2 = kernel + kernel_row_stride*2; - - // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three - // offsets to extract further matrices. - float *outptr0 = matrix_base; - float *outptr4 = matrix_base + mstride * 4; - float *outptr8 = matrix_base + mstride * 8; - float *outptr12 = matrix_base + mstride * 12; - - // For every input channel - for (int in_c = 0; in_c < n_input_channels; in_c++) { - int n_remaining_channels = n_output_channels; - - asm volatile ( - // Registers into which to read the kernel - "w_11 .req v0\n" "qw_11 .req q0\n" "dw_11 .req d0\n" - "w_12 .req v1\n" "qw_12 .req q1\n" "dw_12 .req d1\n" - "w_13 .req v2\n" "qw_13 .req q2\n" "dw_13 .req d2\n" - "w_21 .req v3\n" "qw_21 .req q3\n" "dw_21 .req d3\n" - "w_22 .req v4\n" "qw_22 .req q4\n" "dw_22 .req d4\n" - "w_23 .req v5\n" "qw_23 .req q5\n" "dw_23 .req d5\n" - "w_31 .req v6\n" "qw_31 .req q6\n" "dw_31 .req d6\n" - "w_32 .req v7\n" "qw_32 .req q7\n" "dw_32 .req d7\n" - "w_33 .req v8\n" "qw_33 .req q8\n" "dw_33 .req d8\n" - - // Transformed matrix Ww - "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" - "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" - "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" - "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" - - // Output matrix U = WwWT - "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" - "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" - "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" - "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" - - // Storage view of output matrices - "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" - "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" - "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" - "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" - - "dU11 .req d0\n" "dU12 .req d15\n" "dU13 .req d16\n" "dU14 .req d2\n" - "dU21 .req d9\n" "dU22 .req d17\n" "dU23 .req d18\n" "dU24 .req d11\n" - "dU31 .req d12\n" "dU32 .req d19\n" "dU33 .req d20\n" "dU34 .req d14\n" - "dU41 .req d6\n" "dU42 .req d21\n" "dU43 .req d22\n" "dU44 .req d8\n" - - "half .req v23\n" // {0.5, ..., 0.5} - "dup half.4s, %w[one_half]\n" - "scratch .req v24\n" - - // Subtract the tail from the number of remaining channels and jump to - // the tail if necessary. - "subs %x[n_remaining_channels], %x[n_remaining_channels], #2\n" - "beq 2f\n" - - "1:" - // Load tile of the kernel - "ldr qw_11, [%x[inptr0]]\n" - "str qU11, [%x[outptr0]]\n" - "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" - "str qU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qw_21, [%x[inptr1]]\n" - "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qw_31, [%x[inptr2]]\n" - "str qU41, [%x[outptr12]]\n" - "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" - "str qU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.4s, w_11.4s, w_31.4s\n" - "fmul Ww21.4s, scratch.4s, half.4s\n" - "fmla Ww21.4s, w_21.4s, half.4s\n" - "str qU21, [%x[outptr4]]\n" - "fmul Ww31.4s, scratch.4s, half.4s\n" - "fmls Ww31.4s, w_21.4s, half.4s\n" - "str qU31, [%x[outptr8]]\n" - - "fadd scratch.4s, w_12.4s, w_32.4s\n" - "fmul Ww22.4s, scratch.4s, half.4s\n" - "fmla Ww22.4s, w_22.4s, half.4s\n" - "fmul Ww32.4s, scratch.4s, half.4s\n" - "fmls Ww32.4s, w_22.4s, half.4s\n" - - "fadd scratch.4s, w_13.4s, w_33.4s\n" - "fmul Ww23.4s, scratch.4s, half.4s\n" - "fmla Ww23.4s, w_23.4s, half.4s\n" - "str qU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.4s, scratch.4s, half.4s\n" - "fmls Ww33.4s, w_23.4s, half.4s\n" - "str qU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns - // of U and update output pointers - "fadd scratch.4s, Ww11.4s, Ww13.4s\n" - "fmul U12.4s, scratch.4s, half.4s\n" - "fmla U12.4s, Ww12.4s, half.4s\n" - "str qU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.4s, scratch.4s, half.4s\n" - "fmls U13.4s, Ww12.4s, half.4s\n" - "str qU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd scratch.4s, Ww21.4s, Ww23.4s\n" - "fmul U22.4s, scratch.4s, half.4s\n" - "fmla U22.4s, Ww22.4s, half.4s\n" - "str qU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.4s, scratch.4s, half.4s\n" - "fmls U23.4s, Ww22.4s, half.4s\n" - "str qU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fadd scratch.4s, Ww31.4s, Ww33.4s\n" - "fmul U32.4s, scratch.4s, half.4s\n" - "fmla U32.4s, Ww32.4s, half.4s\n" - "str qU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.4s, scratch.4s, half.4s\n" - "fmls U33.4s, Ww32.4s, half.4s\n" - "str qU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fadd scratch.4s, Ww41.4s, Ww43.4s\n" - "fmul U42.4s, scratch.4s, half.4s\n" - "fmla U42.4s, Ww42.4s, half.4s\n" - "str qU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.4s, scratch.4s, half.4s\n" - "fmls U43.4s, Ww42.4s, half.4s\n" - "str qU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" - "bne 1b\n" - - // Tail size 2 - "2:" - // Load tile of the kernel - "ldr dw_11, [%x[inptr0]]\n" - "str dU11, [%x[outptr0]]\n" - "ldr dw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr dw_13, [%x[inptr0], %x[colstride2]]\n" - "str dU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x08\n" - - "ldr dw_21, [%x[inptr1]]\n" - "ldr dw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr dw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x08\n" - - "ldr dw_31, [%x[inptr2]]\n" - "str dU41, [%x[outptr12]]\n" - "ldr dw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr dw_33, [%x[inptr2], %x[colstride2]]\n" - "str dU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x08\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.2s, w_11.2s, w_31.2s\n" - "fmul Ww21.2s, scratch.2s, half.2s\n" - "fmla Ww21.2s, w_21.2s, half.2s\n" - "str dU21, [%x[outptr4]]\n" - "fmul Ww31.2s, scratch.2s, half.2s\n" - "fmls Ww31.2s, w_21.2s, half.2s\n" - "str dU31, [%x[outptr8]]\n" - - "fadd scratch.2s, w_12.2s, w_32.2s\n" - "fmul Ww22.2s, scratch.2s, half.2s\n" - "fmla Ww22.2s, w_22.2s, half.2s\n" - "fmul Ww32.2s, scratch.2s, half.2s\n" - "fmls Ww32.2s, w_22.2s, half.2s\n" - - "fadd scratch.2s, w_13.2s, w_33.2s\n" - "fmul Ww23.2s, scratch.2s, half.2s\n" - "fmla Ww23.2s, w_23.2s, half.2s\n" - "str dU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.2s, scratch.2s, half.2s\n" - "fmls Ww33.2s, w_23.2s, half.2s\n" - "str dU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns of - // U and update output pointers - "fadd scratch.2s, Ww11.2s, Ww13.2s\n" - "fmul U12.2s, scratch.2s, half.2s\n" - "fmla U12.2s, Ww12.2s, half.2s\n" - "str dU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.2s, scratch.2s, half.2s\n" - "fmls U13.2s, Ww12.2s, half.2s\n" - "str dU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x08\n" - - "fadd scratch.2s, Ww21.2s, Ww23.2s\n" - "fmul U22.2s, scratch.2s, half.2s\n" - "fmla U22.2s, Ww22.2s, half.2s\n" - "str dU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.2s, scratch.2s, half.2s\n" - "fmls U23.2s, Ww22.2s, half.2s\n" - "str dU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x08\n" - - "fadd scratch.2s, Ww31.2s, Ww33.2s\n" - "fmul U32.2s, scratch.2s, half.2s\n" - "fmla U32.2s, Ww32.2s, half.2s\n" - "str dU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.2s, scratch.2s, half.2s\n" - "fmls U33.2s, Ww32.2s, half.2s\n" - "str dU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x08\n" - - "fadd scratch.2s, Ww41.2s, Ww43.2s\n" - "fmul U42.2s, scratch.2s, half.2s\n" - "fmla U42.2s, Ww42.2s, half.2s\n" - "str dU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.2s, scratch.2s, half.2s\n" - "fmls U43.2s, Ww42.2s, half.2s\n" - "str dU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x08\n" - - // Clear aliases - ".unreq half\n" - ".unreq scratch\n" - ".unreq w_11\n" ".unreq qw_11\n" ".unreq dw_11\n" - ".unreq w_12\n" ".unreq qw_12\n" ".unreq dw_12\n" - ".unreq w_13\n" ".unreq qw_13\n" ".unreq dw_13\n" - ".unreq w_21\n" ".unreq qw_21\n" ".unreq dw_21\n" - ".unreq w_22\n" ".unreq qw_22\n" ".unreq dw_22\n" - ".unreq w_23\n" ".unreq qw_23\n" ".unreq dw_23\n" - ".unreq w_31\n" ".unreq qw_31\n" ".unreq dw_31\n" - ".unreq w_32\n" ".unreq qw_32\n" ".unreq dw_32\n" - ".unreq w_33\n" ".unreq qw_33\n" ".unreq dw_33\n" - ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" - ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" - ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" - ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" - ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" - ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" - ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" - ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" - ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" - ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" - ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" - ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" - ".unreq dU11\n" ".unreq dU12\n" ".unreq dU13\n" ".unreq dU14\n" - ".unreq dU21\n" ".unreq dU22\n" ".unreq dU23\n" ".unreq dU24\n" - ".unreq dU31\n" ".unreq dU32\n" ".unreq dU33\n" ".unreq dU34\n" - ".unreq dU41\n" ".unreq dU42\n" ".unreq dU43\n" ".unreq dU44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [n_remaining_channels] "+r" (n_remaining_channels) - : [mstride1] "r" (sizeof(float) * mstride), - [mstride2] "r" (sizeof(float) * mstride * 2), - [mstride3] "r" (sizeof(float) * mstride * 3), - [colstride1] "r" (sizeof(float) * kernel_col_stride), - [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), - [one_half] "r" (0.5f) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24" - ); - - // Progression to complete stride - outptr0 += matrix_row_stride - n_output_channels; - outptr4 += matrix_row_stride - n_output_channels; - outptr8 += matrix_row_stride - n_output_channels; - outptr12 += matrix_row_stride - n_output_channels; - } -} - -template <> -template <> -inline void winograd2x2_3x3_gemm_kernel_transform_impl::transform_kernel<1>( - const float* const kernel, - const int n_input_channels, - const int n_output_channels, - float* const matrix_base, - const int mstride, - const int matrix_row_stride -) { - // Use one input pointer for each row of the kernel, use two additional - // offsets to extract columns. - const int kernel_col_stride = n_input_channels * n_output_channels; - const int kernel_row_stride = 3 * kernel_col_stride; - const float *inptr0 = kernel; - const float *inptr1 = kernel + kernel_row_stride; - const float *inptr2 = kernel + kernel_row_stride*2; - - // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three - // offsets to extract further matrices. - float *outptr0 = matrix_base; - float *outptr4 = matrix_base + mstride * 4; - float *outptr8 = matrix_base + mstride * 8; - float *outptr12 = matrix_base + mstride * 12; - - // For every input channel - for (int in_c = 0; in_c < n_input_channels; in_c++) { - int n_remaining_channels = n_output_channels; - - asm volatile ( - // Registers into which to read the kernel - "w_11 .req v0\n" "qw_11 .req q0\n" "sw_11 .req s0\n" - "w_12 .req v1\n" "qw_12 .req q1\n" "sw_12 .req s1\n" - "w_13 .req v2\n" "qw_13 .req q2\n" "sw_13 .req s2\n" - "w_21 .req v3\n" "qw_21 .req q3\n" "sw_21 .req s3\n" - "w_22 .req v4\n" "qw_22 .req q4\n" "sw_22 .req s4\n" - "w_23 .req v5\n" "qw_23 .req q5\n" "sw_23 .req s5\n" - "w_31 .req v6\n" "qw_31 .req q6\n" "sw_31 .req s6\n" - "w_32 .req v7\n" "qw_32 .req q7\n" "sw_32 .req s7\n" - "w_33 .req v8\n" "qw_33 .req q8\n" "sw_33 .req s8\n" - - // Transformed matrix Ww - "Ww11 .req w_11\n" "Ww12 .req w_12\n" "Ww13 .req w_13\n" - "Ww21 .req v9\n" "Ww22 .req v10\n" "Ww23 .req v11\n" - "Ww31 .req v12\n" "Ww32 .req v13\n" "Ww33 .req v14\n" - "Ww41 .req w_31\n" "Ww42 .req w_32\n" "Ww43 .req w_33\n" - - // Output matrix U = WwWT - "U11 .req Ww11\n" "U12 .req v15\n" "U13 .req v16\n" "U14 .req Ww13\n" - "U21 .req Ww21\n" "U22 .req v17\n" "U23 .req v18\n" "U24 .req Ww23\n" - "U31 .req Ww31\n" "U32 .req v19\n" "U33 .req v20\n" "U34 .req Ww33\n" - "U41 .req Ww41\n" "U42 .req v21\n" "U43 .req v22\n" "U44 .req Ww43\n" - - // Storage view of output matrices - "qU11 .req q0\n" "qU12 .req q15\n" "qU13 .req q16\n" "qU14 .req q2\n" - "qU21 .req q9\n" "qU22 .req q17\n" "qU23 .req q18\n" "qU24 .req q11\n" - "qU31 .req q12\n" "qU32 .req q19\n" "qU33 .req q20\n" "qU34 .req q14\n" - "qU41 .req q6\n" "qU42 .req q21\n" "qU43 .req q22\n" "qU44 .req q8\n" - - "sU11 .req s0\n" "sU12 .req s15\n" "sU13 .req s16\n" "sU14 .req s2\n" - "sU21 .req s9\n" "sU22 .req s17\n" "sU23 .req s18\n" "sU24 .req s11\n" - "sU31 .req s12\n" "sU32 .req s19\n" "sU33 .req s20\n" "sU34 .req s14\n" - "sU41 .req s6\n" "sU42 .req s21\n" "sU43 .req s22\n" "sU44 .req s8\n" - - "half .req v23\n" // {0.5, ..., 0.5} - "dup half.4s, %w[one_half]\n" - "scratch .req v24\n" - - // Subtract the tail from the number of remaining channels and jump to - // the tail if necessary. - "subs %x[n_remaining_channels], %x[n_remaining_channels], #1\n" - "beq 2f\n" - - "1:" - // Load tile of the kernel - "ldr qw_11, [%x[inptr0]]\n" - "str qU11, [%x[outptr0]]\n" - "ldr qw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr qw_13, [%x[inptr0], %x[colstride2]]\n" - "str qU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "ldr qw_21, [%x[inptr1]]\n" - "ldr qw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr qw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x10\n" - - "ldr qw_31, [%x[inptr2]]\n" - "str qU41, [%x[outptr12]]\n" - "ldr qw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr qw_33, [%x[inptr2], %x[colstride2]]\n" - "str qU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x10\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.4s, w_11.4s, w_31.4s\n" - "fmul Ww21.4s, scratch.4s, half.4s\n" - "fmla Ww21.4s, w_21.4s, half.4s\n" - "str qU21, [%x[outptr4]]\n" - "fmul Ww31.4s, scratch.4s, half.4s\n" - "fmls Ww31.4s, w_21.4s, half.4s\n" - "str qU31, [%x[outptr8]]\n" - - "fadd scratch.4s, w_12.4s, w_32.4s\n" - "fmul Ww22.4s, scratch.4s, half.4s\n" - "fmla Ww22.4s, w_22.4s, half.4s\n" - "fmul Ww32.4s, scratch.4s, half.4s\n" - "fmls Ww32.4s, w_22.4s, half.4s\n" - - "fadd scratch.4s, w_13.4s, w_33.4s\n" - "fmul Ww23.4s, scratch.4s, half.4s\n" - "fmla Ww23.4s, w_23.4s, half.4s\n" - "str qU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.4s, scratch.4s, half.4s\n" - "fmls Ww33.4s, w_23.4s, half.4s\n" - "str qU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns - // of U and update output pointers - "fadd scratch.4s, Ww11.4s, Ww13.4s\n" - "fmul U12.4s, scratch.4s, half.4s\n" - "fmla U12.4s, Ww12.4s, half.4s\n" - "str qU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.4s, scratch.4s, half.4s\n" - "fmls U13.4s, Ww12.4s, half.4s\n" - "str qU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x10\n" - - "fadd scratch.4s, Ww21.4s, Ww23.4s\n" - "fmul U22.4s, scratch.4s, half.4s\n" - "fmla U22.4s, Ww22.4s, half.4s\n" - "str qU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.4s, scratch.4s, half.4s\n" - "fmls U23.4s, Ww22.4s, half.4s\n" - "str qU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x10\n" - - "fadd scratch.4s, Ww31.4s, Ww33.4s\n" - "fmul U32.4s, scratch.4s, half.4s\n" - "fmla U32.4s, Ww32.4s, half.4s\n" - "str qU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.4s, scratch.4s, half.4s\n" - "fmls U33.4s, Ww32.4s, half.4s\n" - "str qU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x10\n" - - "fadd scratch.4s, Ww41.4s, Ww43.4s\n" - "fmul U42.4s, scratch.4s, half.4s\n" - "fmla U42.4s, Ww42.4s, half.4s\n" - "str qU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.4s, scratch.4s, half.4s\n" - "fmls U43.4s, Ww42.4s, half.4s\n" - "str qU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x10\n" - - "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n" - "bne 1b\n" - - // Tail size 1 - "2:" - // Load tile of the kernel - "ldr sw_11, [%x[inptr0]]\n" - "str sU11, [%x[outptr0]]\n" - "ldr sw_12, [%x[inptr0], %x[colstride1]]\n" - "ldr sw_13, [%x[inptr0], %x[colstride2]]\n" - "str sU14, [%x[outptr0], %x[mstride3]]\n" - "add %x[inptr0], %x[inptr0], #0x04\n" - - "ldr sw_21, [%x[inptr1]]\n" - "ldr sw_22, [%x[inptr1], %x[colstride1]]\n" - "ldr sw_23, [%x[inptr1], %x[colstride2]]\n" - "add %x[inptr1], %x[inptr1], #0x04\n" - - "ldr sw_31, [%x[inptr2]]\n" - "str sU41, [%x[outptr12]]\n" - "ldr sw_32, [%x[inptr2], %x[colstride1]]\n" - "ldr sw_33, [%x[inptr2], %x[colstride2]]\n" - "str sU44, [%x[outptr12], %x[mstride3]]\n" - "add %x[inptr2], %x[inptr2], #0x04\n" - - // Compute 2nd and 3rd rows of Ww - "fadd scratch.2s, w_11.2s, w_31.2s\n" - "fmul Ww21.2s, scratch.2s, half.2s\n" - "fmla Ww21.2s, w_21.2s, half.2s\n" - "str sU21, [%x[outptr4]]\n" - "fmul Ww31.2s, scratch.2s, half.2s\n" - "fmls Ww31.2s, w_21.2s, half.2s\n" - "str sU31, [%x[outptr8]]\n" - - "fadd scratch.2s, w_12.2s, w_32.2s\n" - "fmul Ww22.2s, scratch.2s, half.2s\n" - "fmla Ww22.2s, w_22.2s, half.2s\n" - "fmul Ww32.2s, scratch.2s, half.2s\n" - "fmls Ww32.2s, w_22.2s, half.2s\n" - - "fadd scratch.2s, w_13.2s, w_33.2s\n" - "fmul Ww23.2s, scratch.2s, half.2s\n" - "fmla Ww23.2s, w_23.2s, half.2s\n" - "str sU24, [%x[outptr4], %x[mstride3]]\n" - "fmul Ww33.2s, scratch.2s, half.2s\n" - "fmls Ww33.2s, w_23.2s, half.2s\n" - "str sU34, [%x[outptr8], %x[mstride3]]\n" - - // Compute and store U, only need to compute the 2nd and 3rd columns of - // U and update output pointers - "fadd scratch.2s, Ww11.2s, Ww13.2s\n" - "fmul U12.2s, scratch.2s, half.2s\n" - "fmla U12.2s, Ww12.2s, half.2s\n" - "str sU12, [%x[outptr0], %x[mstride1]]\n" - "fmul U13.2s, scratch.2s, half.2s\n" - "fmls U13.2s, Ww12.2s, half.2s\n" - "str sU13, [%x[outptr0], %x[mstride2]]\n" - "add %x[outptr0], %x[outptr0], #0x04\n" - - "fadd scratch.2s, Ww21.2s, Ww23.2s\n" - "fmul U22.2s, scratch.2s, half.2s\n" - "fmla U22.2s, Ww22.2s, half.2s\n" - "str sU22, [%x[outptr4], %x[mstride1]]\n" - "fmul U23.2s, scratch.2s, half.2s\n" - "fmls U23.2s, Ww22.2s, half.2s\n" - "str sU23, [%x[outptr4], %x[mstride2]]\n" - "add %x[outptr4], %x[outptr4], #0x04\n" - - "fadd scratch.2s, Ww31.2s, Ww33.2s\n" - "fmul U32.2s, scratch.2s, half.2s\n" - "fmla U32.2s, Ww32.2s, half.2s\n" - "str sU32, [%x[outptr8], %x[mstride1]]\n" - "fmul U33.2s, scratch.2s, half.2s\n" - "fmls U33.2s, Ww32.2s, half.2s\n" - "str sU33, [%x[outptr8], %x[mstride2]]\n" - "add %x[outptr8], %x[outptr8], #0x04\n" - - "fadd scratch.2s, Ww41.2s, Ww43.2s\n" - "fmul U42.2s, scratch.2s, half.2s\n" - "fmla U42.2s, Ww42.2s, half.2s\n" - "str sU42, [%x[outptr12], %x[mstride1]]\n" - "fmul U43.2s, scratch.2s, half.2s\n" - "fmls U43.2s, Ww42.2s, half.2s\n" - "str sU43, [%x[outptr12], %x[mstride2]]\n" - "add %x[outptr12], %x[outptr12], #0x04\n" - - // Clear aliases - ".unreq half\n" - ".unreq scratch\n" - ".unreq w_11\n" ".unreq qw_11\n" ".unreq sw_11\n" - ".unreq w_12\n" ".unreq qw_12\n" ".unreq sw_12\n" - ".unreq w_13\n" ".unreq qw_13\n" ".unreq sw_13\n" - ".unreq w_21\n" ".unreq qw_21\n" ".unreq sw_21\n" - ".unreq w_22\n" ".unreq qw_22\n" ".unreq sw_22\n" - ".unreq w_23\n" ".unreq qw_23\n" ".unreq sw_23\n" - ".unreq w_31\n" ".unreq qw_31\n" ".unreq sw_31\n" - ".unreq w_32\n" ".unreq qw_32\n" ".unreq sw_32\n" - ".unreq w_33\n" ".unreq qw_33\n" ".unreq sw_33\n" - ".unreq Ww11\n" ".unreq Ww12\n" ".unreq Ww13\n" - ".unreq Ww21\n" ".unreq Ww22\n" ".unreq Ww23\n" - ".unreq Ww31\n" ".unreq Ww32\n" ".unreq Ww33\n" - ".unreq Ww41\n" ".unreq Ww42\n" ".unreq Ww43\n" - ".unreq U11\n" ".unreq U12\n" ".unreq U13\n" ".unreq U14\n" - ".unreq U21\n" ".unreq U22\n" ".unreq U23\n" ".unreq U24\n" - ".unreq U31\n" ".unreq U32\n" ".unreq U33\n" ".unreq U34\n" - ".unreq U41\n" ".unreq U42\n" ".unreq U43\n" ".unreq U44\n" - ".unreq qU11\n" ".unreq qU12\n" ".unreq qU13\n" ".unreq qU14\n" - ".unreq qU21\n" ".unreq qU22\n" ".unreq qU23\n" ".unreq qU24\n" - ".unreq qU31\n" ".unreq qU32\n" ".unreq qU33\n" ".unreq qU34\n" - ".unreq qU41\n" ".unreq qU42\n" ".unreq qU43\n" ".unreq qU44\n" - ".unreq sU11\n" ".unreq sU12\n" ".unreq sU13\n" ".unreq sU14\n" - ".unreq sU21\n" ".unreq sU22\n" ".unreq sU23\n" ".unreq sU24\n" - ".unreq sU31\n" ".unreq sU32\n" ".unreq sU33\n" ".unreq sU34\n" - ".unreq sU41\n" ".unreq sU42\n" ".unreq sU43\n" ".unreq sU44\n" - - : [inptr0] "+r" (inptr0), - [inptr1] "+r" (inptr1), - [inptr2] "+r" (inptr2), - [outptr0] "+r" (outptr0), - [outptr4] "+r" (outptr4), - [outptr8] "+r" (outptr8), - [outptr12] "+r" (outptr12), - [n_remaining_channels] "+r" (n_remaining_channels) - : [mstride1] "r" (sizeof(float) * mstride), - [mstride2] "r" (sizeof(float) * mstride * 2), - [mstride3] "r" (sizeof(float) * mstride * 3), - [colstride1] "r" (sizeof(float) * kernel_col_stride), - [colstride2] "r" (sizeof(float) * kernel_col_stride * 2), - [one_half] "r" (0.5f) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24" - ); - - // Progression to complete stride - outptr0 += matrix_row_stride - n_output_channels; - outptr4 += matrix_row_stride - n_output_channels; - outptr8 += matrix_row_stride - n_output_channels; - outptr12 += matrix_row_stride - n_output_channels; - } -} -} -#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp deleted file mode 100644 index 0992c0bb44..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp +++ /dev/null @@ -1,356 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -namespace winograd { - /* Transform from the Winograd domain back to the spatial domain. - */ - template - struct Winograd2x2_3x3GemmOutput { - static void execute( - const Tensor4DShape &output_shape, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - T* const output - ); - - protected: - /* Specialised implementation method. */ - template - static void _execute( - const Tensor4DShape &output_shape, - T *output, - const T *input, - const int matrix_stride, - const int matrix_row_stride - ); - }; - - /* Two-stage implementation of the transformation from the Winograd domain. - * - * First computes Z.F and then computes (Z.F).Z^T. - */ - template - struct Winograd2x2_3x3GemmOutput_TwoStage { - static void execute( - const Tensor4DShape &output_shape, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - T* const output - ); - - protected: - template - static void compute_zf( - const int n_rows, const int n_channels, - T* const zf, const T* const input[16] - ); - - template - static void compute_zfzT( - const Tensor4DShape &output_shape, - T* const output, const T* const zf - ); - }; -} - -#include "output_2x2_3x3/a64_float.hpp" -// #include "output_2x2_3x3/a64_float_two_stage.hpp" - -/*****************************************************************************/ -/* -template -void winograd::Winograd2x2_3x3GemmOutput::execute( - const Tensor4DShape &output_shape, - const int tile_M, - const int tile_N, - T* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - T* const output -) { - T* const antipadding = reinterpret_cast(malloc(sizeof(T) * output_shape.n_channels)); - - // Get input pointers - const T* inptrs[16]; - for (int i = 0; i < 16; i++) { - inptrs[i] = matrices[i]; - } - - for (int batch = 0; batch < output_shape.n_batches; batch++) { - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - // Get pointers for each of the 4 output cells required for this computation - T* outptrs[4]; - for (int cell_i = 0, c = 0; cell_i < 2; cell_i++) { - for (int cell_j = 0; cell_j < 2; cell_j++, c++) { - const int i = tile_i*2 + cell_i; - const int j = tile_j*2 + cell_j; - - if (i < output_shape.n_rows && j < output_shape.n_cols) { - outptrs[c] = output + ( - (batch*output_shape.n_rows + i) * output_shape.n_cols + - j) * output_shape.n_channels; - } else { - outptrs[c] = antipadding; - } - } // cell_j - } // cell_i - - for (int n = 0; n < output_shape.n_channels; n++) { - // Read 16 values and progress pointers - T v[16]; - for (int i = 0; i < 16; i++) { - v[i] = *(inptrs[i]++); - } - - // Compute output for 4 pixels - *(outptrs[0]++) = v[ 0] + v[ 1] + v[ 2] + - v[ 4] + v[ 5] + v[ 6] + - v[ 8] + v[ 9] + v[10]; - *(outptrs[1]++) = v[ 1] - v[ 2] - v[ 3] + - v[ 5] - v[ 6] - v[ 7] + - v[ 9] - v[10] - v[11]; - *(outptrs[2]++) = v[ 4] + v[ 5] + v[ 6] - - v[ 8] - v[ 9] - v[10] - - v[12] - v[13] - v[14]; - *(outptrs[3]++) = v[ 5] - v[ 6] - v[ 7] - - v[ 9] + v[10] + v[11] - - v[13] + v[14] + v[15]; - } // output_channel - } // tile_j - } // tile_i - } // batch - - free(antipadding); -} -*/ - -/*****************************************************************************/ -/* -template -void winograd::Winograd2x2_3x3GemmOutput_TwoStage::execute( - const Tensor4DShape &output_shape, - T* const matrices[16], T* const output -) { - // Allocate memory for the intermediate matrices - const int tile_M = iceildiv(output_shape.n_rows, 2); - const int tile_N = iceildiv(output_shape.n_cols, 2); - const int n_rows = output_shape.n_batches * tile_M * tile_N; - const int n_channels = output_shape.n_channels; - T* matrices_zf = reinterpret_cast( - calloc(8 * n_rows * n_channels, sizeof(T)) - ); - - // Perform the first stage transform, computing ZF. - // Specializations should dispatch to different methods based on tail size. - compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); - - // Perform the second stage transform, finishing Z F Z^T - variable dispatch - // based on size of the output. Specialisations can also dispatch based on - // the tail-size of the channel. - if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { - compute_zfzT(output_shape, output, matrices_zf); - } else if (output_shape.n_rows % 2) { - compute_zfzT(output_shape, output, matrices_zf); - } else if (output_shape.n_cols % 2) { - compute_zfzT(output_shape, output, matrices_zf); - } else { - compute_zfzT(output_shape, output, matrices_zf); - } - - free(reinterpret_cast(matrices_zf)); -} - -template -template -void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zf( - const int n_rows, const int n_channels, - T* output, const T* const input[16] -) { - // Extract 8 output pointers - T* outptr[8]; - for (int i = 0; i < 8; i++) { - outptr[i] = output + i*n_rows*n_channels; - } - - // Copy the 16 input pointers - const T* inptr[16]; - for (int i = 0; i < 16; i++) { - inptr[i] = input[i]; - } - - // For every row of the matrices - for (int i = 0; i < n_rows; i++) { - // For every channel - for (int j = 0; j < n_channels; j++) { - // Extract values from the input matrices - T val[16]; - for (int n = 0; n < 16; n++) { - val[n] = *(inptr[n]++); - } - - // Compute output values - *(outptr[0]++) = val[0] + val[1] + val[2]; - *(outptr[1]++) = val[1] - val[2] - val[3]; - *(outptr[2]++) = val[4] + val[5] + val[6]; - *(outptr[3]++) = val[5] - val[6] - val[7]; - *(outptr[4]++) = val[8] + val[9] + val[10]; - *(outptr[5]++) = val[9] - val[10] - val[11]; - *(outptr[6]++) = val[12] + val[13] + val[14]; - *(outptr[7]++) = val[13] - val[14] - val[15]; - } - } -} - -template -template -void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zfzT( - const Tensor4DShape &output_shape, - T* const output, const T* const input -) { - // Sizing information - const int tile_M = output_shape.n_rows / 2; - const int tile_N = output_shape.n_cols / 2; - - const int n_rows = (output_shape.n_batches * - (tile_M + (tail_M ? 1 : 0)) * - (tile_N + (tail_N ? 1 : 0))); - const int n_channels = output_shape.n_channels; - - // Extract 8 input pointers - const T* inptr[8]; - for (int i = 0; i < 8; i++) { - inptr[i] = input + i*n_rows*n_channels; - } - - // Extract 4 output pointers - T* outptr00 = output; - T* outptr01 = outptr00 + n_channels; - T* outptr10 = outptr00 + output_shape.n_cols * n_channels; - T* outptr11 = outptr10 + n_channels; - - // Progress over the output tiles, generating output values. - for (int batch = 0; batch < output_shape.n_batches; batch++) { - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - T v[8]; - for (int i = 0; i < 8; i++) { - v[i] = *(inptr[i]++); - } - - // Compute the output values and progress the output pointers. - *(outptr00++) = v[0] + v[2] + v[4]; - *(outptr01++) = v[1] + v[3] + v[5]; - *(outptr10++) = v[2] - v[4] - v[6]; - *(outptr11++) = v[3] - v[5] - v[7]; - } - - // Progress the output pointers to the next column - outptr00 += n_channels; - outptr01 += n_channels; - outptr10 += n_channels; - outptr11 += n_channels; - } - - if (tail_N) { - // Only evaluate the left-most columns of the output - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - T v[8]; - for (int i = 0; i < 4; i++) { - v[i * 2] = *inptr[i * 2]; - } - for (int i = 0; i < 8; i++) { - inptr[i]++; - } - - // Compute the output values and progress the output pointers. - *(outptr00++) = v[0] + v[2] + v[4]; - *(outptr10++) = v[2] - v[4] - v[6]; - } - - // Progress the output pointers to the next column - outptr01 += n_channels; // Account for being skipped above - outptr11 += n_channels; // Account for being skipped above - } - - // Progress the output pointers to the next row - outptr00 += output_shape.n_cols * n_channels; - outptr01 += output_shape.n_cols * n_channels; - outptr10 += output_shape.n_cols * n_channels; - outptr11 += output_shape.n_cols * n_channels; - } - - if (tail_M) { - // Only work on the upper row of the output - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - T v[8]; - for (int i = 0; i < 8; i++) { - v[i] = *(inptr[i]++); - } - - // Compute the output values and progress the output pointers. - *(outptr00++) = v[0] + v[2] + v[4]; - *(outptr01++) = v[1] + v[3] + v[5]; - } - - // Progress the output pointers to the next column - outptr00 += n_channels; - outptr01 += n_channels; - outptr10 += 2 * n_channels; // Account for being skipped above - outptr11 += 2 * n_channels; // Account for being skipped above - } - - if (tail_N) { - // Only evaluate the upper-left cell of the output - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - T v[8]; - for (int i = 0; i < 3; i++) { - v[i * 2] = *inptr[i * 2]; - } - for (int i = 0; i < 8; i++) { - inptr[i]++; - } - - // Compute the output values and progress the output pointers. - *(outptr00++) = v[0] + v[2] + v[4]; - } - - // Progress the output pointers to the next column - outptr01 += n_channels; // Account for being skipped above - outptr10 += n_channels; // Account for being skipped above - outptr11 += n_channels; // Account for being skipped above - } - } - } -} -*/ diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp deleted file mode 100644 index bf6ba907b9..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp +++ /dev/null @@ -1,650 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -/* Float implementation for AArch64. - */ -#ifdef __aarch64__ -namespace winograd { - - -template <> -template <> -inline void Winograd2x2_3x3GemmOutput::_execute( - const Tensor4DShape &output_shape, - float *output, - const float *input, - const int mstride, - const int matrix_row_stride -) { - const int tile_M = output_shape.n_rows / 2; - const int tile_N = output_shape.n_cols / 2; - int batch = output_shape.n_batches; - float *outptr = output; - - const float *inptr0 = input; - const float *inptr4 = input + 4 * mstride; - const float *inptr8 = input + 8 * mstride; - const float *inptr12 = input + 12 * mstride; - - const size_t col_stride = sizeof(float) * output_shape.n_channels; - const size_t row_stride = col_stride * tile_N * 2; - - asm volatile ( - // Aliases for elements of the input matrix `F` - // V-register Q-register - "F11 .req v0\n" "qF11 .req q0\n" - "F12 .req v1\n" "qF12 .req q1\n" - "F13 .req v2\n" "qF13 .req q2\n" - "F14 .req v3\n" "qF14 .req q3\n" - "F21 .req v4\n" "qF21 .req q4\n" - "F22 .req v5\n" "qF22 .req q5\n" - "F23 .req v6\n" "qF23 .req q6\n" - "F24 .req v7\n" "qF24 .req q7\n" - "F31 .req v8\n" "qF31 .req q8\n" - "F32 .req v9\n" "qF32 .req q9\n" - "F33 .req v10\n" "qF33 .req q10\n" - "F34 .req v11\n" "qF34 .req q11\n" - "F41 .req v12\n" "qF41 .req q12\n" - "F42 .req v13\n" "qF42 .req q13\n" - "F43 .req v14\n" "qF43 .req q14\n" - "F44 .req v15\n" "qF44 .req q15\n" - - // Aliases for elements of the intermediate matrix `FZ` - "FZ11 .req v16\n" - "FZ12 .req v17\n" - "FZ21 .req v18\n" - "FZ22 .req v19\n" - "FZ31 .req v20\n" - "FZ32 .req v21\n" - "FZ41 .req v22\n" - "FZ42 .req v23\n" - - // Aliases for elements of the output matrix `f` (called `g` due to case - // insensitivity of aliases). - " g11 .req v24\n" - "qg11 .req q24\n" - " g12 .req v25\n" - "qg12 .req q25\n" - " g21 .req v26\n" - "qg21 .req q26\n" - " g22 .req v27\n" - "qg22 .req q27\n" - - // Prepare the various strides - "col_stride .req %x[col_stride]\n" - "row_stride .req %x[row_stride]\n" - "row_plus_col_stride .req %x[row_plus_col_stride]\n" - - "mstride1 .req %x[mstride1]\n" - "mstride2 .req %x[mstride2]\n" - "mstride3 .req %x[mstride3]\n" - - "tile_i .req x19\n" // Tile row counter - "tile_j .req x20\n" // Tile column counter - "channel .req x21\n" // Channel counter - - "1:" // Loop over batches - "mov tile_i, %x[tile_M]\n" // Reset tile row counter - - "2:" // Loop over rows of tiles - "mov tile_j, %x[tile_N]\n" // Reset tile column counter - - "3:" // Loop over columns of tiles - // Perform initial loads of the matrix `F` - "ldr qF11, [%x[inptr0]]\n" - "ldr qF12, [%x[inptr0], mstride1]\n" - "ldr qF13, [%x[inptr0], mstride2]\n" - "ldr qF14, [%x[inptr0], mstride3]\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - "ldr qF21, [%x[inptr4]]\n" - "ldr qF22, [%x[inptr4], mstride1]\n" - "subs channel, %x[n_channels], #4\n" // Reset channel counter - - "ldr qF23, [%x[inptr4], mstride2]\n" - "ldr qF24, [%x[inptr4], mstride3]\n" - "add %x[inptr4], %x[inptr4], #0x10\n" - "beq 5f\n" // Jump straight to tail if necessary - - "4:" // Loop over channels - "ldr qF31, [%x[inptr8]]\n" - "fadd FZ11.4s, F11.4s, F12.4s\n" - - "ldr qF32, [%x[inptr8], mstride1]\n" - "fsub FZ12.4s, F12.4s, F13.4s\n" - - "ldr qF33, [%x[inptr8], mstride2]\n" - "fadd FZ11.4s, FZ11.4s, F13.4s\n" - - "ldr qF34, [%x[inptr8], mstride3]\n" - "fsub FZ12.4s, FZ12.4s, F14.4s\n" - - "ldr qF41, [%x[inptr12]]\n" - "fadd FZ21.4s, F21.4s, F22.4s\n" - - "ldr qF42, [%x[inptr12], mstride1]\n" - "fsub FZ22.4s, F22.4s, F23.4s\n" - - "ldr qF43, [%x[inptr12], mstride2]\n" - "fadd FZ21.4s, FZ21.4s, F23.4s\n" - - "ldr qF44, [%x[inptr12], mstride3]\n" - "fsub FZ22.4s, FZ22.4s, F24.4s\n" - - "fadd FZ31.4s, F31.4s, F32.4s\n" - "add %x[inptr8], %x[inptr8], #0x10\n" - - "fsub FZ32.4s, F32.4s, F33.4s\n" - "add %x[inptr12], %x[inptr12], #0x10\n" - - "fadd FZ31.4s, FZ31.4s, F33.4s\n" - - "fsub FZ32.4s, FZ32.4s, F34.4s\n" - - "fadd g11.4s, FZ11.4s, FZ21.4s\n" - - "fadd g12.4s, FZ12.4s, FZ22.4s\n" - - "fadd g11.4s, g11.4s, FZ31.4s\n" - - "fadd g12.4s, g12.4s, FZ32.4s\n" - - "ldr qF11, [%x[inptr0]]\n" - "fadd FZ41.4s, F41.4s, F42.4s\n" - - "ldr qF12, [%x[inptr0], mstride1]\n" - "fsub g21.4s, FZ21.4s, FZ31.4s\n" - - "ldr qF13, [%x[inptr0], mstride2]\n" - "fsub FZ42.4s, F42.4s, F43.4s\n" - - "ldr qF14, [%x[inptr0], mstride3]\n" - "str qg11, [%x[outptr]]\n" - - "ldr qF21, [%x[inptr4]]\n" - "fadd FZ41.4s, FZ41.4s, F43.4s\n" - - "ldr qF22, [%x[inptr4], mstride1]\n" - "str qg12, [%x[outptr], col_stride]\n" - - "ldr qF23, [%x[inptr4], mstride2]\n" - "fsub FZ42.4s, FZ42.4s, F44.4s\n" - - "ldr qF24, [%x[inptr4], mstride3]\n" - "fsub g22.4s, FZ22.4s, FZ32.4s\n" - - "fsub g21.4s, g21.4s, FZ41.4s\n" - "add %x[inptr0], %x[inptr0], #0x10\n" - - "fsub g22.4s, g22.4s, FZ42.4s\n" - "add %x[inptr4], %x[inptr4], #0x10\n" - - "subs channel, channel, #4\n" - - "str qg21, [%x[outptr], row_stride]\n" - - "str qg22, [%x[outptr], row_plus_col_stride]\n" - - "add %x[outptr], %x[outptr], #0x10\n" - - "bne 4b\n" - - "5:" // Channel tail - "ldr qF31, [%x[inptr8]]\n" - "fadd FZ11.4s, F11.4s, F12.4s\n" - - "ldr qF32, [%x[inptr8], mstride1]\n" - "fsub FZ12.4s, F12.4s, F13.4s\n" - - "ldr qF33, [%x[inptr8], mstride2]\n" - "fadd FZ11.4s, FZ11.4s, F13.4s\n" - - "ldr qF34, [%x[inptr8], mstride3]\n" - "fsub FZ12.4s, FZ12.4s, F14.4s\n" - - "ldr qF41, [%x[inptr12]]\n" - "fadd FZ21.4s, F21.4s, F22.4s\n" - - "ldr qF42, [%x[inptr12], mstride1]\n" - "fsub FZ22.4s, F22.4s, F23.4s\n" - - "ldr qF43, [%x[inptr12], mstride2]\n" - "fadd FZ21.4s, FZ21.4s, F23.4s\n" - - "ldr qF44, [%x[inptr12], mstride3]\n" - "fsub FZ22.4s, FZ22.4s, F24.4s\n" - - "fadd FZ31.4s, F31.4s, F32.4s\n" - "add %x[inptr8], %x[inptr8], #0x10\n" - - "fsub FZ32.4s, F32.4s, F33.4s\n" - "add %x[inptr12], %x[inptr12], #0x10\n" - - "fadd FZ31.4s, FZ31.4s, F33.4s\n" - - "fsub FZ32.4s, FZ32.4s, F34.4s\n" - - "fadd g11.4s, FZ11.4s, FZ21.4s\n" - - "fadd g12.4s, FZ12.4s, FZ22.4s\n" - - "fadd g11.4s, g11.4s, FZ31.4s\n" - - "fadd g12.4s, g12.4s, FZ32.4s\n" - - "fadd FZ41.4s, F41.4s, F42.4s\n" - - "fsub g21.4s, FZ21.4s, FZ31.4s\n" - - "fsub FZ42.4s, F42.4s, F43.4s\n" - - "str qg11, [%x[outptr]]\n" - - "fadd FZ41.4s, FZ41.4s, F43.4s\n" - - "str qg12, [%x[outptr], col_stride]\n" - - "fsub FZ42.4s, FZ42.4s, F44.4s\n" - - "fsub g22.4s, FZ22.4s, FZ32.4s\n" - - "fsub g21.4s, g21.4s, FZ41.4s\n" - - "fsub g22.4s, g22.4s, FZ42.4s\n" - - "subs channel, channel, #4\n" - - "str qg21, [%x[outptr], row_stride]\n" - - // Progress input pointers to the next row of the matrix - "add %x[inptr0], %x[inptr0], %x[mrowpad]\n" - "add %x[inptr4], %x[inptr4], %x[mrowpad]\n" - "add %x[inptr8], %x[inptr8], %x[mrowpad]\n" - "add %x[inptr12], %x[inptr12], %x[mrowpad]\n" - - "str qg22, [%x[outptr], row_plus_col_stride]\n" - - "add %x[outptr], %x[outptr], #0x10\n" - - - "add %x[outptr], %x[outptr], col_stride\n" - "subs tile_j, tile_j, #1\n" - "bne 3b\n" - - "add %x[outptr], %x[outptr], row_stride\n" - "subs tile_i, tile_i, #1\n" - "bne 2b\n" - - "subs %w[batch], %w[batch], #1\n" - "bne 1b\n" - - ".unreq F11\n" ".unreq qF11\n" - ".unreq F12\n" ".unreq qF12\n" - ".unreq F13\n" ".unreq qF13\n" - ".unreq F14\n" ".unreq qF14\n" - ".unreq F21\n" ".unreq qF21\n" - ".unreq F22\n" ".unreq qF22\n" - ".unreq F23\n" ".unreq qF23\n" - ".unreq F24\n" ".unreq qF24\n" - ".unreq F31\n" ".unreq qF31\n" - ".unreq F32\n" ".unreq qF32\n" - ".unreq F33\n" ".unreq qF33\n" - ".unreq F34\n" ".unreq qF34\n" - ".unreq F41\n" ".unreq qF41\n" - ".unreq F42\n" ".unreq qF42\n" - ".unreq F43\n" ".unreq qF43\n" - ".unreq F44\n" ".unreq qF44\n" - - ".unreq FZ11\n" ".unreq FZ12\n" - ".unreq FZ21\n" ".unreq FZ22\n" - ".unreq FZ31\n" ".unreq FZ32\n" - ".unreq FZ41\n" ".unreq FZ42\n" - - ".unreq g11\n" ".unreq qg11\n" - ".unreq g12\n" ".unreq qg12\n" - ".unreq g21\n" ".unreq qg21\n" - ".unreq g22\n" ".unreq qg22\n" - - ".unreq col_stride\n" - ".unreq row_stride\n" - ".unreq row_plus_col_stride\n" - - ".unreq mstride1\n" - ".unreq mstride2\n" - ".unreq mstride3\n" - - ".unreq tile_i \n" - ".unreq tile_j \n" - ".unreq channel\n" - - : [batch] "+r" (batch), - [outptr] "+r" (outptr), - [inptr0] "+r" (inptr0), - [inptr4] "+r" (inptr4), - [inptr8] "+r" (inptr8), - [inptr12] "+r" (inptr12) - : [tile_M] "r" (tile_M), - [tile_N] "r" (tile_N), - [n_channels] "r" (output_shape.n_channels), - [col_stride] "r" (col_stride), - [row_stride] "r" (row_stride), - [row_plus_col_stride] "r" (row_stride + col_stride), - [mstride1] "r" (mstride * sizeof(float)), - [mstride2] "r" (2 * mstride * sizeof(float)), - [mstride3] "r" (3 * mstride * sizeof(float)), - [mrowpad] "r" ((matrix_row_stride - output_shape.n_channels) * sizeof(float)) - : "x19", "x20", "x21", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", - "cc", "memory" - ); -} - -template <> -template -inline void Winograd2x2_3x3GemmOutput::_execute( - const Tensor4DShape &output_shape, - float *output, - const float *input, - const int mstride, - const int matrix_row_stride -) { - // Compute basic information about the shape of the matrices - const int tile_M = output_shape.n_rows / 2; - const int tile_N = output_shape.n_cols / 2; - const int n_channels = output_shape.n_channels; - - // Extract 16 input pointers - const float* inptr[16]; - for (int i = 0; i < 16; i++) { - inptr[i] = input + i*mstride; - } - - // Extract 4 output pointers - float *outptr00 = output; - float *outptr01 = outptr00 + n_channels; - float *outptr10 = outptr00 + output_shape.n_cols * n_channels; - float *outptr11 = outptr10 + n_channels; - - // Progress over the output tiles, generating output values. - for (int batch = 0; batch < output_shape.n_batches; batch++) { - for (int tile_i = 0; tile_i < tile_M; tile_i++) { - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - float F[4][4]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - F[i][j] = *(inptr[i*4 + j]++); - } - } - - // Compute the matrix F.Z - float ZF[4][2]; - ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; - ZF[0][1] = F[0][1] - F[0][2] - F[0][3]; - ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; - ZF[1][1] = F[1][1] - F[1][2] - F[1][3]; - ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; - ZF[2][1] = F[2][1] - F[2][2] - F[2][3]; - ZF[3][0] = F[3][0] + F[3][1] + F[3][2]; - ZF[3][1] = F[3][1] - F[3][2] - F[3][3]; - - // Hence compute the output matrix Z^T . (F.Z) - *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; - *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1]; - *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0]; - *(outptr11++) = ZF[1][1] - ZF[2][1] - ZF[3][1]; - } - - // Progress the input pointers to the next row - for (int i = 0; i < 16; i++) { - inptr[i] += matrix_row_stride - n_channels; - } - - // Progress the output pointers to the next column - outptr00 += n_channels; - outptr01 += n_channels; - outptr10 += n_channels; - outptr11 += n_channels; - } - - if (tail_N) { - // Only evaluate the left-most columns of the output - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - float F[4][3]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 3; j++) { - F[i][j] = *(inptr[i*4 + j]++); - } - } - for (int i = 0; i < 4; i++) { - inptr[i*4 + 3]++; - } - - // Compute the matrix F.Z - float ZF[4][1]; - ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; - ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; - ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; - ZF[3][0] = F[3][0] + F[3][1] + F[3][2]; - - // Hence compute the output matrix Z^T . (F.Z) - *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; - *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0]; - } - - // Progress the input pointers to the next row - for (int i = 0; i < 16; i++) { - inptr[i] += matrix_row_stride - n_channels; - } - - // Progress the output pointers to the next column - outptr01 += n_channels; // Account for being skipped above - outptr11 += n_channels; // Account for being skipped above - } - - // Progress the output pointers to the next row - outptr00 += output_shape.n_cols * n_channels; - outptr01 += output_shape.n_cols * n_channels; - outptr10 += output_shape.n_cols * n_channels; - outptr11 += output_shape.n_cols * n_channels; - } - - if (tail_M) { - // Only work on the upper row of the output - for (int tile_j = 0; tile_j < tile_N; tile_j++) { - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - float F[3][4]; - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4; j++) { - F[i][j] = *(inptr[i*4 + j]++); - } - } - for (int j = 0; j < 4; j++) { - inptr[12 + j]++; - } - - // Compute the matrix F.Z - float ZF[3][2]; - ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; - ZF[0][1] = F[0][1] - F[0][2] - F[0][3]; - ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; - ZF[1][1] = F[1][1] - F[1][2] - F[1][3]; - ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; - ZF[2][1] = F[2][1] - F[2][2] - F[2][3]; - - // Hence compute the output matrix Z^T . (F.Z) - *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; - *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1]; - } - - // Progress the input pointers to the next row - for (int i = 0; i < 16; i++) { - inptr[i] += matrix_row_stride - n_channels; - } - - // Progress the output pointers to the next column - outptr00 += n_channels; - outptr01 += n_channels; - outptr10 += 2 * n_channels; // Account for being skipped above - outptr11 += 2 * n_channels; // Account for being skipped above - } - - if (tail_N) { - // Only evaluate the upper-left cell of the output - for (int channel = 0; channel < n_channels; channel++) { - // Read values from the input pointers - float F[3][3]; - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 3; j++) { - F[i][j] = *(inptr[i*4 + j]); - } - } - for (int i = 0; i < 16; i++) { - inptr[i]++; - } - - // Compute the matrix F.Z - float ZF[3][1]; - ZF[0][0] = F[0][0] + F[0][1] + F[0][2]; - ZF[1][0] = F[1][0] + F[1][1] + F[1][2]; - ZF[2][0] = F[2][0] + F[2][1] + F[2][2]; - - // Hence compute the output matrix Z^T . (F.Z) - *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0]; - } - - // Progress the input pointers to the next row - for (int i = 0; i < 16; i++) { - inptr[i] += matrix_row_stride - n_channels; - } - - // Progress the output pointers to the next column - outptr01 += n_channels; // Account for being skipped above - outptr10 += n_channels; // Account for being skipped above - outptr11 += n_channels; // Account for being skipped above - } - } - } -} - -/*****************************************************************************/ -template <> -inline void Winograd2x2_3x3GemmOutput::execute( - const Tensor4DShape &output_shape, - float* const matrix_base, - const int matrix_stride, - const int matrix_row_stride, - float* const output -) { - // Dispatch to an appropriate implementation based on the shape of the output - // tensor. - if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { - constexpr bool tail_M = true, tail_N = true; - switch (output_shape.n_channels % 4) { - case 0: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 1: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 2: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 3: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - default: - assert(0); - break; - } - } else if (output_shape.n_rows % 2) { - constexpr bool tail_M = true, tail_N = false; - switch (output_shape.n_channels % 4) { - case 0: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 1: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 2: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 3: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - default: - assert(0); - break; - } - } else if (output_shape.n_cols % 2) { - constexpr bool tail_M = false, tail_N = true; - switch (output_shape.n_channels % 4) { - case 0: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 1: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 2: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 3: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - default: - assert(0); - break; - - } - } else { - constexpr bool tail_M = false, tail_N = false; - switch (output_shape.n_channels % 4) { - case 0: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 1: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 2: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - case 3: - _execute(output_shape, output, matrix_base, matrix_stride, matrix_row_stride); - break; - default: - assert(0); - break; - - } - } -} -/*****************************************************************************/ - -} // namespace winograd -#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp deleted file mode 100644 index f551b12b52..0000000000 --- a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp +++ /dev/null @@ -1,655 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __aarch64__ - -/*****************************************************************************/ -// Compute ZF specializations - -template <> -template <> -inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zf<0>( - const int n_rows, const int n_channels, - float* output, const float* const input[16] -) { - // Make copies of some variables - int row = n_rows; - float* outptr = output; - const float* inptr = input[0]; - - // Perform the transformation - asm volatile ( - // "inptr0 .req %x[inptr]\n" - "inptr1 .req x0\n" - "inptr2 .req x1\n" - "inptr3 .req x2\n" - "inptr4 .req x3\n" - "inptr5 .req x4\n" - "inptr6 .req x5\n" - "inptr7 .req x6\n" - "inptr8 .req x7\n" - "inptr9 .req x8\n" - "inptr10 .req x9\n" - "inptr11 .req x10\n" - "inptr12 .req x11\n" - "inptr13 .req x12\n" - "inptr14 .req x13\n" - "inptr15 .req x14\n" - - // "outptr0 .req %x[outptr]\n" - "outptr1 .req x15\n" - "outptr2 .req x16\n" - "outptr3 .req x17\n" - "outptr4 .req x18\n" - "outptr5 .req x19\n" - "outptr6 .req x20\n" - "outptr7 .req x21\n" - - // Compute additional pointers into the input and output matrices. - "mstride .req x22\n" // Matrix stride - "mul mstride, %x[row], %x[n_channels]\n" - "lsl mstride, mstride, #2\n" // * sizeof(float) - - "add inptr1, %x[inptr], mstride\n" - "add inptr2, %x[inptr], mstride, LSL #1\n" - "add inptr3, inptr2, mstride\n" - "add inptr4, inptr3, mstride\n" - "add inptr5, inptr4, mstride\n" - "add inptr6, inptr5, mstride\n" - "add inptr7, inptr6, mstride\n" - "add inptr8, inptr7, mstride\n" - "add inptr9, inptr8, mstride\n" - "add inptr10, inptr9, mstride\n" - "add inptr11, inptr10, mstride\n" - "add inptr12, inptr11, mstride\n" - "add inptr13, inptr12, mstride\n" - "add inptr14, inptr13, mstride\n" - "add inptr15, inptr14, mstride\n" - - "add outptr1, %[outptr], mstride\n" - "add outptr2, outptr1, mstride\n" - "add outptr3, outptr2, mstride\n" - "add outptr4, outptr3, mstride\n" - "add outptr5, outptr4, mstride\n" - "add outptr6, outptr5, mstride\n" - "add outptr7, outptr6, mstride\n" - - ".unreq mstride\n" - - "column .req x22\n" // Column loop counter - - "1:" // Loop over rows - "ldr q0, [%x[inptr]], #0x10\n" - "ldr q1, [inptr1], #0x10\n" - "ldr q2, [inptr2], #0x10\n" - "ldr q3, [inptr3], #0x10\n" - "ldr q4, [inptr4], #0x10\n" - "ldr q5, [inptr5], #0x10\n" - "ldr q6, [inptr6], #0x10\n" - "ldr q7, [inptr7], #0x10\n" - "subs column, %x[n_channels], #0x4\n" - "beq 3f\n" - - "2:" // Loop over columns - "ldr q8, [inptr8], #0x10\n" - "prfm pldl1keep, [%x[inptr], #196]\n" - "fadd v16.4s, v0.4s, v1.4s\n" - - "ldr q9, [inptr9], #0x10\n" - "prfm pldl1keep, [inptr1, #196]\n" - "fsub v17.4s, v1.4s, v2.4s\n" - - "ldr q10, [inptr10], #0x10\n" - "prfm pldl1keep, [inptr2, #196]\n" - "fadd v16.4s, v16.4s, v2.4s\n" - - "ldr q11, [inptr11], #0x10\n" - "prfm pldl1keep, [inptr3, #196]\n" - "fsub v17.4s, v17.4s, v3.4s\n" - - "ldr q12, [inptr12], #0x10\n" - "prfm pldl1keep, [inptr4, #196]\n" - "str q16, [%x[outptr]], #0x10\n" - - "ldr q13, [inptr13], #0x10\n" - "prfm pldl1keep, [inptr5, #196]\n" - "str q17, [outptr1], #0x10\n" - - "ldr q14, [inptr14], #0x10\n" - "prfm pldl1keep, [inptr6, #196]\n" - "fadd v16.4s, v4.4s, v5.4s\n" - - "ldr q15, [inptr15], #0x10\n" - "prfm pldl1keep, [inptr7, #196]\n" - "fsub v17.4s, v5.4s, v6.4s\n" - - "ldr q0, [%x[inptr]], #0x10\n" - "prfm pldl1keep, [inptr8, #196]\n" - "fadd v16.4s, v16.4s, v6.4s\n" - - "ldr q1, [inptr1], #0x10\n" - "prfm pldl1keep, [inptr9, #196]\n" - "fsub v17.4s, v17.4s, v7.4s\n" - - "ldr q2, [inptr2], #0x10\n" - "prfm pldl1keep, [inptr10, #196]\n" - "str q16, [outptr2], #0x10\n" - - "ldr q3, [inptr3], #0x10\n" - "prfm pldl1keep, [inptr11, #196]\n" - "str q17, [outptr3], #0x10\n" - - "ldr q4, [inptr4], #0x10\n" - "prfm pldl1keep, [inptr12, #196]\n" - "fadd v16.4s, v8.4s, v9.4s\n" - - "ldr q5, [inptr5], #0x10\n" - "prfm pldl1keep, [inptr13, #196]\n" - "fsub v17.4s, v9.4s, v10.4s\n" - - "ldr q6, [inptr6], #0x10\n" - "prfm pldl1keep, [inptr14, #196]\n" - "fadd v16.4s, v16.4s, v10.4s\n" - - "ldr q7, [inptr7], #0x10\n" - "prfm pldl1keep, [inptr15, #196]\n" - "fsub v17.4s, v17.4s, v11.4s\n" - - "str q16, [outptr4], #0x10\n" - "fadd v16.4s, v12.4s, v13.4s\n" - "fsub v18.4s, v13.4s, v14.4s\n" - - "str q17, [outptr5], #0x10\n" - "fadd v16.4s, v16.4s, v14.4s\n" - "fsub v18.4s, v18.4s, v15.4s\n" - - "str q16, [outptr6], #0x10\n" - "subs column, column, #0x4\n" - - "str q18, [outptr7], #0x10\n" - "bne 2b\n" - - "3:" // Tail - "ldr q8, [inptr8], #0x10\n" - "prfm pldl1keep, [%x[inptr], #196]\n" - "fadd v16.4s, v0.4s, v1.4s\n" - - "ldr q9, [inptr9], #0x10\n" - "prfm pldl1keep, [inptr1, #196]\n" - "fsub v17.4s, v1.4s, v2.4s\n" - - "ldr q10, [inptr10], #0x10\n" - "prfm pldl1keep, [inptr2, #196]\n" - "fadd v16.4s, v16.4s, v2.4s\n" - - "ldr q11, [inptr11], #0x10\n" - "prfm pldl1keep, [inptr3, #196]\n" - "fsub v17.4s, v17.4s, v3.4s\n" - - "ldr q12, [inptr12], #0x10\n" - "prfm pldl1keep, [inptr4, #196]\n" - "str q16, [%x[outptr]], #0x10\n" - - "ldr q13, [inptr13], #0x10\n" - "prfm pldl1keep, [inptr5, #196]\n" - "str q17, [outptr1], #0x10\n" - - "ldr q14, [inptr14], #0x10\n" - "prfm pldl1keep, [inptr6, #196]\n" - "fadd v16.4s, v4.4s, v5.4s\n" - - "ldr q15, [inptr15], #0x10\n" - "prfm pldl1keep, [inptr7, #196]\n" - "fsub v17.4s, v5.4s, v6.4s\n" - - "prfm pldl1keep, [inptr8, #196]\n" - "prfm pldl1keep, [inptr9, #196]\n" - "fadd v16.4s, v16.4s, v6.4s\n" - - "prfm pldl1keep, [inptr10, #196]\n" - "prfm pldl1keep, [inptr11, #196]\n" - "fsub v17.4s, v17.4s, v7.4s\n" - - "prfm pldl1keep, [inptr12, #196]\n" - "prfm pldl1keep, [inptr13, #196]\n" - "str q16, [outptr2], #0x10\n" - - "prfm pldl1keep, [inptr14, #196]\n" - "prfm pldl1keep, [inptr15, #196]\n" - "str q17, [outptr3], #0x10\n" - - "fadd v16.4s, v8.4s, v9.4s\n" - "fsub v17.4s, v9.4s, v10.4s\n" - - "fadd v16.4s, v16.4s, v10.4s\n" - "fsub v17.4s, v17.4s, v11.4s\n" - - "str q16, [outptr4], #0x10\n" - "fadd v16.4s, v12.4s, v13.4s\n" - "fsub v18.4s, v13.4s, v14.4s\n" - - "str q17, [outptr5], #0x10\n" - "fadd v16.4s, v16.4s, v14.4s\n" - "fsub v18.4s, v18.4s, v15.4s\n" - - "str q16, [outptr6], #0x10\n" - "str q18, [outptr7], #0x10\n" - - "subs %x[row], %x[row], #0x1\n" - "bne 1b\n" - - ".unreq inptr1\n" - ".unreq inptr2\n" - ".unreq inptr3\n" - ".unreq inptr4\n" - ".unreq inptr5\n" - ".unreq inptr6\n" - ".unreq inptr7\n" - ".unreq inptr8\n" - ".unreq inptr9\n" - ".unreq inptr10\n" - ".unreq inptr11\n" - ".unreq inptr12\n" - ".unreq inptr13\n" - ".unreq inptr14\n" - ".unreq inptr15\n" - ".unreq outptr1\n" - ".unreq outptr2\n" - ".unreq outptr3\n" - ".unreq outptr4\n" - ".unreq outptr5\n" - ".unreq outptr6\n" - ".unreq outptr7\n" - - : [row] "+r" (row), - [inptr] "+r" (inptr), - [outptr] "+r" (outptr) - : [n_channels] "r" (n_channels), - [sizeof_float] "i" (sizeof(float)) - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", - "q12", "q13", "q14", "q15", "q16", "q17", "x0", "x1", "x2", "x3", "x4", - "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", - "x16", "x17", "x18", "x19", "x20", "x21", "x22", "cc", "memory" - ); -} - -/*****************************************************************************/ -// Compute ZFZ^T specializations - -template <> -template <> -inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::compute_zfzT( - const Tensor4DShape &output_shape, - float* const output, const float* const input -) { - const int tile_M = output_shape.n_rows / 2; - const int tile_N = output_shape.n_cols / 2; - int batch = output_shape.n_batches; - float *outptr = output; - const float *inptr = input; - - asm volatile ( - // Compute input pointers - "inptr1 .req x0\n" - "inptr2 .req x1\n" - "inptr3 .req x2\n" - "inptr4 .req x3\n" - "inptr5 .req x4\n" - "inptr6 .req x5\n" - "inptr7 .req x6\n" - "inptr8 .req x7\n" - - "mstride .req x8\n" - "mul mstride, %x[tile_M], %x[tile_N]\n" - "mul mstride, mstride, %x[n_channels]\n" - "lsl mstride, mstride, #2\n" // * sizeof(float) - - "add inptr1, %[inptr], mstride\n" - "add inptr2, inptr1, mstride\n" - "add inptr3, inptr2, mstride\n" - "add inptr4, inptr3, mstride\n" - "add inptr5, inptr4, mstride\n" - "add inptr6, inptr5, mstride\n" - "add inptr7, inptr6, mstride\n" - "add inptr8, inptr7, mstride\n" - - ".unreq mstride\n" - - // Compute initial output pointers - "outptr01 .req x8\n" - "outptr10 .req x9\n" - "outptr11 .req x10\n" - - "add outptr01, %x[outptr], %x[n_channels], LSL #2\n" - "add outptr10, %x[outptr], %x[row_stride], LSL #2\n" - "add outptr11, outptr10, %x[n_channels], LSL #2\n" - - "tile_i .req x11\n" - "tile_j .req x12\n" - "channel .req x13\n" - - "1:" // Loop over batches - "mov tile_i, %x[tile_M]\n" - - "2:" // Loop over rows of output tiles - "mov tile_j, %x[tile_N]\n" - - "3:" // Loop over columns of output tiles - "ldr q0, [%x[inptr]], #0x10\n" - "ldr q2, [inptr2], #0x10\n" - "subs channel, %x[n_channels], #0x4\n" - - "ldr q1, [inptr1], #0x10\n" - "ldr q3, [inptr3], #0x10\n" - "beq 6f\n" - - "4:" - "ldr q4, [inptr4], #0x10\n" - "ldr q5, [inptr5], #0x10\n" - "fadd v16.4s, v0.4s, v2.4s\n" - - "ldr q6, [inptr6], #0x10\n" - "ldr q7, [inptr7], #0x10\n" - "fadd v17.4s, v1.4s, v3.4s\n" - - "ldr q8, [%x[inptr]], #0x10\n" - "ldr q10, [inptr2], #0x10\n" - "fadd v16.4s, v16.4s, v4.4s\n" - - "ldr q9, [inptr1], #0x10\n" - "ldr q11, [inptr3], #0x10\n" - "fadd v17.4s, v17.4s, v5.4s\n" - - "str q16, [%x[outptr]], #0x10\n" - "prfm pldl1strm, [%x[inptr], #196]\n" - "fsub v18.4s, v2.4s, v4.4s\n" - - "str q17, [outptr01], #0x10\n" - "prfm pldl1strm, [inptr2, #196]\n" - "fsub v19.4s, v3.4s, v5.4s\n" - - "prfm pldl1strm, [inptr1, #196]\n" - "prfm pldl1strm, [inptr3, #196]\n" - "fsub v18.4s, v18.4s, v6.4s\n" - - "prfm pldl1strm, [inptr4, #196]\n" - "prfm pldl1strm, [inptr5, #196]\n" - "fsub v19.4s, v19.4s, v7.4s\n" - - "str q18, [outptr10], #0x10\n" - "prfm pldl1strm, [inptr6, #196]\n" - "prfm pldl1strm, [inptr7, #196]\n" - - "subs channel, channel, #0x4\n" - - "str q19, [outptr11], #0x10\n" - "beq 6f\n" // Branch to tail - - "ldr q12, [inptr4], #0x10\n" - "ldr q13, [inptr5], #0x10\n" - "fadd v16.4s, v8.4s, v10.4s\n" - - "ldr q14, [inptr6], #0x10\n" - "ldr q15, [inptr7], #0x10\n" - "fadd v17.4s, v9.4s, v11.4s\n" - - "ldr q0, [%x[inptr]], #0x10\n" - "ldr q2, [inptr2], #0x10\n" - "fadd v16.4s, v16.4s, v12.4s\n" - - "ldr q1, [inptr1], #0x10\n" - "ldr q3, [inptr3], #0x10\n" - "fadd v17.4s, v17.4s, v13.4s\n" - - "str q16, [%x[outptr]], #0x10\n" - "prfm pldl1strm, [%x[inptr], #196]\n" - "fsub v18.4s, v10.4s, v12.4s\n" - - "str q17, [outptr01], #0x10\n" - "prfm pldl1strm, [inptr2, #196]\n" - "fsub v19.4s, v11.4s, v13.4s\n" - - "prfm pldl1strm, [inptr1, #196]\n" - "prfm pldl1strm, [inptr3, #196]\n" - "fsub v18.4s, v18.4s, v14.4s\n" - - "prfm pldl1strm, [inptr4, #196]\n" - "prfm pldl1strm, [inptr5, #196]\n" - "fsub v19.4s, v19.4s, v15.4s\n" - - "str q18, [outptr10], #0x10\n" - "prfm pldl1strm, [inptr6, #196]\n" - "prfm pldl1strm, [inptr7, #196]\n" - - "subs channel, channel, #0x4\n" - - "str q19, [outptr11], #0x10\n" - "bne 4b\n" // Continue loop - - "5:" // Tail - "ldr q12, [inptr4], #0x10\n" - "ldr q13, [inptr5], #0x10\n" - "fadd v16.4s, v8.4s, v10.4s\n" - - "ldr q14, [inptr6], #0x10\n" - "ldr q15, [inptr7], #0x10\n" - "fadd v17.4s, v9.4s, v11.4s\n" - - "fadd v16.4s, v16.4s, v12.4s\n" - - "fadd v17.4s, v17.4s, v13.4s\n" - - "str q16, [%x[outptr]], #0x10\n" - "fsub v18.4s, v10.4s, v12.4s\n" - "fsub v19.4s, v11.4s, v13.4s\n" - - "str q17, [outptr01], #0x10\n" - "fsub v18.4s, v18.4s, v14.4s\n" - "fsub v19.4s, v19.4s, v15.4s\n" - - "str q18, [outptr10], #0x10\n" - "str q19, [outptr11], #0x10\n" - "b 7f\n" - - "6:" // Tail - "ldr q4, [inptr4], #0x10\n" - "ldr q5, [inptr5], #0x10\n" - "fadd v16.4s, v0.4s, v2.4s\n" - - "ldr q6, [inptr6], #0x10\n" - "ldr q7, [inptr7], #0x10\n" - "fadd v17.4s, v1.4s, v3.4s\n" - - "fadd v16.4s, v16.4s, v4.4s\n" - - "fadd v17.4s, v17.4s, v5.4s\n" - - "str q16, [%x[outptr]], #0x10\n" - "fsub v18.4s, v2.4s, v4.4s\n" - "fsub v19.4s, v3.4s, v5.4s\n" - - "str q17, [outptr01], #0x10\n" - "fsub v18.4s, v18.4s, v6.4s\n" - "fsub v19.4s, v19.4s, v7.4s\n" - - "str q18, [outptr10], #0x10\n" - "str q19, [outptr11], #0x10\n" - - "7:" - "add %x[outptr], %x[outptr], %x[n_channels], LSL #2\n" - "add outptr01, outptr01, %x[n_channels], LSL #2\n" - "add outptr10, outptr10, %x[n_channels], LSL #2\n" - "add outptr11, outptr11, %x[n_channels], LSL #2\n" - - "subs tile_j, tile_j, #1\n" - "bne 3b\n" - - // Progress the output pointers to the new row - "add %x[outptr], %x[outptr], %x[row_stride], LSL #2\n" - "add outptr01, outptr01, %x[row_stride], LSL #2\n" - "add outptr10, outptr10, %x[row_stride], LSL #2\n" - "add outptr11, outptr11, %x[row_stride], LSL #2\n" - - "subs tile_i, tile_i, #1\n" - "bne 2b\n" - - "subs %[batch], %[batch], #1\n" - "bne 1b\n" - "5:" - - ".unreq inptr1\n" - ".unreq inptr2\n" - ".unreq inptr3\n" - ".unreq inptr4\n" - ".unreq inptr5\n" - ".unreq inptr6\n" - ".unreq inptr7\n" - ".unreq inptr8\n" - ".unreq outptr01\n" - ".unreq outptr10\n" - ".unreq outptr11\n" - : [batch] "+r" (batch), - [outptr] "+r" (outptr), - [inptr] "+r" (inptr) - : [tile_M] "r" (tile_M), - [tile_N] "r" (tile_N), - [n_channels] "r" (output_shape.n_channels), - [row_stride] "r" (output_shape.n_cols * output_shape.n_channels) - : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", - "x12", "x13", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "cc", "memory" - ); -} -/*****************************************************************************/ - -/*****************************************************************************/ -template <> -inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage::execute( - const Tensor4DShape &output_shape, - float* const matrices[16], float* const output -) { - // profiler prof; - - // Allocate memory for the intermediate matrices - const int tile_M = iceildiv(output_shape.n_rows, 2); - const int tile_N = iceildiv(output_shape.n_cols, 2); - const int n_rows = output_shape.n_batches * tile_M * tile_N; - const int n_channels = output_shape.n_channels; - float* matrices_zf = reinterpret_cast( - calloc(8 * n_rows * n_channels, sizeof(float)) - ); - - // Perform the first stage transform, computing ZF. - const auto f_compute_zf = [&] () { - switch (n_channels % 4) { - case 0: - compute_zf<0>(n_rows, n_channels, matrices_zf, matrices); - break; - case 1: - compute_zf<1>(n_rows, n_channels, matrices_zf, matrices); - break; - case 2: - compute_zf<2>(n_rows, n_channels, matrices_zf, matrices); - break; - case 3: - compute_zf<3>(n_rows, n_channels, matrices_zf, matrices); - }; - }; - // prof("Compute ZF", f_compute_zf, 16 * n_rows * n_channels * sizeof(float), 0, 8 * n_rows * n_channels * sizeof(float)); - f_compute_zf(); - - // Perform the second stage transform, finishing Z F Z^T - variable dispatch - // based on size of the output and the channel tail. - const auto f_compute_zfzT = [&] () { - if (output_shape.n_rows % 2 && output_shape.n_cols % 2) { - constexpr bool tail_M = true, tail_N = true; - switch (n_channels % 4) { - case 0: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 1: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 2: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 3: - compute_zfzT(output_shape, output, matrices_zf); - } - } else if (output_shape.n_rows % 2) { - constexpr bool tail_M = true, tail_N = false; - switch (n_channels % 4) { - case 0: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 1: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 2: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 3: - compute_zfzT(output_shape, output, matrices_zf); - } - } else if (output_shape.n_cols % 2) { - constexpr bool tail_M = false, tail_N = true; - switch (n_channels % 4) { - case 0: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 1: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 2: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 3: - compute_zfzT(output_shape, output, matrices_zf); - } - } else { - constexpr bool tail_M = false, tail_N = false; - switch (n_channels % 4) { - case 0: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 1: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 2: - compute_zfzT(output_shape, output, matrices_zf); - break; - case 3: - compute_zfzT(output_shape, output, matrices_zf); - } - } - }; - // prof("Compute ZFZT", f_compute_zfzT, 8 * n_rows * n_channels * sizeof(float), 0, 4 * n_rows * n_channels * sizeof(float)); - f_compute_zfzT(); - - free(reinterpret_cast(matrices_zf)); -} -/*****************************************************************************/ - -#endif // __aarch64__ diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp new file mode 100644 index 0000000000..e7907d18c0 --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp @@ -0,0 +1,238 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "transforms/output.hpp" +#include "winograd_gemm.hpp" +#include "arm.hpp" + +namespace winograd +{ + +using Transform = WinogradGEMM<2, 2, 3, 3>::OutputTransform; + +template <> +template <> +int Transform::ops_performed(const Tensor4DShape &shape) +{ + // NOTE: Cost in FLOPs rather than instructions or uops. + const int tile_M = iceildiv(shape.n_rows, 2); + const int tile_N = iceildiv(shape.n_cols, 2); + return 24 * tile_M * tile_N * shape.n_channels; +} + +/* F(2x2, 3x3) constructs 2x2 output tiles from a 3x3 convolution. Since we use + * enough tiles to cover the output space each output tile may contain 0 or 1 + * padded values to the right and bottom columns or rows of the tile, e.g.: + * + * ___ ___ + * | | | X| + * |___| |__X| + * + * ___ ___ + * | | | X| + * |X_X| |X_X| + * + * + * We provide a specialised output transform for each of these instances. + * Consequently we below construct an array of the various padding options, the + * array contains pointers to the specific implementations. + */ +template <> +template <> +template +void Transform::process_tile( + const int n_channels, + const float* const matrix_base, + const int matrix_stride, + float* const output, + const int output_row_stride, + const int output_col_stride +) +{ + constexpr int cells_i = 2 - pad_bottom; + constexpr int cells_j = 2 - pad_right; + + // Construct a map to the output cells + float *outptrs[cells_i][cells_j]; + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + outptrs[i][j] = output + i*output_row_stride + j*output_col_stride; + } + } + const float *inptr = matrix_base; + + // For each channel of the output + int channels_remaining = n_channels; +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used and computed during this transform + float32x4_t F[4][4], FZ[4][2], f[2][2]; + + // Read a 4x4 tile in the Winograd domain + for (int i = 0, m = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++, m++) + { + F[i][j] = vld1q_f32(inptr + m*matrix_stride); + } + } + inptr += 4; + + // Compute the matrix F Z + for (int i = 0; i < 4; i++) + { + // FZ[i][0] = F[i][0] + F[i][1] + F[i][2]; + FZ[i][0] = vaddq_f32(vaddq_f32(F[i][0], F[i][1]), F[i][2]); + + // FZ[i][1] = F[i][1] - F[i][2] - F[i][3]; + FZ[i][1] = vsubq_f32(vsubq_f32(F[i][1], F[i][2]), F[i][3]); + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 2; j++) + { + // f[0][j] = FZ[0][j] + FZ[1][j] + FZ[2][j]; + f[0][j] = vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), FZ[2][j]); + + // f[1][j] = FZ[1][j] - FZ[2][j] - FZ[3][j]; + f[1][j] = vsubq_f32(vsubq_f32(FZ[1][j], FZ[2][j]), FZ[3][j]); + } + + // Write out the output tile + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + vst1q_f32(outptrs[i][j], f[i][j]); + outptrs[i][j] += 4; + } + } + } +#endif // __aarch64__ +#ifdef __arm_any__ + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used and computed during this transform + float32x2_t F[4][4], FZ[4][2], f[2][2]; + + // Read a 4x4 tile in the Winograd domain + for (int i = 0, m = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++, m++) + { + F[i][j] = vld1_f32(inptr + m*matrix_stride); + } + } + inptr += 2; + + // Compute the matrix F Z + for (int i = 0; i < 4; i++) + { + // FZ[i][0] = F[i][0] + F[i][1] + F[i][2]; + FZ[i][0] = vadd_f32(vadd_f32(F[i][0], F[i][1]), F[i][2]); + + // FZ[i][1] = F[i][1] - F[i][2] - F[i][3]; + FZ[i][1] = vsub_f32(vsub_f32(F[i][1], F[i][2]), F[i][3]); + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 2; j++) + { + // f[0][j] = FZ[0][j] + FZ[1][j] + FZ[2][j]; + f[0][j] = vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), FZ[2][j]); + + // f[1][j] = FZ[1][j] - FZ[2][j] - FZ[3][j]; + f[1][j] = vsub_f32(vsub_f32(FZ[1][j], FZ[2][j]), FZ[3][j]); + } + + // Write out the output tile + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + vst1_f32(outptrs[i][j], f[i][j]); + outptrs[i][j] += 2; + } + } + } +#endif // __arm_any__ + for (; channels_remaining; channels_remaining--) + { + // Matrices used and computed during this transform + float F[4][4], FZ[4][2], f[2][2]; + + // Read a 4x4 tile in the Winograd domain + for (int i = 0, m = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++, m++) + { + F[i][j] = *(inptr + m*matrix_stride); + } + } + inptr++; + + // Compute the matrix F Z + for (int i = 0; i < 4; i++) + { + FZ[i][0] = F[i][0] + F[i][1] + F[i][2]; + FZ[i][1] = F[i][1] - F[i][2] - F[i][3]; + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 2; j++) + { + f[0][j] = FZ[0][j] + FZ[1][j] + FZ[2][j]; + f[1][j] = FZ[1][j] - FZ[2][j] - FZ[3][j]; + } + + // Write out the output tile + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + *(outptrs[i][j]++) = f[i][j]; + } + } + } +} + +template <> +template <> +const Transform::TileFn Transform::tile_fns[max_pad_bottom][max_pad_right] = +{ + { + Transform::template process_tile<0, 0>, // No padding + Transform::template process_tile<0, 1>, // Right padding + }, + { + Transform::template process_tile<1, 0>, // Bottom padding + Transform::template process_tile<1, 1>, // Bottom and right padding + } +}; + +template struct WinogradGEMM<2, 2, 3, 3>::OutputTransform; +} // namespace winograd diff --git a/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp new file mode 100644 index 0000000000..483e5c110b --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp @@ -0,0 +1,299 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "transforms/output.hpp" +#include "winograd_gemm.hpp" +#include "arm.hpp" + +namespace winograd +{ + +using Transform = WinogradGEMM<4, 4, 3, 3>::OutputTransform; + +template <> +template <> +int Transform::ops_performed(const Tensor4DShape &shape) +{ + // NOTE: Cost in FLOPs rather than instructions or uops. + const int tile_M = iceildiv(shape.n_rows, 4); + const int tile_N = iceildiv(shape.n_cols, 4); + return 170 * tile_M * tile_N * shape.n_channels; +} + +// Instantiate cost methods +template int Transform::ops_performed(const Tensor4DShape&); + +/* F(4x4, 3x3) constructs 4x4 output tiles from a 3x3 convolution. Since we use + * enough tiles to cover the output space each output tile may contain up to 3 + * padded values to the right and bottom columns or rows of the tile, e.g.: +* +* ________ ________ ________ ________ +* | | | X| | X X| | X X X| +* | | | X| | X X| | X X X| +* | | | X| | X X| | X X X| +* |_______| |______X| |____X_X| |__X_X_X| +* +* ________ ________ ________ ________ +* | | | X| | X X| | X X X| +* | | | X| | X X| | X X X| +* | | | X| | X X| | X X X| +* |X_X_X_X| |X_X_X_X| |X_X_X_X| |X_X_X_X| +* +* ________ ________ ________ ________ +* | | | X| | X X| | X X X| +* | | | X| | X X| | X X X| +* |X X X X| |X X X X| |X X X X| |X X X X| +* |X_X_X_X| |X_X_X_X| |X_X_X_X| |X_X_X_X| +* +* ________ ________ ________ ________ +* | | | X| | X X| | X X X| +* |X X X X| |X X X X| |X X X X| |X X X X| +* |X X X X| |X X X X| |X X X X| |X X X X| +* |X_X_X_X| |X_X_X_X| |X_X_X_X| |X_X_X_X| +* +* +* We provide a specialised output transform for each of these instances. +*/ +template <> +template <> +template +void Transform::process_tile( + const int n_channels, + const float* const matrix_base, + const int matrix_stride, + float* const output, + const int output_row_stride, + const int output_col_stride +) +{ + constexpr int cells_i = 4 - pad_bottom; + constexpr int cells_j = 4 - pad_right; + + // Construct a map to the output cells + float *outptrs[cells_i][cells_j]; + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + outptrs[i][j] = output + i*output_row_stride + j*output_col_stride; + } + } + const float *inptr = matrix_base; + + // For each channel of the output + int channels_remaining = n_channels; +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used and computed during this transform + float32x4_t F[6][6], FZ[6][4], f[4][4]; + + // Read a 6x6 tile in the Winograd domain + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + F[i][j] = vld1q_f32(inptr + m*matrix_stride); + } + } + inptr += 4; + + // Compute the matrix F Z + for (int i = 0; i < 6; i++) + { + // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; + FZ[i][0] = vaddq_f32(vaddq_f32(vaddq_f32(F[i][0], F[i][1]), vaddq_f32(F[i][2], F[i][3])), F[i][4]); + + // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4]; + FZ[i][1] = vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 2.0f); + + // FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4]; + FZ[i][2] = vmlaq_n_f32(vaddq_f32(F[i][1], F[i][2]), vaddq_f32(F[i][3], F[i][4]), 4.0f); + + // FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5]; + FZ[i][3] = vaddq_f32(vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 8.0f), F[i][5]); + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 4; j++) + { + // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; + f[0][j] = vaddq_f32(vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), vaddq_f32(FZ[2][j], FZ[3][j])), FZ[4][j]); + + // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j]; + f[1][j] = vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 2.0f); + + // f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j]; + f[2][j] = vmlaq_n_f32(vaddq_f32(FZ[1][j], FZ[2][j]), vaddq_f32(FZ[3][j], FZ[4][j]), 4.0f); + + // f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j]; + f[3][j] = vaddq_f32(vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 8.0f), FZ[5][j]); + } + + // Write out the output tile + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + vst1q_f32(outptrs[i][j], f[i][j]); + outptrs[i][j] += 4; + } + } + } +#endif // __aarch64__ +#ifdef __arm_any__ + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used and computed during this transform + float32x2_t F[6][6], FZ[6][4], f[4][4]; + + // Read a 6x6 tile in the Winograd domain + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + F[i][j] = vld1_f32(inptr + m*matrix_stride); + } + } + inptr += 2; + + // Compute the matrix F Z + for (int i = 0; i < 6; i++) + { + // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; + FZ[i][0] = vadd_f32(vadd_f32(vadd_f32(F[i][0], F[i][1]), vadd_f32(F[i][2], F[i][3])), F[i][4]); + + // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4]; + FZ[i][1] = vmla_n_f32(vsub_f32(F[i][1], F[i][2]), vsub_f32(F[i][3], F[i][4]), 2.0f); + + // FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4]; + FZ[i][2] = vmla_n_f32(vadd_f32(F[i][1], F[i][2]), vadd_f32(F[i][3], F[i][4]), 4.0f); + + // FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5]; + FZ[i][3] = vadd_f32(vmla_n_f32(vsub_f32(F[i][1], F[i][2]), vsub_f32(F[i][3], F[i][4]), 8.0f), F[i][5]); + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 4; j++) + { + // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; + f[0][j] = vadd_f32(vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), vadd_f32(FZ[2][j], FZ[3][j])), FZ[4][j]); + + // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j]; + f[1][j] = vmla_n_f32(vsub_f32(FZ[1][j], FZ[2][j]), vsub_f32(FZ[3][j], FZ[4][j]), 2.0f); + + // f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j]; + f[2][j] = vmla_n_f32(vadd_f32(FZ[1][j], FZ[2][j]), vadd_f32(FZ[3][j], FZ[4][j]), 4.0f); + + // f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j]; + f[3][j] = vadd_f32(vmla_n_f32(vsub_f32(FZ[1][j], FZ[2][j]), vsub_f32(FZ[3][j], FZ[4][j]), 8.0f), FZ[5][j]); + } + + // Write out the output tile + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + vst1_f32(outptrs[i][j], f[i][j]); + outptrs[i][j] += 2; + } + } + } +#endif + for (; channels_remaining; channels_remaining--) + { + // Matrices used and computed during this transform + float F[6][6], FZ[6][4], f[4][4]; + + // Read a 6x6 tile in the Winograd domain + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + F[i][j] = *(inptr + m*matrix_stride); + } + } + inptr++; + + // Compute the matrix F Z + for (int i = 0; i < 6; i++) + { + FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; + FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4]; + FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4]; + FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5]; + } + + // Compute the output tile f = ZT F Z + for (int j = 0; j < 4; j++) + { + f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; + f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j]; + f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j]; + f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j]; + } + + // Write out the output tile + for (int i = 0; i < cells_i; i++) + { + for (int j = 0; j < cells_j; j++) + { + *(outptrs[i][j]++) = f[i][j]; + } + } + } +} + +template <> +template <> +const Transform::TileFn Transform::tile_fns[max_pad_bottom][max_pad_right] = +{ + { + Transform::template process_tile<0, 0>, + Transform::template process_tile<0, 1>, + Transform::template process_tile<0, 2>, + Transform::template process_tile<0, 3>, + }, + { + Transform::template process_tile<1, 0>, + Transform::template process_tile<1, 1>, + Transform::template process_tile<1, 2>, + Transform::template process_tile<1, 3>, + }, + { + Transform::template process_tile<2, 0>, + Transform::template process_tile<2, 1>, + Transform::template process_tile<2, 2>, + Transform::template process_tile<2, 3>, + }, + { + Transform::template process_tile<3, 0>, + Transform::template process_tile<3, 1>, + Transform::template process_tile<3, 2>, + Transform::template process_tile<3, 3>, + } +}; + +template struct WinogradGEMM<4, 4, 3, 3>::OutputTransform; +} // namespace winograd diff --git a/src/core/NEON/kernels/winograd/transforms/weights_2x2_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/weights_2x2_3x3_fp32.cpp new file mode 100644 index 0000000000..c0b282431e --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/weights_2x2_3x3_fp32.cpp @@ -0,0 +1,228 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "arm.hpp" +#include "winograd_gemm.hpp" +#include "transforms/kernel.hpp" + +namespace winograd +{ + template <> + template <> + void WinogradGEMM<2, 2, 3, 3>::WeightsTransform::execute( + const int n_output_channels, + const int n_input_channels, + const float* const input, + float* const output, + const int matrix_stride, + const int matrix_row_stride + ) + { + constexpr int inner_tile_i = 4; + constexpr int inner_tile_j = 4; + + // Get pointers to each cell of the weight tensor + const auto weight_col_stride = n_input_channels * n_output_channels; + const auto weight_row_stride = 3 * weight_col_stride; + const float *inptrs[3][3]; + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride; + } + } + + // For each input channel + for (int ic = 0; ic < n_input_channels; ic++) + { + float *outptr = output + ic * matrix_row_stride; + + // For each output channel + int channels_remaining = n_output_channels; +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used and computed in this kernel + float32x4_t w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = vld1q_f32(inptrs[i][j]); + inptrs[i][j] += 4; + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = w[0][j]; + + // Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]); + Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); + + // Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]); + Ww[2][j] = vmulq_n_f32(vaddq_f32(vsubq_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); + + Ww[3][j] = w[2][j]; + } + + // Compute V = W w WT + for (int i = 0; i < inner_tile_i; i++) + { + V[i][0] = Ww[i][0]; + + // V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]); + V[i][1] = vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); + + // V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]); + V[i][2] = vmulq_n_f32(vaddq_f32(vsubq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); + + V[i][3] = Ww[i][2]; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++, m++) + { + vst1q_f32(outptr + m*matrix_stride, V[i][j]); + } + } + outptr += 4; + } +#endif // __aarch64__ +#ifdef __arm_any__ + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used and computed in this kernel + float32x2_t w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = vld1_f32(inptrs[i][j]); + inptrs[i][j] += 2; + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = w[0][j]; + + // Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]); + Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); + + // Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]); + Ww[2][j] = vmul_n_f32(vadd_f32(vsub_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); + + Ww[3][j] = w[2][j]; + } + + // Compute V = W w WT + for (int i = 0; i < inner_tile_i; i++) + { + V[i][0] = Ww[i][0]; + + // V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]); + V[i][1] = vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); + + // V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]); + V[i][2] = vmul_n_f32(vadd_f32(vsub_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); + + V[i][3] = Ww[i][2]; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++, m++) + { + vst1_f32(outptr + m*matrix_stride, V[i][j]); + } + } + outptr += 2; + } +#endif // __arm_any__ + for (; channels_remaining; channels_remaining--) + { + // Matrices used and computed in this kernel + float w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = *(inptrs[i][j]++); + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = w[0][j]; + Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]); + Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]); + Ww[3][j] = w[2][j]; + } + + // Compute V = W w WT + for (int i = 0; i < inner_tile_i; i++) + { + V[i][0] = Ww[i][0]; + V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]); + V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]); + V[i][3] = Ww[i][2]; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < inner_tile_i; i++) + { + for (int j = 0; j < inner_tile_j; j++, m++) + { + *(outptr + m*matrix_stride) = V[i][j]; + } + } + outptr++; + } + } + } + + template <> + template <> + int WinogradGEMM<2, 2, 3, 3>::WeightsTransform::ops_performed(const KernelShape &shape) + { + const int channel_prod = shape.n_input_channels * shape.n_output_channels; + return 2 * 18 * channel_prod; + } + + template struct WinogradGEMM<2, 2, 3, 3>::WeightsTransform; +} // namespace winograd diff --git a/src/core/NEON/kernels/winograd/transforms/weights_4x4_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/weights_4x4_3x3_fp32.cpp new file mode 100644 index 0000000000..de659c38e0 --- /dev/null +++ b/src/core/NEON/kernels/winograd/transforms/weights_4x4_3x3_fp32.cpp @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "arm.hpp" +#include "winograd_gemm.hpp" +#include "transforms/kernel.hpp" + +namespace winograd +{ + /* Float implementation for kernel transform F(4x4, 3x3) */ + template <> + template <> + void WinogradGEMM<4, 4, 3, 3>::WeightsTransform::execute( + const int n_output_channels, + const int n_input_channels, + const float* const input, // NOTE: Data in HWIO order + float* const output, + const int matrix_stride, + const int matrix_row_stride + ) + { + // Get pointers to each cell of the weight tensor + const auto weight_col_stride = n_input_channels * n_output_channels; + const auto weight_row_stride = 3 * weight_col_stride; + const float *inptrs[3][3]; + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride; + } + } + + // For each input channel + for (int ic = 0; ic < n_input_channels; ic++) + { + float *outptr = output + ic * matrix_row_stride; + + // For each output channel + int channels_remaining = n_output_channels; +#ifdef __aarch64__ + for (; channels_remaining >= 4; channels_remaining -= 4) + { + // Matrices used and computed in this kernel + float32x4_t w[3][3], Ww[6][3], V[6][6]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = vld1q_f32(inptrs[i][j]); + inptrs[i][j] += 4; + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + // Ww[0][j] = 6*w[0][j]; + Ww[0][j] = vmulq_n_f32(w[0][j], 6.0); + + // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; + Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), -4.0); + + // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; + Ww[2][j] = vmulq_n_f32(vsubq_f32(vsubq_f32(w[1][j], w[0][j]), w[2][j]), 4.0); + + // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; + Ww[3][j] = vmlaq_n_f32(vmlaq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); + + // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; + Ww[4][j] = vmlaq_n_f32(vmlsq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); + + // Ww[5][j] = 24*w[2][j]; + Ww[5][j] = vmulq_n_f32(w[2][j], 24.0f); + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + const float recip576 = 1.0f / 576.0f; + + // V[i][0] = 6*Ww[i][0]; + V[i][0] = vmulq_n_f32(vmulq_n_f32(Ww[i][0], 6.0), recip576); + + // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]; + V[i][1] = vmulq_n_f32(vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576); + + // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]; + V[i][2] = vmulq_n_f32(vmulq_n_f32(vsubq_f32(vsubq_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576); + + // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]; + V[i][3] = vmulq_n_f32(vmlaq_n_f32(vmlaq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); + + // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]; + V[i][4] = vmulq_n_f32(vmlaq_n_f32(vmlsq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); + + // V[i][5] = 24*Ww[i][2]; + V[i][5] = vmulq_n_f32(vmulq_n_f32(Ww[i][2], 24.0f), recip576); + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + vst1q_f32(outptr + m*matrix_stride, V[i][j]); + } + } + outptr += 4; + } +#endif // __aarch64__ +#ifdef __arm_any__ + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used and computed in this kernel + float32x2_t w[3][3], Ww[6][3], V[6][6]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = vld1_f32(inptrs[i][j]); + inptrs[i][j] += 2; + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + // Ww[0][j] = 6*w[0][j]; + Ww[0][j] = vmul_n_f32(w[0][j], 6.0); + + // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; + Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), -4.0); + + // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; + Ww[2][j] = vmul_n_f32(vsub_f32(vsub_f32(w[1][j], w[0][j]), w[2][j]), 4.0); + + // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; + Ww[3][j] = vmla_n_f32(vmla_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); + + // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; + Ww[4][j] = vmla_n_f32(vmls_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); + + // Ww[5][j] = 24*w[2][j]; + Ww[5][j] = vmul_n_f32(w[2][j], 24.0f); + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + const float recip576 = 1.0f / 576.0f; + + // V[i][0] = 6*Ww[i][0]; + V[i][0] = vmul_n_f32(vmul_n_f32(Ww[i][0], 6.0), recip576); + + // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]; + V[i][1] = vmul_n_f32(vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576); + + // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]; + V[i][2] = vmul_n_f32(vmul_n_f32(vsub_f32(vsub_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576); + + // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]; + V[i][3] = vmul_n_f32(vmla_n_f32(vmla_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); + + // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]; + V[i][4] = vmul_n_f32(vmla_n_f32(vmls_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); + + // V[i][5] = 24*Ww[i][2]; + V[i][5] = vmul_n_f32(vmul_n_f32(Ww[i][2], 24.0f), recip576); + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + vst1_f32(outptr + m*matrix_stride, V[i][j]); + } + } + outptr += 2; + } +#endif // __arm_any__ + for (; channels_remaining; channels_remaining--) + { + // Matrices used and computed in this kernel + float w[3][3], Ww[6][3], V[6][6]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = *(inptrs[i][j]++); + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = 6*w[0][j]; + Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; + Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; + Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; + Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; + Ww[5][j] = 24*w[2][j]; + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + V[i][0] = ( 6*Ww[i][0]) / 576.0; + V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0; + V[i][2] = (-4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]) / 576.0; + V[i][3] = ( 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]) / 576.0; + V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]) / 576.0; + V[i][5] = (24*Ww[i][2]) / 576.0; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + *(outptr + m*matrix_stride) = V[i][j]; + } + } + outptr++; + } + } + } + + template <> + template <> + int WinogradGEMM<4, 4, 3, 3>::WeightsTransform::ops_performed(const KernelShape &shape) + { + const int channel_prod = shape.n_input_channels * shape.n_output_channels; + return 9 * 16 * channel_prod; + } + + template struct WinogradGEMM<4, 4, 3, 3>::WeightsTransform; +} diff --git a/src/core/NEON/kernels/winograd/utils.cpp b/src/core/NEON/kernels/winograd/utils.cpp new file mode 100644 index 0000000000..24d0386c76 --- /dev/null +++ b/src/core/NEON/kernels/winograd/utils.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include +#include + +double TimeInUs(void) +{ +#ifdef CYCLE_PROFILING + timespec t; + clock_gettime(CLOCK_REALTIME, &t); + return 1e6*t.tv_sec + 1e-3*t.tv_nsec; +#else + return 0; +#endif +} + +void PrintMatrix(const float* const m, const int M, const int N, const int row_stride) +{ + for (int i = 0; i < M; i++) + { + for (int j = 0; j < N; j++) + { + printf("%.3f ", m[i*row_stride + j]); + } + printf("\n"); + } + printf("\n"); +} diff --git a/src/core/NEON/kernels/winograd/utils.hpp b/src/core/NEON/kernels/winograd/utils.hpp deleted file mode 100644 index 14e709f028..0000000000 --- a/src/core/NEON/kernels/winograd/utils.hpp +++ /dev/null @@ -1,55 +0,0 @@ - -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include - -inline double TimeInUs(void) { -#ifdef CYCLE_PROFILING - timespec t; - clock_gettime(CLOCK_THREAD_CPUTIME_ID, &t); - return 1e6*t.tv_sec + 1e-3*t.tv_nsec; -#else - return 0; -#endif -} - -inline int iceildiv(const int a, const int b) { - return (a + b - 1) / b; -} - -template -inline T roundup(const T a, const T b) { - return a + b - (a % b); -} - -inline void PrintMatrix(const float* const m, const int M, const int N, const int row_stride) { - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - printf("%.3f ", m[i*row_stride + j]); - } - printf("\n"); - } - printf("\n"); -} diff --git a/src/core/NEON/kernels/winograd/winograd_gemm.cpp b/src/core/NEON/kernels/winograd/winograd_gemm.cpp new file mode 100644 index 0000000000..b44a45367f --- /dev/null +++ b/src/core/NEON/kernels/winograd/winograd_gemm.cpp @@ -0,0 +1,560 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "winograd_gemm.hpp" +#include "batched_blocked_gemm.hpp" +using namespace winograd; + +/** Get the output shape of a convolution. */ +template +template +Tensor4DShape WinogradGEMM::Convolution::get_output_shape( + const KernelShape &kernel_shape, + const Tensor4DShape &in_shape, + const PaddingType padding +) +{ + // TODO Accept different kernel sizes + return Tensor4DShape { + in_shape.n_batches, + (padding == PADDING_SAME) ? in_shape.n_rows : in_shape.n_rows - 2, + (padding == PADDING_SAME) ? in_shape.n_cols : in_shape.n_cols - 2, + kernel_shape.n_output_channels, + in_shape.ordering + }; +} + +/* Get the memory required to transform the kernel. + */ +template +template +size_t WinogradGEMM::Convolution::get_kernel_transform_working_size(const KernelShape &shape) +{ + if (shape.ordering == HWIO) + { + // Kernel is already in the correct order, so no additional memory is + // required. + return 0; + } + else + { + // Need to re-order the kernel into HWIO form, require enough space to + // represent the tensor. + return sizeof(TIn) * shape.size(); + } +} + +/** Get the memory required to store the kernel transformed into the + * Winograd domain. + */ +template +template +size_t WinogradGEMM::Convolution::get_kernel_storage_size(const KernelShape &shape) +{ + return N_GEMMS * get_kernel_matrix_size(shape); +} + + +template +template +size_t WinogradGEMM::Convolution::get_input_storage_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding +) +{ + return N_GEMMS * get_input_matrix_size(kernel_shape, input_shape, padding); +} + + +template +template +size_t WinogradGEMM::Convolution::get_output_storage_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding +) +{ + return N_GEMMS * get_output_matrix_size(kernel_shape, input_shape, padding); +} + + +/** Get the memory required to apply a Winograd operator to some input. + */ +template +template +size_t WinogradGEMM::Convolution::get_working_space_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type +) +{ + const auto output_shape = get_output_shape(kernel_shape, input_shape, padding_type); + + // Get the memory required to store the matrices + const size_t matrix_sizes = N_GEMMS * ( + get_input_matrix_size(kernel_shape, input_shape, padding_type) + + get_output_matrix_size(kernel_shape, input_shape, padding_type) + ); + + // Add additional space to re-order the input and output if the input tensor + // is not in NHWC format. + if (input_shape.ordering == NHWC) + { + return matrix_sizes; // No extra spacing required + } + else // NCHW, must reorder the input and output tensors + { + // We only need to re-order the input or output at any one time, so request + // enough memory to do the largest of these. + const size_t extra_memory = std::max( + sizeof(TIn) * input_shape.size(), + sizeof(TOut) * output_shape.size() + ); + return matrix_sizes + extra_memory; + } +} + + +/* Get the memory required by a single "input" matrix. + */ +template +template +size_t WinogradGEMM::Convolution::get_input_matrix_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type +) +{ + return get_input_matrix_stride(kernel_shape, input_shape, padding_type) * sizeof(TIn); +} + +template +template +int WinogradGEMM::Convolution::get_input_matrix_stride( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type +) +{ + // Compute shape for the GEMM + const auto output_shape = get_output_shape(kernel_shape, input_shape, padding_type); + const int tile_rows = iceildiv(output_shape.n_rows, output_tile_rows); + const int tile_cols = iceildiv(output_shape.n_cols, output_tile_cols); + const int M = roundup(input_shape.n_batches * tile_rows * tile_cols, M_BLOCK); + const int K = kernel_shape.n_input_channels; + + return M * K; +} + + +/* Get the memory required by a single "output" matrix. + */ +template +template +size_t WinogradGEMM::Convolution::get_output_matrix_size( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type +) +{ + return get_output_matrix_stride(kernel_shape, input_shape, padding_type) * sizeof(TOut); +} + + +template +template +int WinogradGEMM::Convolution::get_output_matrix_stride( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding_type +) +{ + // Compute shape for the GEMM + const auto output_shape = get_output_shape(kernel_shape, input_shape, padding_type); + const int tile_rows = iceildiv(output_shape.n_rows, output_tile_rows); + const int tile_cols = iceildiv(output_shape.n_cols, output_tile_cols); + const int M = roundup(tile_rows * tile_cols, M_BLOCK); + const int N = roundup(kernel_shape.n_output_channels, N_BLOCK); + + return input_shape.n_batches * M * N; +} + + +/* Get the memory required by a single "kernel" matrix. + */ +template +template +size_t WinogradGEMM::Convolution::get_kernel_matrix_size(const KernelShape &shape) +{ + return sizeof(TIn) * get_kernel_matrix_stride(shape); +} + +template +template +int WinogradGEMM::Convolution::get_kernel_matrix_stride(const KernelShape &shape) +{ + const int K = shape.n_input_channels; + const int N = roundup(shape.n_output_channels, N_BLOCK); + return K * N; +} + + +/** Create a new Winograd operator. */ +template +template +WinogradGEMM::Convolution::Convolution( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding, + void *kernel_storage +) : kernel_shape(kernel_shape), // Store the kernel shape + kernel_matrix_row_stride(roundup(kernel_shape.n_output_channels, N_BLOCK)), + manage_kernel_storage(kernel_storage == NULL), + _kernel_storage(manage_kernel_storage ? + ALLOCATE(get_kernel_storage_size(kernel_shape)) : + kernel_storage), + input_shape(input_shape), + padding(padding), + output_shape(get_output_shape(kernel_shape, input_shape, padding)), + tile_rows(iceildiv(output_shape.n_rows, output_tile_rows)), + tile_cols(iceildiv(output_shape.n_cols, output_tile_cols)), + M(input_shape.n_batches * tile_rows * tile_cols), + K(kernel_shape.n_input_channels), + N(kernel_shape.n_output_channels), + prof() +{ + // Create pointers to the kernel matrices + const int kernel_matrix_size_bytes = get_kernel_matrix_size(kernel_shape); + int8_t* const ks_bytes = reinterpret_cast(_kernel_storage); + for (int i = 0; i < N_GEMMS; i++) { + kernel_matrices[i] = reinterpret_cast( + ks_bytes + i*kernel_matrix_size_bytes); + } +} + + +/** Create a new Winograd operator and initialise the weights. */ +template +template +WinogradGEMM::Convolution::Convolution( + const KernelShape &kernel_shape, + const Tensor4DShape &input_shape, + const PaddingType padding, + const TIn* const kernel, + void *kernel_storage, + void *transform_working_space +) : Convolution(kernel_shape, input_shape, padding, kernel_storage) +{ + transform_weights(kernel, transform_working_space); +} + + +/** Clean up a convolution engine. */ +template +template +WinogradGEMM:: +Convolution::~Convolution() +{ + // If we were responsible for managing kernel storage ensure that it is + // freed. + if (manage_kernel_storage) + { + free(_kernel_storage); + } +} + + +/** Transform weights into the Winograd domain and store them for later use/reuse. */ +template +template +template +void WinogradGEMM:: +Convolution::transform_weights( + const TIn* const kernel, + void *transform_working_space +) +{ + // Allocate working space if it is required + bool allocated_working_space = false; + if (transform_working_space == NULL && // If no memory has been provided + get_kernel_transform_working_size(kernel_shape) != 0) // And we need the space + { + allocated_working_space = true; + transform_working_space = ALLOCATE( + get_kernel_transform_working_size(kernel_shape) + ); + } + + // The transformation methods only work on weights laid out in HWIO form, if + // the weights are not in this form then we need to re-order them. + const TIn *kernel_hwio = kernel; + if (kernel_shape.ordering != HWIO) + { + kernel_hwio = reinterpret_cast(transform_working_space); + + // Re-order the weights from OIHW to HWIO + this->prof( + "Weight reorder", + [&kernel, &kernel_hwio, this] () { + reorder::ofm_ifm_h_w_to_h_w_ifm_ofm( + kernel, const_cast(kernel_hwio), + kernel_shape.n_output_channels, + kernel_shape.n_input_channels, + kernel_shape.n_rows, + kernel_shape.n_cols + ); + }, + kernel_shape.size() * sizeof(TIn), + 0, + kernel_shape.size() * sizeof(TIn) + ); + } + + const int kernel_matrix_size_bytes = get_kernel_matrix_size(kernel_shape); + WeightsTransformT weights_transform( + kernel_hwio, kernel_matrices[0], + kernel_matrix_size_bytes / sizeof(TIn), + kernel_matrix_row_stride, + kernel_shape.n_output_channels, + kernel_shape.n_input_channels + ); + + // Transform the weights into the Winograd domain + auto kernel_prep = [&] () + { + weights_transform.run(0, weights_transform.get_window()); + }; + + prof( + "Kernel Prep", kernel_prep, + WeightsTransformT::bytes_read(kernel_shape), + WeightsTransformT::ops_performed(kernel_shape), + WeightsTransformT::bytes_written(kernel_shape) + ); + + // Free memory if we allocated it + if (allocated_working_space) + { + free(transform_working_space); + } +} + + +/** Perform a convolution. */ +template +template +void WinogradGEMM:: +Convolution::execute( + TOut* const output, + const TIn* const input, + void *working_space, + const int n_threads +) +{ + const auto padding_type = padding; + const auto input_shape = this->input_shape; + + // Allocate working space if none has been provided + const bool manage_working_space = (working_space == NULL); + if (manage_working_space) + { + const size_t ws_size = get_working_space_size( + kernel_shape, input_shape, padding_type + ); + working_space = ALLOCATE(ws_size * sizeof(int8_t)); + memset(working_space, 0x00, ws_size); + } + int8_t* const ws_bytes = reinterpret_cast(working_space); + + // Split the working space into that required for 16 input matrices and + // output matrices. + TIn *input_matrices[N_GEMMS]; + TOut *output_matrices[N_GEMMS]; + const int in_matrix_stride_bytes = get_input_matrix_size(kernel_shape, input_shape, padding_type); + const int out_matrix_stride_bytes = get_output_matrix_size(kernel_shape, input_shape, padding_type); + + for (int i = 0; i < N_GEMMS; i++) + { + input_matrices[i] = reinterpret_cast( + ws_bytes + i*in_matrix_stride_bytes); + output_matrices[i] = reinterpret_cast( + ws_bytes + N_GEMMS*in_matrix_stride_bytes + i*out_matrix_stride_bytes); + } + + // If we need to re-order the input and output tensors then the final chunk + // of the working space can be used for this purpose. + // TODO - Overlay the input reorder on top of the output matrices + // - Overlay the output reorder on top of the input matrices + // Reorder the input input form if it was not provided in this ordering. + const TIn* input_nhwc = input; + if (input_shape.ordering == NCHW) + { + input_nhwc = reinterpret_cast( + ws_bytes + N_GEMMS*(in_matrix_stride_bytes + out_matrix_stride_bytes) + ); + + this->prof( + "NCHW -> NHWC", + [input, input_shape, input_nhwc] () { + reorder::nchw_to_nhwc( + input, const_cast(input_nhwc), + input_shape.n_batches, + input_shape.n_channels, + input_shape.n_rows, + input_shape.n_cols + ); + }, + input_shape.size(), 0, input_shape.size() + ); + } + + // Compute shape for the GEMM + const auto output_shape = this->output_shape; + int M = this->M; + int K = this->K; + int N = this->N; + + const int in_matrix_row_stride = K; + const int out_matrix_row_stride = kernel_matrix_row_stride; + + InputTransform input_transform( + input_nhwc, + input_shape.n_batches, + input_shape.n_rows, + input_shape.n_cols, + input_shape.n_channels, + padding_type, + input_matrices[0], + in_matrix_stride_bytes / sizeof(TIn), + in_matrix_row_stride + ); + + // Transform the input into the Winograd domain + auto input_prep = [&] () { + input_transform.run(0, input_transform.get_window()); + }; + prof( + "Input Prep", input_prep, + InputTransform::bytes_read(input_shape), + InputTransform::ops_performed(input_shape), + InputTransform::bytes_written(input_shape) + ); + + // Perform the GEMMs + const int kernel_matrix_stride_bytes = get_kernel_matrix_size(kernel_shape); + BatchedBlockedGemm gemms( + N_GEMMS, M, K, N, + in_matrix_stride_bytes / sizeof(TIn), + in_matrix_row_stride, + kernel_matrix_stride_bytes / sizeof(TIn), + kernel_matrix_row_stride, + out_matrix_stride_bytes / sizeof(TOut), + out_matrix_row_stride, + input_matrices[0], + kernel_matrices[0], + output_matrices[0] + ); + gemms.run(0, gemms.get_window()); + + // If the output tensor needs to be in NCHW form then store the NHWC output + // tensor in temporary storage and then reorder. If the output tensor needs + // to be in NHWC then just write straight to the output tensor. + TOut *output_nhwc = output; + if (input_shape.ordering == NCHW) + { + output_nhwc = reinterpret_cast( + ws_bytes + N_GEMMS*(in_matrix_stride_bytes + out_matrix_stride_bytes) + ); + } + + // Transform the output tensor from the Winograd domain to the spatial + // domain. + OutputTransform output_transform( + output_matrices[0], + out_matrix_stride_bytes / sizeof(TOut), + out_matrix_row_stride, + output_nhwc, + output_shape.n_batches, + output_shape.n_rows, + output_shape.n_cols, + output_shape.n_channels + ); + auto output_prep = [&] () { + output_transform.run(0, output_transform.get_window()); + }; + prof( + "Output Comp", output_prep, + OutputTransform::bytes_read(output_shape), + OutputTransform::ops_performed(output_shape), + OutputTransform::bytes_written(output_shape) + ); + + // Reorder the output tensor if it is required to be in NCHW form. + if (input_shape.ordering == NCHW) + { + prof( + "NHWC -> NCHW", + [output_nhwc, output_shape, output] () { + reorder::nhwc_to_nchw( + output_nhwc, output, + output_shape.n_batches, + output_shape.n_rows, + output_shape.n_cols, + output_shape.n_channels + ); + }, + output_shape.size(), 0, output_shape.size() + ); + } + + // Free working space if we were responsible for allocating it + if (manage_working_space) + { + free(working_space); + } +} + + +/** Perform a convolution. */ +template +template +void WinogradGEMM:: +Convolution::execute( + TOut* const output, + const TIn* const input, + const int n_threads +) +{ + execute(output, input, NULL, n_threads); +} + + +// Instantiate required implementations +template class WinogradGEMM<2, 2, 3, 3>::Convolution; +template class WinogradGEMM<4, 4, 3, 3>::Convolution; diff --git a/src/core/NEON/kernels/winograd/winograd_gemm.hpp b/src/core/NEON/kernels/winograd/winograd_gemm.hpp deleted file mode 100644 index 59afa2f5ab..0000000000 --- a/src/core/NEON/kernels/winograd/winograd_gemm.hpp +++ /dev/null @@ -1,345 +0,0 @@ -/* - * Copyright (c) 2017 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once -#include -#include -#include - -#include "gemm.hpp" -#include "profiler.hpp" -#include "utils.hpp" -#include "shims.hpp" - -#include "transforms.hpp" - -namespace winograd { - /***************************************************************************/ - /* Implementation of the Winograd F(2x2, 3x3, 4x4) algorithm using GEMM - * internally. - */ - template - class Winograd2x2_3x3GEMM { - public: - /* Instantiate a new Winograd operator. - */ - Winograd2x2_3x3GEMM(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage); - virtual ~Winograd2x2_3x3GEMM(); - - /** Transform the weights into the Winograd domain. - */ - template > - void transform_weights(const TIn* const kernel, void *transform_working_space); - - /* Initializes matrices pointers, to be called once before execute() - */ - template > - void reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const TIn* const input, void* working_space); - - /* Apply the Winograd operator to some input. - */ - template > - void reshape_output(const Tensor4DShape& input_shape, const PaddingType padding_type, TOut* const output); - - - /* Apply the Winograd operator to some input. - */ - void execute(size_t first, size_t last); - - /* Get the memory required to transform the kernel. - */ - static inline size_t get_kernel_transform_working_size(const KernelShape &shape); - - /* Get the output shape of a convolution. - */ - static Tensor4DShape get_output_shape(const Tensor4DShape &input_shape, const KernelShape &k_shape, - const PaddingType padding_type); - - /* Get the memory required to instantiate a new Winograd operator. - */ - static size_t get_kernel_storage_size(const KernelShape &shape); - - /* Get the memory required to apply a Winograd operator to some input. - */ - static size_t get_working_space_size(const Tensor4DShape &input_shape,const KernelShape &k_shape, - const PaddingType padding); - - - Winograd2x2_3x3GEMM(const Winograd2x2_3x3GEMM &) = delete; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - Winograd2x2_3x3GEMM &operator=(const Winograd2x2_3x3GEMM &) = delete; - /** Allow instances of this class to be moved */ - Winograd2x2_3x3GEMM(Winograd2x2_3x3GEMM &&) = default; - /** Allow instances of this class to be moved */ - Winograd2x2_3x3GEMM &operator=(Winograd2x2_3x3GEMM &&) = default; - - protected: - /* Get the memory required by a single "input" matrix. - */ - static size_t get_input_matrix_size(const Tensor4DShape &input_shape,const KernelShape &k_shape, - const PaddingType padding); - - /* Get the memory required by a single "output" matrix. - */ - static size_t get_output_matrix_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, - const PaddingType padding); - - /* Get the memory required by a single "kernel" matrix. - */ - static size_t get_kernel_matrix_size(const KernelShape &shape); - - const KernelShape kernel_shape; // Shape of applied kernel - const Tensor4DShape in_shape; - const PaddingType padding; - - const int kernel_matrix_row_stride; // Stride within kernel matrix - - const bool manage_kernel_storage; // Free kernel storage when done - void* const _kernel_storage; // Base pointer for kernel matrices - - profiler prof; // Profiler - - TIn *kernel_matrices[16]; // Prepared form of kernel - TIn *input_matrices[16]; - TOut *output_matrices[16]; - - - static const int M_BLOCK = 4; - static const int N_BLOCK = 16; - }; -} // namespace winograd - -template -size_t winograd::Winograd2x2_3x3GEMM::get_kernel_transform_working_size( - const KernelShape &shape -) -{ - // Need to re-order the kernel into HWIO form, require enough space to - // represent the tensor. - return sizeof(TIn) * shape.size(); -} - - -template -template -void winograd::Winograd2x2_3x3GEMM::transform_weights( - const TIn* const kernel, - void *transform_working_space -) -{ - const int kernel_matrix_size_bytes = get_kernel_matrix_size(kernel_shape); - int8_t* const ks_bytes = reinterpret_cast(_kernel_storage); - for (int i = 0; i < 16; i++) { - kernel_matrices[i] = reinterpret_cast( - ks_bytes + i*kernel_matrix_size_bytes); - } - - const TIn *kernel_hwio = kernel; - if( transform_working_space) - { - kernel_hwio = reinterpret_cast(transform_working_space); - ofm_ifm_h_w_to_h_w_ifm_ofm( - kernel, const_cast(kernel_hwio), - kernel_shape.n_output_channels, - kernel_shape.n_input_channels, - kernel_shape.n_rows, - kernel_shape.n_cols - ); - } - KernelTransform::execute( - kernel_shape, kernel_hwio, kernel_matrices[0], - kernel_matrix_size_bytes / sizeof(TIn), - kernel_matrix_row_stride - ); -} - -template -winograd::Winograd2x2_3x3GEMM::Winograd2x2_3x3GEMM( const KernelShape &kernel_shape, const Tensor4DShape input_shape, - const PaddingType padding_type, void *kernel_storage) - : kernel_shape(kernel_shape), in_shape(input_shape), padding(padding_type),kernel_matrix_row_stride(roundup(kernel_shape.n_output_channels, N_BLOCK)), manage_kernel_storage(false), - _kernel_storage(kernel_storage), prof() { - memset(kernel_matrices, 0x00, sizeof(TIn)*16); - memset(input_matrices, 0x00, sizeof(TIn)*16); - memset(output_matrices, 0x00, sizeof(TOut)*16); -} - -/*****************************************************************************/ -template -winograd::Winograd2x2_3x3GEMM::~Winograd2x2_3x3GEMM() {} - -/*****************************************************************************/ -template -template -void winograd::Winograd2x2_3x3GEMM::reshape_input( - const Tensor4DShape& input_shape, - const PaddingType padding_type, - const TIn* const input, - void *working_space -) { - assert(working_space); - int8_t* const ws_bytes = reinterpret_cast(working_space); - // Split the working space into that required for 16 input matrices and - // output matrices. - const int in_matrix_stride_bytes = get_input_matrix_size(input_shape, kernel_shape, padding_type); - const int out_matrix_stride_bytes = get_output_matrix_size(input_shape, kernel_shape, padding_type); - - for (int i = 0; i < 16; i++) { - input_matrices[i] = reinterpret_cast( - ws_bytes + i*in_matrix_stride_bytes); - output_matrices[i] = reinterpret_cast( - ws_bytes + 16*in_matrix_stride_bytes + i*out_matrix_stride_bytes); - } - - // Compute shape for the GEMM - const auto output_shape = get_output_shape(input_shape,kernel_shape, padding_type); - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - const int K = kernel_shape.n_input_channels; - - const int in_matrix_row_stride = K; - const int in_matrix_batch_stride = tile_rows*tile_cols*in_matrix_row_stride; - - // Transform the input tensor into an appropriate form - auto input_prep = [&] () { - InputTransform::execute( - input, input_shape, padding_type, tile_rows, tile_cols, - input_matrices[0], in_matrix_stride_bytes / sizeof(TIn), - in_matrix_batch_stride, in_matrix_row_stride - ); - }; - prof( - "Input Prep", input_prep, - InputTransform::bytes_read(input_shape, output_shape), - InputTransform::flops_performed(input_shape, output_shape), - InputTransform::bytes_written(input_shape, output_shape) - ); - -} - -/*****************************************************************************/ -template -template -void winograd::Winograd2x2_3x3GEMM::reshape_output(const Tensor4DShape& input_shape, const PaddingType padding_type, TOut* const output) { - assert(output_matrices[0]); - const int out_matrix_stride_bytes = get_output_matrix_size(input_shape, kernel_shape, padding_type); - const auto output_shape = get_output_shape(input_shape,kernel_shape, padding_type); - const int out_matrix_row_stride = kernel_matrix_row_stride; - - // Transform the output tensor into an appropriate form - OutputTransform::execute( - output_shape, - output_matrices[0], - out_matrix_stride_bytes / sizeof(TOut), - out_matrix_row_stride, - output - ); -} - - -/*****************************************************************************/ -template -void winograd::Winograd2x2_3x3GEMM::execute( size_t first, size_t last ) { - assert(input_matrices[0] && kernel_matrices[0] && output_matrices[0]); - assert(first < 16 && last < 16 && first < last); - // Compute shape for the GEMM - const auto output_shape = get_output_shape(in_shape,kernel_shape, padding); - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - const int M = in_shape.n_batches * tile_rows * tile_cols; - const int K = kernel_shape.n_input_channels; - const int N = kernel_shape.n_output_channels; - - const int in_matrix_row_stride = K; - const int out_matrix_row_stride = kernel_matrix_row_stride; - // Perform the GEMMs - for (size_t i = first; i <= last; i++) { - BlockedGemm( - input_matrices[i], kernel_matrices[i], output_matrices[i], M, K, N, - in_matrix_row_stride, kernel_matrix_row_stride, out_matrix_row_stride - ); -// prof("GEMM", perform_gemm, 0, 2*M*K*N, 0); // TODO Memory - } - -} - -/*****************************************************************************/ -template -Tensor4DShape winograd::Winograd2x2_3x3GEMM::get_output_shape( - const Tensor4DShape &in_shape, const KernelShape &k_shape, const PaddingType padding) { - return Tensor4DShape { - in_shape.n_batches, - (padding == PADDING_SAME) ? in_shape.n_rows : in_shape.n_rows - 2, - (padding == PADDING_SAME) ? in_shape.n_cols : in_shape.n_cols - 2, - k_shape.n_output_channels - }; -} - -template -size_t winograd::Winograd2x2_3x3GEMM::get_kernel_storage_size( - const KernelShape &shape) { - return 16 * get_kernel_matrix_size(shape); -} - -template -size_t winograd::Winograd2x2_3x3GEMM::get_kernel_matrix_size( - const KernelShape &shape) { - const int K = shape.n_input_channels; - const int N = roundup(shape.n_output_channels, N_BLOCK); - return sizeof(TIn) * K * N; -} - -template -size_t winograd::Winograd2x2_3x3GEMM::get_working_space_size( - const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type -) { - return 16 * get_input_matrix_size(input_shape, k_shape, padding_type) + - 16 * get_output_matrix_size(input_shape, k_shape, padding_type); -} - -template -size_t winograd::Winograd2x2_3x3GEMM::get_input_matrix_size( - const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type -) { - // Compute shape for the GEMM - const auto output_shape = get_output_shape(input_shape, k_shape, padding_type); - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - const int M = roundup(tile_rows * tile_cols, M_BLOCK); - const int K = k_shape.n_input_channels; - - return input_shape.n_batches * M * K * sizeof(TIn); -} - -template -size_t winograd::Winograd2x2_3x3GEMM::get_output_matrix_size( - const Tensor4DShape& input_shape, const KernelShape &k_shape,const PaddingType padding_type -) { - // Compute shape for the GEMM - const auto output_shape = get_output_shape(input_shape, k_shape, padding_type); - const int tile_rows = iceildiv(output_shape.n_rows, 2); - const int tile_cols = iceildiv(output_shape.n_cols, 2); - const int M = roundup(tile_rows * tile_cols, M_BLOCK); - const int N = roundup(k_shape.n_output_channels, N_BLOCK); - - return input_shape.n_batches * M * N * sizeof(TOut); -} diff --git a/src/core/NEON/kernels/winograd/winograd_layer.cpp b/src/core/NEON/kernels/winograd/winograd_layer.cpp new file mode 100644 index 0000000000..689ecba5fb --- /dev/null +++ b/src/core/NEON/kernels/winograd/winograd_layer.cpp @@ -0,0 +1,204 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "convolution.hpp" +#include "winograd_layer.hpp" +#include "tensor.hpp" + + +/** Determine how much memory (in units of TIn) to allocate for the transformed + * weights. + */ +template < + int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols, + typename TIn, typename TOut +> +unsigned int WinogradConvolutionLayer< + OutputTileRows, OutputTileCols, KernelRows, KernelCols, TIn, TOut +>::get_weight_storage_size( + const int n_output_channels, /** Number of output feature maps. */ + const int n_input_channels /** Number of input feature maps. */ +) +{ + const KernelShape shape( + n_output_channels, KernelRows, KernelCols, n_input_channels + ); + return static_cast( + // WinogradConv returns the size in bytes, we divide by `sizeof(TIn)` to + // express that in units of TIn. + WinogradConv::get_kernel_storage_size(shape) / sizeof(TIn) + ); +} + + +/** Determine how much memory (in units of TIn) to allocate for the transformed + * input. + */ +template < + int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols, + typename TIn, typename TOut +> +unsigned int WinogradConvolutionLayer< + OutputTileRows, OutputTileCols, KernelRows, KernelCols, TIn, TOut +>::get_input_storage_size( + const int n_batches, /** Number of batches in the input tensor. */ + const int n_channels, /** Number of feature maps in the input tensor. */ + const int n_rows, /** Number of rows in each feature map. */ + const int n_cols, /** Number of columns in each feature map. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ +) +{ + // Construct shapes for the input and kernel tensors. + const Tensor4DShape input_shape(n_batches, n_rows, n_cols, n_channels); + const KernelShape kern_shape(1, KernelRows, KernelCols, n_channels); + const PaddingType padding = (same_padding) ? PADDING_SAME : PADDING_VALID; + + // Return the size, converted into units of TIn + return static_cast( + WinogradConv::get_input_storage_size(kern_shape, input_shape, padding) / + sizeof(TIn) + ); +} + + +/** Determine how much memory (in units of TOut) to allocate for the (Winograd + * domain) output. + */ +template < + int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols, + typename TIn, typename TOut +> +unsigned int WinogradConvolutionLayer< + OutputTileRows, OutputTileCols, KernelRows, KernelCols, TIn, TOut +>::get_output_storage_size( + const int n_batches, /** Number of batches in the output tensor. */ + const int n_rows, /** Number of rows in each feature map of the input tensor. */ + const int n_cols, /** Number of columns in each feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ +) +{ + // Construct shapes for the input and kernel tensors. + const Tensor4DShape input_shape(n_batches, n_rows, n_cols, 1); + const KernelShape kern_shape(n_output_channels, KernelRows, KernelCols, 1); + const PaddingType padding = (same_padding) ? PADDING_SAME : PADDING_VALID; + + // Return the size, converted into units of TOut + return static_cast( + WinogradConv::get_output_storage_size(kern_shape, input_shape, padding) / + sizeof(TOut) + ); +} + + +/** Get the shape (rows, cols) of a feature map of the output tensor. */ +template < + int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols, + typename TIn, typename TOut +> +std::pair WinogradConvolutionLayer< + OutputTileRows, OutputTileCols, KernelRows, KernelCols, TIn, TOut +>::get_output_feature_map_shape( + const int n_input_rows, /** Number of rows in the input feature map. */ + const int n_input_cols, /** Number of columns in the input feature map. */ + const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ +) +{ + // Construct shapes for the input and kernel tensors. + const Tensor4DShape input_shape(1, n_input_rows, n_input_cols, 1); + const KernelShape kern_shape(1, KernelRows, KernelCols, 1); + const PaddingType padding = (same_padding) ? PADDING_SAME : PADDING_VALID; + + // Compute the new shape + const auto output_shape = WinogradConv::get_output_shape( + kern_shape, input_shape, padding + ); + + return std::make_pair(output_shape.n_rows, output_shape.n_cols); +} + + +/** Create a new Winograd convolution layer. + */ +template < + int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols, + typename TIn, typename TOut +> +WinogradConvolutionLayer:: +WinogradConvolutionLayer( + const int n_batches, /** Number of batches in the input and output tensors. */ + const int n_input_channels, /** Number of feature maps in a batch of the input tensor. */ + const int n_input_rows, /** Number of rows in a feature map of the input tensor. */ + const int n_input_cols, /** Number of columns in a feature map of the input tensor. */ + const int n_output_channels, /** Number of feature maps in the output tensor. */ + const bool same_padding, /** Use "SAME" padding, otherwise use "VALID". */ + const TIn* const weights, /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */ + TIn* const winograd_weights, /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */ + const TIn* const input, /** Pointer to NHWC ordered input tensor, in the spatial domain. */ + TIn* const winograd_input, /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */ + TOut* const output, /** Pointer to NHWC ordered output tensor, in the spatial domain. */ + TOut* const winograd_output /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */ +) : _kernel_shape(n_output_channels, KernelRows, KernelCols, n_input_channels), + _input_shape(n_batches, n_input_rows, n_input_cols, n_input_channels), + _padding(same_padding ? PADDING_SAME : PADDING_VALID), + _output_shape(WinogradConv::get_output_shape(_kernel_shape, _input_shape, _padding)), + _n_output_rows(_output_shape.n_rows), + _n_output_cols(_output_shape.n_cols), + _kernel_matrix_stride(WinogradConv::get_kernel_matrix_stride(_kernel_shape)), + _kernel_matrix_row_stride(roundup(n_output_channels, WinogradConv::N_BLOCK)), + _input_matrix_stride(WinogradConv::get_input_matrix_stride(_kernel_shape, _input_shape, _padding)), + _input_matrix_row_stride(n_input_channels), + _output_matrix_stride(WinogradConv::get_output_matrix_stride(_kernel_shape, _input_shape, _padding)), + _output_matrix_row_stride(_kernel_matrix_row_stride), + _tile_rows(iceildiv(_n_output_rows, OutputTileRows)), + _tile_cols(iceildiv(_n_output_cols, OutputTileCols)), + _m(n_batches * _tile_rows * _tile_cols), + _k(n_input_channels), + _n(n_output_channels), + weights_transform( + weights, winograd_weights, + _kernel_matrix_stride, _kernel_matrix_row_stride, + n_output_channels, n_input_channels + ), + input_transform( + input, n_batches, n_input_rows, n_input_cols, n_input_channels, _padding, + winograd_input, _input_matrix_stride, _input_matrix_row_stride + ), + gemms( + WinogradBase::N_GEMMS, _m, _k, _n, + _input_matrix_stride, _input_matrix_row_stride, + _kernel_matrix_stride, _kernel_matrix_row_stride, + _output_matrix_stride, _output_matrix_row_stride, + winograd_input, winograd_weights, winograd_output + ), + output_transform( + winograd_output, _output_matrix_stride, _output_matrix_row_stride, + output, n_batches, _n_output_rows, _n_output_cols, n_output_channels + ) +{ +} + +// Instantiate valid implementations. +template class WinogradConvolutionLayer<2, 2, 3, 3, float, float>; +template class WinogradConvolutionLayer<4, 4, 3, 3, float, float>; diff --git a/src/runtime/NEON/functions/NEWinogradLayer.cpp b/src/runtime/NEON/functions/NEWinogradLayer.cpp index 21f298ca25..da46f8773c 100644 --- a/src/runtime/NEON/functions/NEWinogradLayer.cpp +++ b/src/runtime/NEON/functions/NEWinogradLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, 2018 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -43,8 +43,8 @@ inline Tensor4DShape internal_get_input_shape(const arm_compute::ITensor *input) namespace arm_compute { NEWinogradLayer::NEWinogradLayer(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _winograd_kernel(), _permute_input(), _permute_weights(), _permute_output(), _workspace(), _kernel_storage(), _input_nhwc(), _output_nhwc(), - _weights_hwio(), _input(), _weights(), _output(), _reshaped_kernel(false), _conv() + : _memory_group(std::move(memory_manager)), _winograd_kernel(), _permute_input(), _permute_weights(), _permute_output(), _input_workspace(), _output_workspace(), _kernel_storage(), _input_nhwc(), + _output_nhwc(), _weights_hwio(), _input(), _weights(), _output(), _reshaped_kernel(false), _conv() { } /* arm_compute */ @@ -72,36 +72,37 @@ void NEWinogradLayer::configure(const ITensor *input, const ITensor *weights, co ARM_COMPUTE_ERROR_ON_MSG(stride_y != 1 || stride_x != 1, "Winograd layer only supports unit strides."); // Get convolved dimensions - auto padding = PADDING_VALID; - const int in_channels = input->info()->dimension(2); - const int out_channels = output->info()->dimension(2); - const int weights_width = weights->info()->dimension(0); - const int weights_height = weights->info()->dimension(1); + const int in_channels = input->info()->dimension(2); + const int out_channels = output->info()->dimension(2); - const KernelShape kernel_shape({ out_channels, weights_height, weights_width, in_channels }); const Tensor4DShape in_shape(internal_get_input_shape(input)); // Get the memory required to instantiate a new Winograd operator. - constexpr size_t kstore_alignment = 64; - const size_t kernel_storage_per_thread = NEWinogradLayerKernel::get_kernel_storage_size(kernel_shape); - _kernel_storage.allocator()->init(TensorInfo(TensorShape{ (kernel_storage_per_thread + kstore_alignment - 1) }, 1, DataType::U8)); + constexpr size_t storage_alignment = 64; + const size_t kernel_storage_size = NEWinogradLayerKernel::get_weight_storage_size(out_channels, in_channels) * sizeof(float); + _kernel_storage.allocator()->init(TensorInfo(TensorShape{ (kernel_storage_size + storage_alignment - 1) }, 1, DataType::U8)); _memory_group.manage(&_kernel_storage); - - // Get workbench size and allocate memory - - constexpr size_t wspace_alignment = 64; - const size_t ws_size = NEWinogradLayerKernel::get_working_space_size(in_shape, kernel_shape, padding); - _workspace.allocator()->init(TensorInfo(TensorShape{ (ws_size + wspace_alignment - 1) }, 1, DataType::U8)); - _memory_group.manage(&_workspace); _memory_group.manage(&_input_nhwc); _kernel_storage.allocator()->allocate(); - _workspace.allocator()->allocate(); + // Input storage + const size_t input_storage_size = NEWinogradLayerKernel::get_input_storage_size(in_shape.n_batches, in_shape.n_channels, in_shape.n_rows, in_shape.n_cols, false) * sizeof(float); + _input_workspace.allocator()->init(TensorInfo(TensorShape{ (input_storage_size + storage_alignment - 1) }, 1, DataType::U8)); + _memory_group.manage(&_input_workspace); + _input_workspace.allocator()->allocate(); + + // Output storage + const size_t output_storage_size = NEWinogradLayerKernel::get_output_storage_size(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels, false) * sizeof(float); + _output_workspace.allocator()->init(TensorInfo(TensorShape{ (output_storage_size + storage_alignment - 1) }, 1, DataType::U8)); + _memory_group.manage(&_output_workspace); + _output_workspace.allocator()->allocate(); - // Create Winograd operator object - _conv = support::cpp14::make_unique(kernel_shape, in_shape, padding, _kernel_storage.buffer()); + // configure and allocate dst tensor to be used to convert from winograd domain to spatial domain when calling to reshape_output() + TensorInfo info(TensorShape(_output->info()->dimension(2), _output->info()->dimension(0), + _output->info()->dimension(1), _output->info()->dimension(3)), + 1, _output->info()->data_type()); + _output_nhwc.allocator()->init(info); - // Configure the kernel, padding not needed so it's safe to call configure after allocare - _winograd_kernel.configure(_conv.get()); + _output_nhwc.allocator()->allocate(); // Re-order a weight tensor from [Output feature map x Input feature map x Height x Width] to [Height x Width x Input feature map x Output feature map] switch(weights->info()->num_dimensions()) @@ -122,60 +123,56 @@ void NEWinogradLayer::configure(const ITensor *input, const ITensor *weights, co break; } } + + _weights_hwio.allocator()->allocate(); + // configure the kernel to transform the input tensor from NCHW -> NHWC _permute_input.configure(input, &_input_nhwc, PermutationVector(2U, 0U, 1U)); - // configure and allocate dst tensor to be used to convert from winograd domain to spatial domain when calling to reshape_output() - TensorInfo info(TensorShape(_output->info()->dimension(2), _output->info()->dimension(0), - _output->info()->dimension(1), _output->info()->dimension(3)), - 1, _output->info()->data_type()); - _output_nhwc.allocator()->init(info); - - _output_nhwc.allocator()->allocate(); - _weights_hwio.allocator()->allocate(); _input_nhwc.allocator()->allocate(); + + // Create Winograd operator object + _conv = support::cpp14::make_unique( + in_shape.n_batches, + in_shape.n_channels, + in_shape.n_rows, + in_shape.n_cols, + out_channels, + false, + reinterpret_cast(_weights_hwio.buffer()), + reinterpret_cast(_kernel_storage.buffer()), + reinterpret_cast(_input_nhwc.buffer()), + reinterpret_cast(_input_workspace.buffer()), + reinterpret_cast(_output_nhwc.buffer()), + reinterpret_cast(_output_workspace.buffer())); + + // Configure the kernel, padding not needed so it's safe to call configure after allocare + _winograd_kernel.configure(_conv.get()); + + // Reorder the convoluted output to ACL's ordering NCHW + _permute_output.configure(&_output_nhwc, _output, PermutationVector(1U, 2U, 0U)); + } void NEWinogradLayer::run() { -#if defined(__aarch64__) _memory_group.acquire(); if(!_reshaped_kernel) { _reshaped_kernel = true; _permute_weights.run(); - _conv->transform_weights(reinterpret_cast(_weights_hwio.buffer()), nullptr); + _conv->transform_weights(); } - const Tensor4DShape in_shape(internal_get_input_shape(_input)); - auto padding = PADDING_VALID; - //Bring channels to the front as Winograd code expects the tensor to be in the format NHWC _permute_input.run(); - - //Setup matrices ptrs and transfor the input tensor to the appropriate form before running GEMM. - _conv->reshape_input(in_shape, padding, reinterpret_cast(_input_nhwc.buffer()), _workspace.buffer()); - + // Transform input tensor to the winograd domain + _conv->transform_input(); //Run 16 GEMMs in multiple threads, each kernel runs one or more GEMMs NEScheduler::get().schedule(&_winograd_kernel, Window::DimX); - - //Transform the output to the appropriate form - _conv->reshape_output(in_shape, padding, reinterpret_cast(_output_nhwc.buffer())); - + // Transform output tensor to the spatial domain + _conv->transform_output(); // Reorder the convoluted output to ACL's ordering NCHW - _permute_output.configure(&_output_nhwc, _output, PermutationVector(1U, 2U, 0U)); _permute_output.run(); - _memory_group.release(); -#else /* __aarch64__ */ - ARM_COMPUTE_UNUSED(_winograd_kernel); - ARM_COMPUTE_UNUSED(_workspace); - ARM_COMPUTE_UNUSED(_kernel_storage); - ARM_COMPUTE_UNUSED(_input); - ARM_COMPUTE_UNUSED(_weights); - ARM_COMPUTE_UNUSED(_output); - ARM_COMPUTE_UNUSED(_reshaped_kernel); - ARM_COMPUTE_UNUSED(_conv); - ARM_COMPUTE_ERROR("Winograd only supported for aarch64, recompile with arch=arm64-v8a."); -#endif /* __aarch64__ */ } } // namespace arm_compute -- cgit v1.2.1