diff options
Diffstat (limited to 'src/core')
12 files changed, 2787 insertions, 355 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index cf91ee0652..7f171ec15a 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -33,6 +33,7 @@ #include "kernels/a32_sgemm_8x6.hpp" #include "kernels/a64_hybrid_fp32_mla_16x4.hpp" +#include "kernels/a64_hybrid_fp32_mla_4x8.hpp" #include "kernels/a64_native_fp32_mla_16x4.hpp" #include "kernels/a64_smallK_hybrid_fp32_mla_4x6.hpp" #include "kernels/a64_smallK_hybrid_fp32_mla_4x8.hpp" @@ -106,9 +107,16 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] = }, { GemmMethod::GEMM_HYBRID, + "hybrid_fp32_mla_4x8_normal", + [](const GemmArgs &args) { return (args._Ksize >= 4) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args) { return (args._Nsize < 12); }, + [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_4x8, float, float>(args); } +}, +{ + GemmMethod::GEMM_HYBRID, "hybrid_fp32_mla_16x4", [](const GemmArgs &args) { return (args._Ksize >= 4) && !args._trA && args._pretransposed_hint; }, - [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, + [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || (args._Msize < 16) || (args._nmulti > 1); }, [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_16x4, float, float>(args); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp index 574ecef5b2..22b6960baf 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp @@ -68,7 +68,7 @@ class GemmHybridQuantized : public GemmCommon<To, Tr> { const NDRange<4> _window_range; - ARequantizeLayer32 _qp; + Requantize32 _qp; int32_t *row_bias = nullptr; int32_t *col_bias = nullptr; @@ -140,7 +140,7 @@ public: GemmHybridQuantized & operator= (GemmHybridQuantized &) = delete; /* Constructor */ - GemmHybridQuantized(const GemmArgs &args, const ARequantizeLayer32 &qp) + GemmHybridQuantized(const GemmArgs &args, const Requantize32 &qp) : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), _nbatches(args._nbatches), _nmulti(args._nmulti), _trB(args._trB), _k_block(compute_k_block(args)), _n_block(compute_n_block(args)), diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp new file mode 100644 index 0000000000..73d0c272a6 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2019 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include "arm_gemm.hpp" + +#include "kernels/a64_hybrid_s8s32_dot_16x4.hpp" +#include "kernels/a64_smallK_hybrid_s8s32_dot_4x6.hpp" +#include "kernels/a64_smallK_hybrid_s8s32_dot_4x8.hpp" +#include "kernels/sve_hybrid_s8s32_dot_4VLx4.hpp" +#include "kernels/sve_smallK_hybrid_s8s32_dot_1VLx8.hpp" + +#include "gemm_hybrid_quantized.hpp" +#include "quantize_wrapper.hpp" + +namespace arm_gemm { + +static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods[] = +{ +#ifdef __ARM_FEATURE_SVE +{ + GemmMethod::GEMM_HYBRID_QUANTIZED, + "smallK_hybrid_s8s32_dot_1VLx8", + [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64 && !args._trA && args._pretransposed_hint; }, + nullptr, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_s8s32_dot_1VLx8, int8_t, int8_t>(args, qp); } +}, +{ + GemmMethod::GEMM_HYBRID_QUANTIZED, + "hybrid_s8s32_dot_4VLx4", + [](const GemmArgs &args, const Requantize32 &) { return args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_s8s32_dot_4VLx4, int8_t, int8_t>(args, qp); } +}, +#endif +{ + GemmMethod::GEMM_HYBRID_QUANTIZED, + "smallK_hybrid_s8s32_dot_4x8", + [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._trA && args._pretransposed_hint; }, + nullptr, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_s8s32_dot_4x8, int8_t, int8_t>(args, qp); } +}, +{ + GemmMethod::GEMM_HYBRID_QUANTIZED, + "smallK_hybrid_s8s32_dot_4x6", + [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._trA && args._pretransposed_hint; }, + nullptr, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_s8s32_dot_4x6, int8_t, int8_t>(args, qp); } +}, +{ + GemmMethod::GEMM_HYBRID_QUANTIZED, + "hybrid_s8s32_dot_16x4", + [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return args._Nsize<=256 && args._Ksize>128; }, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_s8s32_dot_16x4, int8_t, int8_t>(args, qp); } +}, +{ + GemmMethod::QUANTIZE_WRAPPER, + "quantized_wrapper", + nullptr, + nullptr, + [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<int8_t, int8_t, int32_t>(args, qp); } +}, +{ + GemmMethod::DEFAULT, + "", + nullptr, + nullptr, + nullptr +} +}; + +template<> +const GemmImplementation<int8_t, int8_t, Requantize32> *gemm_implementation_list<int8_t, int8_t, Requantize32>() { + return gemm_qint8_methods; +} + +template UniqueGemmCommon<int8_t, int8_t> gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); +template KernelDescription get_gemm_method<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); +template std::vector<KernelDescription> get_compatible_kernels<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index 079c04ae06..59cd1704ff 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -36,51 +36,51 @@ namespace arm_gemm { -static const GemmImplementation<uint8_t, uint8_t, ARequantizeLayer32> gemm_quint8_methods[] = +static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_methods[] = { #ifdef __ARM_FEATURE_SVE { GemmMethod::GEMM_HYBRID_QUANTIZED, "smallK_hybrid_u8u32_dot_1VLx8", - [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._Ksize<=64 && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64 && !args._trA && args._pretransposed_hint; }, nullptr, - [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_1VLx8, uint8_t, uint8_t>(args, qp); } + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_1VLx8, uint8_t, uint8_t>(args, qp); } }, { GemmMethod::GEMM_HYBRID_QUANTIZED, "hybrid_u8u32_dot_4VLx4", - [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; }, - [](const GemmArgs &args, const ARequantizeLayer32 &) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, - [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized<hybrid_u8u32_dot_4VLx4, uint8_t, uint8_t>(args, qp); } + [](const GemmArgs &args, const Requantize32 &) { return args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_u8u32_dot_4VLx4, uint8_t, uint8_t>(args, qp); } }, #endif { GemmMethod::GEMM_HYBRID_QUANTIZED, "smallK_hybrid_u8u32_dot_4x8", - [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._trA && args._pretransposed_hint; }, nullptr, - [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_4x8, uint8_t, uint8_t>(args, qp); } + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_4x8, uint8_t, uint8_t>(args, qp); } }, { GemmMethod::GEMM_HYBRID_QUANTIZED, "smallK_hybrid_u8u32_dot_4x6", - [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._trA && args._pretransposed_hint; }, nullptr, - [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_4x6, uint8_t, uint8_t>(args, qp); } + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_4x6, uint8_t, uint8_t>(args, qp); } }, { GemmMethod::GEMM_HYBRID_QUANTIZED, "hybrid_u8u32_dot_16x4", - [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._ci->has_dotprod() && args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; }, - [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._Nsize<=256 && args._Ksize>128; }, - [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized<hybrid_u8u32_dot_16x4, uint8_t, uint8_t>(args, qp); } + [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return args._Nsize<=256 && args._Ksize>128; }, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_u8u32_dot_16x4, uint8_t, uint8_t>(args, qp); } }, { GemmMethod::QUANTIZE_WRAPPER, "quantized_wrapper", nullptr, nullptr, - [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new QuantizeWrapper<uint8_t, uint8_t, uint32_t>(args, qp); } + [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<uint8_t, uint8_t, uint32_t>(args, qp); } }, { GemmMethod::DEFAULT, @@ -92,13 +92,13 @@ static const GemmImplementation<uint8_t, uint8_t, ARequantizeLayer32> gemm_quint }; template<> -const GemmImplementation<uint8_t, uint8_t, ARequantizeLayer32> *gemm_implementation_list<uint8_t, uint8_t, ARequantizeLayer32>() { +const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_list<uint8_t, uint8_t, Requantize32>() { return gemm_quint8_methods; } -template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, ARequantizeLayer32>(const GemmArgs &args, const ARequantizeLayer32 &os); -template KernelDescription get_gemm_method<uint8_t, uint8_t, ARequantizeLayer32>(const GemmArgs &args, const ARequantizeLayer32 &os); -template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, ARequantizeLayer32>(const GemmArgs &args, const ARequantizeLayer32 &os); +template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); +template KernelDescription get_gemm_method<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); +template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8.hpp new file mode 100644 index 0000000000..da5beef48c --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8.hpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2018-2019 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + + +#include "../std_transforms_fixed.hpp" + +namespace arm_gemm +{ + +// Actual kernel implementations +void a64_hybrid_fp32_mla_4x8(const float *, int, const float *, float *, int, int, int, int, const float *, Activation, bool); + +class hybrid_fp32_mla_4x8 +{ +public: + typedef float operand_type; + typedef float result_type; + + typedef void (*kern_type)(const float *, int, const float *, float *, int, int, int, int, const float *, Activation, bool); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 8; + } + + static unsigned int out_width() + { + return 4; + } + + static constexpr unsigned int k_unroll() + { + return 1; + } + + static constexpr bool supports_append() + { + return false; + } + + static constexpr bool supports_bias() + { + return true; + } + + static constexpr bool supports_activation() + { + return true; + } + + StdTransformsFixed<operand_type, result_type, 8, 4, 1> transforms = {}; + + // Default to the generic kernel + kern_type kernel=a64_hybrid_fp32_mla_4x8; + + hybrid_fp32_mla_4x8(const CPUInfo *ci) + { + UNUSED(ci); + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8/generic.cpp new file mode 100644 index 0000000000..db7eb83160 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8/generic.cpp @@ -0,0 +1,1923 @@ +/* + * Copyright (c) 2018-2019 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <algorithm> + +#include "arm_gemm.hpp" + +#include "../../asmlib.hpp" +#include "../../utils.hpp" + +namespace arm_gemm { + +void a64_hybrid_fp32_mla_4x8(const float *A, int lda, const float *B, float *C, int ldc, int M, int N, int K, const float *bias, Activation act, bool append) { + const int K_stride = K; + const long loops_count = ((K + 4) / 8) - 1; + K -= loops_count * 8; + const long regs_count = (K / 4) - 1; + K -= (regs_count + 1) * 4; + const long blocks_count = K / 1; + float nullbias[4]; + if (!append && !bias) { + memset(nullbias, 0, (4 * sizeof(float))); + } + float minval = - static_cast<float>(std::numeric_limits<float>::infinity()); + float maxval = static_cast<float>(std::numeric_limits<float>::infinity()); + const float * const minptr = &minval; + const float * const maxptr = &maxval; + + switch(act.type) + { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + maxval = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + minval = 0.0f; + break; + } + + for (int y=0; y<M; y+=8) { + const float * const a_ptr0_base = A + (y * lda); + const unsigned long ldab = lda * sizeof(float); + + float *c_ptr0 = C + (y * ldc); + + for (int x0=0; x0<N; x0+=4ul) { + const long width = std::min((unsigned long)N-x0, 4ul); + long loops = loops_count; + long regs = regs_count; + long blocks = blocks_count; + const float *a_ptr0 = a_ptr0_base; + const float *b_ptr0 = B + (K_stride * x0); + const bool use_result_buffer = (width < 4); + float result_buffer[32]; + const unsigned long ldcb = (use_result_buffer ? 4 : ldc) * sizeof(float); + float *c_ptr_real = c_ptr0; + if (use_result_buffer && append) { + for(int cy=0; cy<std::min(M-y, 8); cy++) { + for(unsigned int cx=0; cx<width; cx++) { + result_buffer[cy * 4 + cx] = c_ptr_real[cy * ldc + cx]; + } + } + } + if (use_result_buffer) { + c_ptr0 = result_buffer; + } + const float *biasptr = bias ? bias+x0 : nullbias; + + switch(M-y) { + case 1: + __asm __volatile ( + "ldr q24, [%[biasptr]]\n" + "ldr q0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "ldr q16, [%[b_ptr0]]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "ldr q8, [%[a_ptr0]]\n" + "subs %[loops], %[loops], #0x1\n" + "ldr q16, [%[b_ptr0]]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "ldr q16, [%[b_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory" + ); + break; + case 2: + __asm __volatile ( + "a_ptr1 .req X0\n" + "c_ptr1 .req X1\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "ldr q16, [%[b_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "ldr q9, [a_ptr1]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "ldr q16, [%[b_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "str q25, [c_ptr1]\n" + ".unreq a_ptr1\n" + ".unreq c_ptr1\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "cc", "memory" + ); + break; + case 3: + __asm __volatile ( + "a_ptr1 .req X0\n" + "a_ptr2 .req X1\n" + "c_ptr1 .req X2\n" + "c_ptr2 .req X3\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "mov v26.16b, v24.16b\n" + "ldr q2, [a_ptr2]\n" + "ldr q16, [%[b_ptr0]]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q9, [a_ptr1]\n" + "ldr q10, [a_ptr2]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "add a_ptr2, a_ptr2, #0x20\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "ldr q2, [a_ptr2, #-0x10]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "prfm PLDL1KEEP, [a_ptr2, #0x40]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "prfm PSTL1KEEP, [c_ptr2]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q10, [a_ptr2]\n" + "ldr q16, [%[b_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr s2, [a_ptr2]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "add a_ptr2, a_ptr2, #0x4\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmax v26.4s, v26.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "fmin v26.4s, v26.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "str q25, [c_ptr1]\n" + "str q26, [c_ptr2]\n" + ".unreq a_ptr1\n" + ".unreq a_ptr2\n" + ".unreq c_ptr1\n" + ".unreq c_ptr2\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "cc", "memory" + ); + break; + case 4: + __asm __volatile ( + "a_ptr1 .req X0\n" + "a_ptr2 .req X1\n" + "a_ptr3 .req X2\n" + "c_ptr1 .req X3\n" + "c_ptr2 .req X4\n" + "c_ptr3 .req X5\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "mov v26.16b, v24.16b\n" + "ldr q2, [a_ptr2]\n" + "mov v27.16b, v24.16b\n" + "ldr q16, [%[b_ptr0]]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "ldr q3, [a_ptr3]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add a_ptr3, a_ptr3, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q11, [a_ptr3]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "add a_ptr2, a_ptr2, #0x20\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "ldr q2, [a_ptr2, #-0x10]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "add a_ptr3, a_ptr3, #0x20\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "ldr q3, [a_ptr3, #-0x10]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "prfm PLDL1KEEP, [a_ptr2, #0x40]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "prfm PLDL1KEEP, [a_ptr3, #0x40]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "prfm PSTL1KEEP, [c_ptr2]\n" + "prfm PSTL1KEEP, [c_ptr3]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr3, a_ptr3, #0x10\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr s2, [a_ptr2]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "add a_ptr2, a_ptr2, #0x4\n" + "ldr s3, [a_ptr3]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "add a_ptr3, a_ptr3, #0x4\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmax v26.4s, v26.4s, v22.4s\n" + "fmax v27.4s, v27.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "fmin v26.4s, v26.4s, v23.4s\n" + "fmin v27.4s, v27.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "str q25, [c_ptr1]\n" + "str q26, [c_ptr2]\n" + "str q27, [c_ptr3]\n" + ".unreq a_ptr1\n" + ".unreq a_ptr2\n" + ".unreq a_ptr3\n" + ".unreq c_ptr1\n" + ".unreq c_ptr2\n" + ".unreq c_ptr3\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc", "memory" + ); + break; + case 5: + __asm __volatile ( + "a_ptr1 .req X0\n" + "a_ptr2 .req X1\n" + "a_ptr3 .req X2\n" + "a_ptr4 .req X3\n" + "c_ptr1 .req X4\n" + "c_ptr2 .req X5\n" + "c_ptr3 .req X6\n" + "c_ptr4 .req X7\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "mov v26.16b, v24.16b\n" + "ldr q2, [a_ptr2]\n" + "mov v27.16b, v24.16b\n" + "ldr q16, [%[b_ptr0]]\n" + "mov v28.16b, v24.16b\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "ldr q3, [a_ptr3]\n" + "add a_ptr4, a_ptr3, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q4, [a_ptr4]\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add c_ptr4, c_ptr3, %[ldc]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add a_ptr3, a_ptr3, #0x10\n" + "add a_ptr4, a_ptr4, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q12, [a_ptr4]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "add a_ptr2, a_ptr2, #0x20\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "add a_ptr3, a_ptr3, #0x20\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "ldr q2, [a_ptr2, #-0x10]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "ldr q3, [a_ptr3, #-0x10]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add a_ptr4, a_ptr4, #0x20\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "ldr q4, [a_ptr4, #-0x10]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "prfm PLDL1KEEP, [a_ptr2, #0x40]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "prfm PLDL1KEEP, [a_ptr3, #0x40]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "prfm PSTL1KEEP, [c_ptr2]\n" + "prfm PSTL1KEEP, [c_ptr3]\n" + "prfm PSTL1KEEP, [c_ptr4]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr3, a_ptr3, #0x10\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr4, a_ptr4, #0x10\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr s2, [a_ptr2]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "add a_ptr2, a_ptr2, #0x4\n" + "ldr s3, [a_ptr3]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "add a_ptr3, a_ptr3, #0x4\n" + "ldr s4, [a_ptr4]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "add a_ptr4, a_ptr4, #0x4\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmax v26.4s, v26.4s, v22.4s\n" + "fmax v27.4s, v27.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "fmin v26.4s, v26.4s, v23.4s\n" + "fmin v27.4s, v27.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "fmax v28.4s, v28.4s, v22.4s\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "str q25, [c_ptr1]\n" + "fmin v28.4s, v28.4s, v23.4s\n" + "str q26, [c_ptr2]\n" + "str q27, [c_ptr3]\n" + "str q28, [c_ptr4]\n" + ".unreq a_ptr1\n" + ".unreq a_ptr2\n" + ".unreq a_ptr3\n" + ".unreq a_ptr4\n" + ".unreq c_ptr1\n" + ".unreq c_ptr2\n" + ".unreq c_ptr3\n" + ".unreq c_ptr4\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory" + ); + break; + case 6: + __asm __volatile ( + "a_ptr1 .req X0\n" + "a_ptr2 .req X1\n" + "a_ptr3 .req X2\n" + "a_ptr4 .req X3\n" + "a_ptr5 .req X4\n" + "c_ptr1 .req X5\n" + "c_ptr2 .req X6\n" + "c_ptr3 .req X7\n" + "c_ptr4 .req X8\n" + "c_ptr5 .req X9\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "mov v26.16b, v24.16b\n" + "ldr q2, [a_ptr2]\n" + "mov v27.16b, v24.16b\n" + "ldr q16, [%[b_ptr0]]\n" + "mov v28.16b, v24.16b\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "mov v29.16b, v24.16b\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "ldr q3, [a_ptr3]\n" + "add a_ptr4, a_ptr3, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q4, [a_ptr4]\n" + "add a_ptr5, a_ptr4, %[lda]\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "ldr q5, [a_ptr5]\n" + "add c_ptr4, c_ptr3, %[ldc]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add c_ptr5, c_ptr4, %[ldc]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add a_ptr3, a_ptr3, #0x10\n" + "add a_ptr4, a_ptr4, #0x10\n" + "add a_ptr5, a_ptr5, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q13, [a_ptr5]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr2, a_ptr2, #0x20\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "add a_ptr3, a_ptr3, #0x20\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "add a_ptr4, a_ptr4, #0x20\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "add a_ptr5, a_ptr5, #0x20\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "ldr q2, [a_ptr2, #-0x10]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "ldr q3, [a_ptr3, #-0x10]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "ldr q4, [a_ptr4, #-0x10]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "ldr q5, [a_ptr5, #-0x10]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "prfm PLDL1KEEP, [a_ptr2, #0x40]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "prfm PLDL1KEEP, [a_ptr3, #0x40]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v29.4s, v16.4s, v13.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v29.4s, v17.4s, v13.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v29.4s, v18.4s, v13.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "fmla v29.4s, v19.4s, v13.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "prfm PSTL1KEEP, [c_ptr2]\n" + "prfm PSTL1KEEP, [c_ptr3]\n" + "prfm PSTL1KEEP, [c_ptr4]\n" + "prfm PSTL1KEEP, [c_ptr5]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "ldr q13, [a_ptr5]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "add a_ptr3, a_ptr3, #0x10\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr4, a_ptr4, #0x10\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr5, a_ptr5, #0x10\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v29.4s, v16.4s, v13.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v29.4s, v17.4s, v13.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v29.4s, v18.4s, v13.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "fmla v29.4s, v19.4s, v13.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr s2, [a_ptr2]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "add a_ptr2, a_ptr2, #0x4\n" + "ldr s3, [a_ptr3]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "add a_ptr3, a_ptr3, #0x4\n" + "ldr s4, [a_ptr4]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "add a_ptr4, a_ptr4, #0x4\n" + "ldr s5, [a_ptr5]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "add a_ptr5, a_ptr5, #0x4\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmax v26.4s, v26.4s, v22.4s\n" + "fmax v27.4s, v27.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "fmin v26.4s, v26.4s, v23.4s\n" + "fmin v27.4s, v27.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "fmax v28.4s, v28.4s, v22.4s\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "fmax v29.4s, v29.4s, v22.4s\n" + "str q25, [c_ptr1]\n" + "fmin v28.4s, v28.4s, v23.4s\n" + "fmin v29.4s, v29.4s, v23.4s\n" + "str q26, [c_ptr2]\n" + "str q27, [c_ptr3]\n" + "str q28, [c_ptr4]\n" + "str q29, [c_ptr5]\n" + ".unreq a_ptr1\n" + ".unreq a_ptr2\n" + ".unreq a_ptr3\n" + ".unreq a_ptr4\n" + ".unreq a_ptr5\n" + ".unreq c_ptr1\n" + ".unreq c_ptr2\n" + ".unreq c_ptr3\n" + ".unreq c_ptr4\n" + ".unreq c_ptr5\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "cc", "memory" + ); + break; + case 7: + __asm __volatile ( + "a_ptr1 .req X0\n" + "a_ptr2 .req X1\n" + "a_ptr3 .req X2\n" + "a_ptr4 .req X3\n" + "a_ptr5 .req X4\n" + "a_ptr6 .req X5\n" + "c_ptr1 .req X6\n" + "c_ptr2 .req X7\n" + "c_ptr3 .req X8\n" + "c_ptr4 .req X9\n" + "c_ptr5 .req X10\n" + "c_ptr6 .req X11\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "mov v26.16b, v24.16b\n" + "ldr q2, [a_ptr2]\n" + "mov v27.16b, v24.16b\n" + "ldr q16, [%[b_ptr0]]\n" + "mov v28.16b, v24.16b\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "mov v29.16b, v24.16b\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "mov v30.16b, v24.16b\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "ldr q3, [a_ptr3]\n" + "add a_ptr4, a_ptr3, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q4, [a_ptr4]\n" + "add a_ptr5, a_ptr4, %[lda]\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "ldr q5, [a_ptr5]\n" + "add a_ptr6, a_ptr5, %[lda]\n" + "add c_ptr4, c_ptr3, %[ldc]\n" + "ldr q6, [a_ptr6]\n" + "add c_ptr5, c_ptr4, %[ldc]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add c_ptr6, c_ptr5, %[ldc]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add a_ptr3, a_ptr3, #0x10\n" + "add a_ptr4, a_ptr4, #0x10\n" + "add a_ptr5, a_ptr5, #0x10\n" + "add a_ptr6, a_ptr6, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "ldr q13, [a_ptr5]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q14, [a_ptr6]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v30.4s, v17.4s, v6.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr2, a_ptr2, #0x20\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr3, a_ptr3, #0x20\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "add a_ptr4, a_ptr4, #0x20\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "add a_ptr5, a_ptr5, #0x20\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "add a_ptr6, a_ptr6, #0x20\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v30.4s, v18.4s, v6.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "ldr q2, [a_ptr2, #-0x10]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "ldr q3, [a_ptr3, #-0x10]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "ldr q4, [a_ptr4, #-0x10]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "ldr q5, [a_ptr5, #-0x10]\n" + "fmla v30.4s, v19.4s, v6.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "ldr q6, [a_ptr6, #-0x10]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "prfm PLDL1KEEP, [a_ptr2, #0x40]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "prfm PLDL1KEEP, [a_ptr3, #0x40]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v29.4s, v16.4s, v13.s[0]\n" + "fmla v30.4s, v16.4s, v14.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v29.4s, v17.4s, v13.s[1]\n" + "fmla v30.4s, v17.4s, v14.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v29.4s, v18.4s, v13.s[2]\n" + "fmla v30.4s, v18.4s, v14.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "fmla v29.4s, v19.4s, v13.s[3]\n" + "fmla v30.4s, v19.4s, v14.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "prfm PSTL1KEEP, [c_ptr2]\n" + "prfm PSTL1KEEP, [c_ptr3]\n" + "prfm PSTL1KEEP, [c_ptr4]\n" + "prfm PSTL1KEEP, [c_ptr5]\n" + "prfm PSTL1KEEP, [c_ptr6]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "ldr q13, [a_ptr5]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "ldr q14, [a_ptr6]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "add a_ptr3, a_ptr3, #0x10\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "add a_ptr4, a_ptr4, #0x10\n" + "fmla v30.4s, v17.4s, v6.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr5, a_ptr5, #0x10\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr6, a_ptr6, #0x10\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "fmla v30.4s, v18.4s, v6.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "fmla v30.4s, v19.4s, v6.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v29.4s, v16.4s, v13.s[0]\n" + "fmla v30.4s, v16.4s, v14.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v29.4s, v17.4s, v13.s[1]\n" + "fmla v30.4s, v17.4s, v14.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v29.4s, v18.4s, v13.s[2]\n" + "fmla v30.4s, v18.4s, v14.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "fmla v29.4s, v19.4s, v13.s[3]\n" + "fmla v30.4s, v19.4s, v14.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "fmla v30.4s, v17.4s, v6.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "fmla v30.4s, v18.4s, v6.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "fmla v30.4s, v19.4s, v6.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr s2, [a_ptr2]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "add a_ptr2, a_ptr2, #0x4\n" + "ldr s3, [a_ptr3]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "add a_ptr3, a_ptr3, #0x4\n" + "ldr s4, [a_ptr4]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "add a_ptr4, a_ptr4, #0x4\n" + "ldr s5, [a_ptr5]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "add a_ptr5, a_ptr5, #0x4\n" + "ldr s6, [a_ptr6]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "add a_ptr6, a_ptr6, #0x4\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmax v26.4s, v26.4s, v22.4s\n" + "fmax v27.4s, v27.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "fmin v26.4s, v26.4s, v23.4s\n" + "fmin v27.4s, v27.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "fmax v28.4s, v28.4s, v22.4s\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "fmax v29.4s, v29.4s, v22.4s\n" + "str q25, [c_ptr1]\n" + "fmax v30.4s, v30.4s, v22.4s\n" + "fmin v28.4s, v28.4s, v23.4s\n" + "fmin v29.4s, v29.4s, v23.4s\n" + "str q26, [c_ptr2]\n" + "fmin v30.4s, v30.4s, v23.4s\n" + "str q27, [c_ptr3]\n" + "str q28, [c_ptr4]\n" + "str q29, [c_ptr5]\n" + "str q30, [c_ptr6]\n" + ".unreq a_ptr1\n" + ".unreq a_ptr2\n" + ".unreq a_ptr3\n" + ".unreq a_ptr4\n" + ".unreq a_ptr5\n" + ".unreq a_ptr6\n" + ".unreq c_ptr1\n" + ".unreq c_ptr2\n" + ".unreq c_ptr3\n" + ".unreq c_ptr4\n" + ".unreq c_ptr5\n" + ".unreq c_ptr6\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "cc", "memory" + ); + break; + default: + case 8: + __asm __volatile ( + "a_ptr1 .req X0\n" + "a_ptr2 .req X1\n" + "a_ptr3 .req X2\n" + "a_ptr4 .req X3\n" + "a_ptr5 .req X4\n" + "a_ptr6 .req X5\n" + "a_ptr7 .req X6\n" + "c_ptr1 .req X7\n" + "c_ptr2 .req X8\n" + "c_ptr3 .req X9\n" + "c_ptr4 .req X10\n" + "c_ptr5 .req X11\n" + "c_ptr6 .req X12\n" + "c_ptr7 .req X13\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "mov v26.16b, v24.16b\n" + "ldr q2, [a_ptr2]\n" + "mov v27.16b, v24.16b\n" + "ldr q16, [%[b_ptr0]]\n" + "mov v28.16b, v24.16b\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "mov v29.16b, v24.16b\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "mov v30.16b, v24.16b\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "mov v31.16b, v24.16b\n" + "ldr q3, [a_ptr3]\n" + "add a_ptr4, a_ptr3, %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "ldr q4, [a_ptr4]\n" + "add a_ptr5, a_ptr4, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q5, [a_ptr5]\n" + "add a_ptr6, a_ptr5, %[lda]\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "ldr q6, [a_ptr6]\n" + "add a_ptr7, a_ptr6, %[lda]\n" + "add c_ptr4, c_ptr3, %[ldc]\n" + "ldr q7, [a_ptr7]\n" + "add c_ptr5, c_ptr4, %[ldc]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add c_ptr6, c_ptr5, %[ldc]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add c_ptr7, c_ptr6, %[ldc]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add a_ptr3, a_ptr3, #0x10\n" + "add a_ptr4, a_ptr4, #0x10\n" + "add a_ptr5, a_ptr5, #0x10\n" + "add a_ptr6, a_ptr6, #0x10\n" + "add a_ptr7, a_ptr7, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "ldr q13, [a_ptr5]\n" + "fmla v31.4s, v16.4s, v7.s[0]\n" + "ldr q14, [a_ptr6]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q15, [a_ptr7]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v30.4s, v17.4s, v6.s[1]\n" + "add a_ptr2, a_ptr2, #0x20\n" + "fmla v31.4s, v17.4s, v7.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr3, a_ptr3, #0x20\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr4, a_ptr4, #0x20\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "add a_ptr5, a_ptr5, #0x20\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "add a_ptr6, a_ptr6, #0x20\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "add a_ptr7, a_ptr7, #0x20\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v30.4s, v18.4s, v6.s[2]\n" + "prfm PLDL1KEEP, [a_ptr2, #0x40]\n" + "fmla v31.4s, v18.4s, v7.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "ldr q2, [a_ptr2, #-0x10]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "ldr q3, [a_ptr3, #-0x10]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "ldr q4, [a_ptr4, #-0x10]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "ldr q5, [a_ptr5, #-0x10]\n" + "fmla v30.4s, v19.4s, v6.s[3]\n" + "ldr q6, [a_ptr6, #-0x10]\n" + "fmla v31.4s, v19.4s, v7.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "ldr q7, [a_ptr7, #-0x10]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "prfm PLDL1KEEP, [a_ptr3, #0x40]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v29.4s, v16.4s, v13.s[0]\n" + "fmla v30.4s, v16.4s, v14.s[0]\n" + "fmla v31.4s, v16.4s, v15.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v29.4s, v17.4s, v13.s[1]\n" + "fmla v30.4s, v17.4s, v14.s[1]\n" + "fmla v31.4s, v17.4s, v15.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v29.4s, v18.4s, v13.s[2]\n" + "fmla v30.4s, v18.4s, v14.s[2]\n" + "fmla v31.4s, v18.4s, v15.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "fmla v29.4s, v19.4s, v13.s[3]\n" + "fmla v30.4s, v19.4s, v14.s[3]\n" + "fmla v31.4s, v19.4s, v15.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "prfm PSTL1KEEP, [c_ptr2]\n" + "prfm PSTL1KEEP, [c_ptr3]\n" + "prfm PSTL1KEEP, [c_ptr4]\n" + "prfm PSTL1KEEP, [c_ptr5]\n" + "prfm PSTL1KEEP, [c_ptr6]\n" + "prfm PSTL1KEEP, [c_ptr7]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "ldr q13, [a_ptr5]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "ldr q14, [a_ptr6]\n" + "fmla v31.4s, v16.4s, v7.s[0]\n" + "ldr q15, [a_ptr7]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "add a_ptr3, a_ptr3, #0x10\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "add a_ptr4, a_ptr4, #0x10\n" + "fmla v30.4s, v17.4s, v6.s[1]\n" + "add a_ptr5, a_ptr5, #0x10\n" + "fmla v31.4s, v17.4s, v7.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr6, a_ptr6, #0x10\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr7, a_ptr7, #0x10\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "fmla v30.4s, v18.4s, v6.s[2]\n" + "fmla v31.4s, v18.4s, v7.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "fmla v30.4s, v19.4s, v6.s[3]\n" + "fmla v31.4s, v19.4s, v7.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v29.4s, v16.4s, v13.s[0]\n" + "fmla v30.4s, v16.4s, v14.s[0]\n" + "fmla v31.4s, v16.4s, v15.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v29.4s, v17.4s, v13.s[1]\n" + "fmla v30.4s, v17.4s, v14.s[1]\n" + "fmla v31.4s, v17.4s, v15.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v29.4s, v18.4s, v13.s[2]\n" + "fmla v30.4s, v18.4s, v14.s[2]\n" + "fmla v31.4s, v18.4s, v15.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "fmla v29.4s, v19.4s, v13.s[3]\n" + "fmla v30.4s, v19.4s, v14.s[3]\n" + "fmla v31.4s, v19.4s, v15.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "fmla v31.4s, v16.4s, v7.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "fmla v30.4s, v17.4s, v6.s[1]\n" + "fmla v31.4s, v17.4s, v7.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "fmla v30.4s, v18.4s, v6.s[2]\n" + "fmla v31.4s, v18.4s, v7.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "fmla v30.4s, v19.4s, v6.s[3]\n" + "fmla v31.4s, v19.4s, v7.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr s2, [a_ptr2]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "add a_ptr2, a_ptr2, #0x4\n" + "ldr s3, [a_ptr3]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "add a_ptr3, a_ptr3, #0x4\n" + "ldr s4, [a_ptr4]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "add a_ptr4, a_ptr4, #0x4\n" + "ldr s5, [a_ptr5]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "add a_ptr5, a_ptr5, #0x4\n" + "ldr s6, [a_ptr6]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "add a_ptr6, a_ptr6, #0x4\n" + "ldr s7, [a_ptr7]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "add a_ptr7, a_ptr7, #0x4\n" + "fmla v31.4s, v16.4s, v7.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmax v26.4s, v26.4s, v22.4s\n" + "fmax v27.4s, v27.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "fmin v26.4s, v26.4s, v23.4s\n" + "fmin v27.4s, v27.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "fmax v28.4s, v28.4s, v22.4s\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "fmax v29.4s, v29.4s, v22.4s\n" + "str q25, [c_ptr1]\n" + "fmax v30.4s, v30.4s, v22.4s\n" + "fmin v28.4s, v28.4s, v23.4s\n" + "fmax v31.4s, v31.4s, v22.4s\n" + "str q26, [c_ptr2]\n" + "fmin v29.4s, v29.4s, v23.4s\n" + "fmin v30.4s, v30.4s, v23.4s\n" + "fmin v31.4s, v31.4s, v23.4s\n" + "str q27, [c_ptr3]\n" + "str q28, [c_ptr4]\n" + "str q29, [c_ptr5]\n" + "str q30, [c_ptr6]\n" + "str q31, [c_ptr7]\n" + ".unreq a_ptr1\n" + ".unreq a_ptr2\n" + ".unreq a_ptr3\n" + ".unreq a_ptr4\n" + ".unreq a_ptr5\n" + ".unreq a_ptr6\n" + ".unreq a_ptr7\n" + ".unreq c_ptr1\n" + ".unreq c_ptr2\n" + ".unreq c_ptr3\n" + ".unreq c_ptr4\n" + ".unreq c_ptr5\n" + ".unreq c_ptr6\n" + ".unreq c_ptr7\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", "memory" + ); + break; + } + if (use_result_buffer) { + for(int cy=0; cy<std::min(M-y, 8); cy++) { + for(unsigned int cx=0; cx<width; cx++) { + c_ptr_real[cy * ldc + cx] = result_buffer[cy * 4 + cx]; + } + } + } + } + } +} + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4.hpp index 5a6fabcfa9..bdc62ea181 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4.hpp @@ -61,7 +61,7 @@ public: static constexpr bool supports_append() { - return false; + return true; } static constexpr bool supports_bias() diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/a55.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/a55.cpp index 3ecf0151aa..7c08aa2165 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/a55.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/a55.cpp @@ -35,7 +35,6 @@ namespace arm_gemm { void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, int32_t *C, int ldc, int M, int N, int K, const int32_t *bias, Activation act, bool append) { UNUSED(bias); UNUSED(act); - const int K_stride = ((K + 3) / 4) * 4; const long loops_count = ((K + 16) / 32) - 1; K -= loops_count * 32; @@ -80,6 +79,7 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in "temploadreg1 .req X1\n" "temploadreg2 .req X2\n" "temploadreg3 .req X3\n" + "cbnz %[append], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" @@ -95,8 +95,26 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in "ldr d14, [%[b_ptr0], #0x60]\n" "ldr temploadreg2, [%[b_ptr0], #0x68]\n" "add %[b_ptr0], %[b_ptr0], #0x80\n" - "cbz %[loops], 1f\n" - "2:\n" + "cbz %[loops], 2f\n" + "b 3f\n" + "1:\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #0x10]\n" + "ldr q18, [%[c_ptr0], #0x20]\n" + "ldr q19, [%[c_ptr0], #0x30]\n" + "ldr q0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "ldr q8, [%[b_ptr0]]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr d14, [%[b_ptr0], #0x60]\n" + "ldr temploadreg2, [%[b_ptr0], #0x68]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "cbz %[loops], 2f\n" + "3:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ins v14.d[1], temploadreg2\n" ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" @@ -236,14 +254,14 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in "ins v11.d[1], temploadreg3\n" "ins v12.d[1], temploadreg0\n" "ins v13.d[1], temploadreg1\n" - "b.ne 2b\n" - "1:\n" + "b.ne 3b\n" + "2:\n" "ins v14.d[1], temploadreg2\n" "prfm PSTL1KEEP, [%[c_ptr0]]\n" "ldr d15, [%[b_ptr0], #-0x10]\n" "ldr temploadreg3, [%[b_ptr0], #-0x8]\n" "ins v15.d[1], temploadreg3\n" - "cbz %[regs], 3f\n" + "cbz %[regs], 4f\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr d4, [%[a_ptr0]]\n" ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" @@ -354,8 +372,8 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4fa4e9b1 // sdot v17.4s, v13.16b, v4.4b[3]\n" ".inst 0x4fa4e9d2 // sdot v18.4s, v14.16b, v4.4b[3]\n" ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" - "b 4f\n" - "3:\n" + "b 5f\n" + "4:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr d8, [%[b_ptr0]]\n" ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" @@ -397,9 +415,9 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4fa0e9b1 // sdot v17.4s, v13.16b, v0.4b[3]\n" ".inst 0x4fa0e9d2 // sdot v18.4s, v14.16b, v0.4b[3]\n" ".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n" - "4:\n" - "cbz %[blocks], 5f\n" - "6:\n" + "5:\n" + "cbz %[blocks], 6f\n" + "7:\n" "ldr q8, [%[b_ptr0]]\n" "subs %[blocks], %[blocks], #0x1\n" "ldr q9, [%[b_ptr0], #0x10]\n" @@ -412,17 +430,17 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" ".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" - "b.ne 6b\n" - "5:\n" - "cbz %[odds], 7f\n" + "b.ne 7b\n" + "6:\n" + "cbz %[odds], 8f\n" "ld1 {v0.b}[0], [%[a_ptr0]], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[1], [%[a_ptr0]], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[2], [%[a_ptr0]]\n" - "8:\n" + "9:\n" "ldr q8, [%[b_ptr0]]\n" "ldr q9, [%[b_ptr0], #0x10]\n" "ldr q10, [%[b_ptr0], #0x20]\n" @@ -431,7 +449,7 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" ".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" - "7:\n" + "8:\n" "str q16, [%[c_ptr0]]\n" "str q17, [%[c_ptr0], #0x10]\n" "str q18, [%[c_ptr0], #0x20]\n" @@ -454,74 +472,99 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in "temploadreg1 .req X3\n" "temploadreg2 .req X4\n" "temploadreg3 .req X5\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "cbnz %[append], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" - "ldr q8, [%[b_ptr0]]\n" + "ldr q1, [a_ptr1]\n" "movi v18.4s, #0\n" - "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q8, [%[b_ptr0]]\n" "movi v19.4s, #0\n" - "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" "movi v20.4s, #0\n" - "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" "movi v21.4s, #0\n" - "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" "movi v22.4s, #0\n" - "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" "movi v23.4s, #0\n" + "ldr q13, [%[b_ptr0], #0x50]\n" "ldr d14, [%[b_ptr0], #0x60]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" "ldr temploadreg2, [%[b_ptr0], #0x68]\n" - "add a_ptr1, %[a_ptr0], %[lda]\n" - "add c_ptr1, %[c_ptr0], %[ldc]\n" - "ldr q1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "cbz %[loops], 2f\n" + "b 3f\n" + "1:\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #0x10]\n" + "ldr q18, [%[c_ptr0], #0x20]\n" + "ldr q19, [%[c_ptr0], #0x30]\n" + "ldr q20, [c_ptr1]\n" + "ldr q21, [c_ptr1, #0x10]\n" + "ldr q22, [c_ptr1, #0x20]\n" + "ldr q23, [c_ptr1, #0x30]\n" + "ldr q0, [%[a_ptr0]]\n" "add %[a_ptr0], %[a_ptr0], #0x10\n" - "ins v14.d[1], temploadreg2\n" + "ldr q1, [a_ptr1]\n" "add a_ptr1, a_ptr1, #0x10\n" + "ldr q8, [%[b_ptr0]]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr d14, [%[b_ptr0], #0x60]\n" + "ldr temploadreg2, [%[b_ptr0], #0x68]\n" "add %[b_ptr0], %[b_ptr0], #0x80\n" - "cbz %[loops], 1f\n" - "2:\n" + "cbz %[loops], 2f\n" + "3:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" - "ldr d15, [%[b_ptr0], #-0x10]\n" + "ins v14.d[1], temploadreg2\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" - "ldr temploadreg3, [%[b_ptr0], #-0x8]\n" + "ldr d15, [%[b_ptr0], #-0x10]\n" ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" - "ldr d4, [%[a_ptr0]]\n" + "ldr temploadreg3, [%[b_ptr0], #-0x8]\n" ".inst 0x4f81e135 // sdot v21.4s, v9.16b, v1.4b[0]\n" - "ldr temploadreg0, [%[a_ptr0], #0x8]\n" + "ldr d4, [%[a_ptr0]]\n" ".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n" - "ldr d5, [a_ptr1]\n" + "ldr temploadreg0, [%[a_ptr0], #0x8]\n" ".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n" - "ldr temploadreg1, [a_ptr1, #0x8]\n" + "ldr d5, [a_ptr1]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" - "ldr d8, [%[b_ptr0]]\n" + "ldr temploadreg1, [a_ptr1, #0x8]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" - "ins v4.d[1], temploadreg0\n" + "ldr d8, [%[b_ptr0]]\n" ".inst 0x4fa0e190 // sdot v16.4s, v12.16b, v0.4b[1]\n" - "ldr temploadreg0, [%[b_ptr0], #0x8]\n" + "ins v4.d[1], temploadreg0\n" ".inst 0x4fa1e194 // sdot v20.4s, v12.16b, v1.4b[1]\n" - "ldr d9, [%[b_ptr0], #0x10]\n" + "ldr temploadreg0, [%[b_ptr0], #0x8]\n" ".inst 0x4fa0e1b1 // sdot v17.4s, v13.16b, v0.4b[1]\n" - "ins v5.d[1], temploadreg1\n" + "ldr d9, [%[b_ptr0], #0x10]\n" ".inst 0x4fa1e1b5 // sdot v21.4s, v13.16b, v1.4b[1]\n" - "ldr temploadreg1, [%[b_ptr0], #0x18]\n" + "ins v5.d[1], temploadreg1\n" ".inst 0x4fa0e1d2 // sdot v18.4s, v14.16b, v0.4b[1]\n" - "ldr d10, [%[b_ptr0], #0x20]\n" + "ldr temploadreg1, [%[b_ptr0], #0x18]\n" ".inst 0x4fa1e1d6 // sdot v22.4s, v14.16b, v1.4b[1]\n" + "ldr d10, [%[b_ptr0], #0x20]\n" "ldr temploadreg2, [%[b_ptr0], #0x28]\n" - "ldr d11, [%[b_ptr0], #0x30]\n" "subs %[loops], %[loops], #0x1\n" - "ins v15.d[1], temploadreg3\n" + "ldr d11, [%[b_ptr0], #0x30]\n" "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" - "ldr temploadreg3, [%[b_ptr0], #0x38]\n" + "ins v15.d[1], temploadreg3\n" "add %[a_ptr0], %[a_ptr0], #0x20\n" + "ldr temploadreg3, [%[b_ptr0], #0x38]\n" + "add a_ptr1, a_ptr1, #0x20\n" ".inst 0x4fa0e1f3 // sdot v19.4s, v15.16b, v0.4b[1]\n" "ldr d12, [%[b_ptr0], #0x40]\n" ".inst 0x4fa1e1f7 // sdot v23.4s, v15.16b, v1.4b[1]\n" "ins v8.d[1], temploadreg0\n" "ldr temploadreg0, [%[b_ptr0], #0x48]\n" - "add a_ptr1, a_ptr1, #0x20\n" - "ldr d13, [%[b_ptr0], #0x50]\n" "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "ldr d13, [%[b_ptr0], #0x50]\n" ".inst 0x4f80e910 // sdot v16.4s, v8.16b, v0.4b[2]\n" "ins v9.d[1], temploadreg1\n" ".inst 0x4f81e914 // sdot v20.4s, v8.16b, v1.4b[2]\n" @@ -658,15 +701,15 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in "ins v11.d[1], temploadreg3\n" "ins v12.d[1], temploadreg0\n" "ins v13.d[1], temploadreg1\n" + "b.ne 3b\n" + "2:\n" "ins v14.d[1], temploadreg2\n" - "b.ne 2b\n" - "1:\n" - "ldr d15, [%[b_ptr0], #-0x10]\n" "prfm PSTL1KEEP, [%[c_ptr0]]\n" - "ldr temploadreg3, [%[b_ptr0], #-0x8]\n" + "ldr d15, [%[b_ptr0], #-0x10]\n" "prfm PSTL1KEEP, [c_ptr1]\n" + "ldr temploadreg3, [%[b_ptr0], #-0x8]\n" "ins v15.d[1], temploadreg3\n" - "cbz %[regs], 3f\n" + "cbz %[regs], 4f\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr d4, [%[a_ptr0]]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -813,8 +856,8 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4fa5e9d6 // sdot v22.4s, v14.16b, v5.4b[3]\n" ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" - "b 4f\n" - "3:\n" + "b 5f\n" + "4:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr temploadreg0, [%[b_ptr0], #0x8]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -872,9 +915,9 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4fa1e9d6 // sdot v22.4s, v14.16b, v1.4b[3]\n" ".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n" ".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n" - "4:\n" - "cbz %[blocks], 5f\n" - "6:\n" + "5:\n" + "cbz %[blocks], 6f\n" + "7:\n" "ldr q8, [%[b_ptr0]]\n" "subs %[blocks], %[blocks], #0x1\n" "ldr q9, [%[b_ptr0], #0x10]\n" @@ -893,20 +936,20 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" - "b.ne 6b\n" - "5:\n" - "cbz %[odds], 7f\n" + "b.ne 7b\n" + "6:\n" + "cbz %[odds], 8f\n" "ld1 {v0.b}[0], [%[a_ptr0]], #1\n" "ld1 {v1.b}[0], [a_ptr1], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[1], [%[a_ptr0]], #1\n" "ld1 {v1.b}[1], [a_ptr1], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[2], [%[a_ptr0]]\n" "ld1 {v1.b}[2], [a_ptr1]\n" - "8:\n" + "9:\n" "ldr q8, [%[b_ptr0]]\n" "ldr q9, [%[b_ptr0], #0x10]\n" "ldr q10, [%[b_ptr0], #0x20]\n" @@ -919,7 +962,7 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" - "7:\n" + "8:\n" "str q16, [%[c_ptr0]]\n" "str q17, [%[c_ptr0], #0x10]\n" "str q18, [%[c_ptr0], #0x20]\n" @@ -950,40 +993,72 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in "temploadreg1 .req X5\n" "temploadreg2 .req X6\n" "temploadreg3 .req X7\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "cbnz %[append], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" - "ldr q8, [%[b_ptr0]]\n" + "ldr q1, [a_ptr1]\n" "movi v18.4s, #0\n" - "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q2, [a_ptr2]\n" "movi v19.4s, #0\n" - "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q8, [%[b_ptr0]]\n" "movi v20.4s, #0\n" - "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" "movi v21.4s, #0\n" - "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" "movi v22.4s, #0\n" - "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" "movi v23.4s, #0\n" - "ldr d14, [%[b_ptr0], #0x60]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" "movi v24.4s, #0\n" - "ldr temploadreg2, [%[b_ptr0], #0x68]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" "movi v25.4s, #0\n" - "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr d14, [%[b_ptr0], #0x60]\n" "movi v26.4s, #0\n" - "ldr q1, [a_ptr1]\n" + "ldr temploadreg2, [%[b_ptr0], #0x68]\n" "movi v27.4s, #0\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add a_ptr1, a_ptr1, #0x10\n" "ins v14.d[1], temploadreg2\n" - "add a_ptr2, a_ptr1, %[lda]\n" - "add c_ptr1, %[c_ptr0], %[ldc]\n" - "ldr q2, [a_ptr2]\n" - "add c_ptr2, c_ptr1, %[ldc]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "cbz %[loops], 2f\n" + "b 3f\n" + "1:\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #0x10]\n" + "ldr q18, [%[c_ptr0], #0x20]\n" + "ldr q19, [%[c_ptr0], #0x30]\n" + "ldr q20, [c_ptr1]\n" + "ldr q21, [c_ptr1, #0x10]\n" + "ldr q22, [c_ptr1, #0x20]\n" + "ldr q23, [c_ptr1, #0x30]\n" + "ldr q24, [c_ptr2]\n" + "ldr q25, [c_ptr2, #0x10]\n" + "ldr q26, [c_ptr2, #0x20]\n" + "ldr q27, [c_ptr2, #0x30]\n" + "ldr q0, [%[a_ptr0]]\n" "add %[a_ptr0], %[a_ptr0], #0x10\n" + "ldr q1, [a_ptr1]\n" "add a_ptr1, a_ptr1, #0x10\n" + "ldr q2, [a_ptr2]\n" "add a_ptr2, a_ptr2, #0x10\n" + "ldr q8, [%[b_ptr0]]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr d14, [%[b_ptr0], #0x60]\n" + "ldr temploadreg2, [%[b_ptr0], #0x68]\n" "add %[b_ptr0], %[b_ptr0], #0x80\n" - "cbz %[loops], 1f\n" - "2:\n" + "ins v14.d[1], temploadreg2\n" + "cbz %[loops], 2f\n" + "3:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr d15, [%[b_ptr0], #-0x10]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -1203,15 +1278,15 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in "ins v12.d[1], temploadreg0\n" "ins v13.d[1], temploadreg1\n" "ins v14.d[1], temploadreg2\n" - "b.ne 2b\n" - "1:\n" + "b.ne 3b\n" + "2:\n" "ldr d15, [%[b_ptr0], #-0x10]\n" "prfm PSTL1KEEP, [%[c_ptr0]]\n" "ldr temploadreg3, [%[b_ptr0], #-0x8]\n" "prfm PSTL1KEEP, [c_ptr1]\n" "prfm PSTL1KEEP, [c_ptr2]\n" "ins v15.d[1], temploadreg3\n" - "cbz %[regs], 3f\n" + "cbz %[regs], 4f\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr d4, [%[a_ptr0]]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -1394,8 +1469,8 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" ".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n" - "b 4f\n" - "3:\n" + "b 5f\n" + "4:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr temploadreg0, [%[b_ptr0], #0x8]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -1469,9 +1544,9 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n" ".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n" ".inst 0x4fa2e9fb // sdot v27.4s, v15.16b, v2.4b[3]\n" - "4:\n" - "cbz %[blocks], 5f\n" - "6:\n" + "5:\n" + "cbz %[blocks], 6f\n" + "7:\n" "ldr q8, [%[b_ptr0]]\n" "subs %[blocks], %[blocks], #0x1\n" "ldr q9, [%[b_ptr0], #0x10]\n" @@ -1496,23 +1571,23 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" ".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n" - "b.ne 6b\n" - "5:\n" - "cbz %[odds], 7f\n" + "b.ne 7b\n" + "6:\n" + "cbz %[odds], 8f\n" "ld1 {v0.b}[0], [%[a_ptr0]], #1\n" "ld1 {v1.b}[0], [a_ptr1], #1\n" "ld1 {v2.b}[0], [a_ptr2], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[1], [%[a_ptr0]], #1\n" "ld1 {v1.b}[1], [a_ptr1], #1\n" "ld1 {v2.b}[1], [a_ptr2], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[2], [%[a_ptr0]]\n" "ld1 {v1.b}[2], [a_ptr1]\n" "ld1 {v2.b}[2], [a_ptr2]\n" - "8:\n" + "9:\n" "ldr q8, [%[b_ptr0]]\n" "ldr q9, [%[b_ptr0], #0x10]\n" "ldr q10, [%[b_ptr0], #0x20]\n" @@ -1529,7 +1604,7 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" ".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n" - "7:\n" + "8:\n" "str q16, [%[c_ptr0]]\n" "str q17, [%[c_ptr0], #0x10]\n" "str q18, [%[c_ptr0], #0x20]\n" @@ -1569,48 +1644,86 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in "temploadreg1 .req X7\n" "temploadreg2 .req X8\n" "temploadreg3 .req X9\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "cbnz %[append], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" - "ldr q8, [%[b_ptr0]]\n" + "ldr q1, [a_ptr1]\n" "movi v18.4s, #0\n" - "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q2, [a_ptr2]\n" "movi v19.4s, #0\n" - "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q3, [a_ptr3]\n" "movi v20.4s, #0\n" - "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q8, [%[b_ptr0]]\n" "movi v21.4s, #0\n" - "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" "movi v22.4s, #0\n" - "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" "movi v23.4s, #0\n" - "ldr d14, [%[b_ptr0], #0x60]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" "movi v24.4s, #0\n" - "ldr temploadreg2, [%[b_ptr0], #0x68]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" "movi v25.4s, #0\n" - "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" "movi v26.4s, #0\n" - "ldr q1, [a_ptr1]\n" + "ldr d14, [%[b_ptr0], #0x60]\n" "movi v27.4s, #0\n" - "ins v14.d[1], temploadreg2\n" + "ldr temploadreg2, [%[b_ptr0], #0x68]\n" "movi v28.4s, #0\n" - "add a_ptr2, a_ptr1, %[lda]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" "movi v29.4s, #0\n" - "ldr q2, [a_ptr2]\n" + "ins v14.d[1], temploadreg2\n" "movi v30.4s, #0\n" - "add a_ptr3, a_ptr2, %[lda]\n" + "add a_ptr1, a_ptr1, #0x10\n" "movi v31.4s, #0\n" - "ldr q3, [a_ptr3]\n" - "add c_ptr1, %[c_ptr0], %[ldc]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add a_ptr3, a_ptr3, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "cbz %[loops], 2f\n" + "b 3f\n" + "1:\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #0x10]\n" + "ldr q18, [%[c_ptr0], #0x20]\n" + "ldr q19, [%[c_ptr0], #0x30]\n" + "ldr q20, [c_ptr1]\n" + "ldr q21, [c_ptr1, #0x10]\n" + "ldr q22, [c_ptr1, #0x20]\n" + "ldr q23, [c_ptr1, #0x30]\n" + "ldr q24, [c_ptr2]\n" + "ldr q25, [c_ptr2, #0x10]\n" + "ldr q26, [c_ptr2, #0x20]\n" + "ldr q27, [c_ptr2, #0x30]\n" + "ldr q28, [c_ptr3]\n" + "ldr q29, [c_ptr3, #0x10]\n" + "ldr q30, [c_ptr3, #0x20]\n" + "ldr q31, [c_ptr3, #0x30]\n" + "ldr q0, [%[a_ptr0]]\n" "add %[a_ptr0], %[a_ptr0], #0x10\n" - "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q1, [a_ptr1]\n" "add a_ptr1, a_ptr1, #0x10\n" - "add c_ptr3, c_ptr2, %[ldc]\n" + "ldr q2, [a_ptr2]\n" "add a_ptr2, a_ptr2, #0x10\n" + "ldr q3, [a_ptr3]\n" "add a_ptr3, a_ptr3, #0x10\n" + "ldr q8, [%[b_ptr0]]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr d14, [%[b_ptr0], #0x60]\n" + "ldr temploadreg2, [%[b_ptr0], #0x68]\n" "add %[b_ptr0], %[b_ptr0], #0x80\n" - "cbz %[loops], 1f\n" - "2:\n" + "ins v14.d[1], temploadreg2\n" + "cbz %[loops], 2f\n" + "3:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr d15, [%[b_ptr0], #-0x10]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -1870,8 +1983,8 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in "ins v13.d[1], temploadreg1\n" "prfm PLDL1KEEP, [a_ptr3, #0x40]\n" "ins v14.d[1], temploadreg2\n" - "b.ne 2b\n" - "1:\n" + "b.ne 3b\n" + "2:\n" "ldr d15, [%[b_ptr0], #-0x10]\n" "prfm PSTL1KEEP, [%[c_ptr0]]\n" "ldr temploadreg3, [%[b_ptr0], #-0x8]\n" @@ -1879,7 +1992,7 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in "prfm PSTL1KEEP, [c_ptr2]\n" "prfm PSTL1KEEP, [c_ptr3]\n" "ins v15.d[1], temploadreg3\n" - "cbz %[regs], 3f\n" + "cbz %[regs], 4f\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr d4, [%[a_ptr0]]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -2098,8 +2211,8 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" ".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n" ".inst 0x4fa7e9ff // sdot v31.4s, v15.16b, v7.4b[3]\n" - "b 4f\n" - "3:\n" + "b 5f\n" + "4:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr temploadreg0, [%[b_ptr0], #0x8]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -2189,9 +2302,9 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n" ".inst 0x4fa2e9fb // sdot v27.4s, v15.16b, v2.4b[3]\n" ".inst 0x4fa3e9ff // sdot v31.4s, v15.16b, v3.4b[3]\n" - "4:\n" - "cbz %[blocks], 5f\n" - "6:\n" + "5:\n" + "cbz %[blocks], 6f\n" + "7:\n" "ldr q8, [%[b_ptr0]]\n" "subs %[blocks], %[blocks], #0x1\n" "ldr q9, [%[b_ptr0], #0x10]\n" @@ -2222,26 +2335,26 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" ".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n" ".inst 0x4f83e17f // sdot v31.4s, v11.16b, v3.4b[0]\n" - "b.ne 6b\n" - "5:\n" - "cbz %[odds], 7f\n" + "b.ne 7b\n" + "6:\n" + "cbz %[odds], 8f\n" "ld1 {v0.b}[0], [%[a_ptr0]], #1\n" "ld1 {v1.b}[0], [a_ptr1], #1\n" "ld1 {v2.b}[0], [a_ptr2], #1\n" "ld1 {v3.b}[0], [a_ptr3], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[1], [%[a_ptr0]], #1\n" "ld1 {v1.b}[1], [a_ptr1], #1\n" "ld1 {v2.b}[1], [a_ptr2], #1\n" "ld1 {v3.b}[1], [a_ptr3], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[2], [%[a_ptr0]]\n" "ld1 {v1.b}[2], [a_ptr1]\n" "ld1 {v2.b}[2], [a_ptr2]\n" "ld1 {v3.b}[2], [a_ptr3]\n" - "8:\n" + "9:\n" "ldr q8, [%[b_ptr0]]\n" "ldr q9, [%[b_ptr0], #0x10]\n" "ldr q10, [%[b_ptr0], #0x20]\n" @@ -2262,7 +2375,7 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" ".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n" ".inst 0x4f83e17f // sdot v31.4s, v11.16b, v3.4b[0]\n" - "7:\n" + "8:\n" "str q16, [%[c_ptr0]]\n" "str q17, [%[c_ptr0], #0x10]\n" "str q18, [%[c_ptr0], #0x20]\n" diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp index b48b674621..9f06a48ff5 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp @@ -35,7 +35,6 @@ namespace arm_gemm { void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_t *C, int ldc, int M, int N, int K, const int32_t *bias, Activation act, bool append) { UNUSED(bias); UNUSED(act); - const int K_stride = ((K + 3) / 4) * 4; const long loops_count = ((K + 16) / 32) - 1; K -= loops_count * 32; @@ -76,6 +75,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ switch(M-y) { case 1: __asm __volatile ( + "cbnz %[append], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" @@ -90,8 +90,25 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ "ldr q13, [%[b_ptr0], #0x50]\n" "ldr q14, [%[b_ptr0], #0x60]\n" "add %[b_ptr0], %[b_ptr0], #0x80\n" - "cbz %[loops], 1f\n" - "2:\n" + "cbz %[loops], 2f\n" + "b 3f\n" + "1:\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #0x10]\n" + "ldr q18, [%[c_ptr0], #0x20]\n" + "ldr q19, [%[c_ptr0], #0x30]\n" + "ldr q0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "ldr q8, [%[b_ptr0]]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q14, [%[b_ptr0], #0x60]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "cbz %[loops], 2f\n" + "3:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q15, [%[b_ptr0], #-0x10]\n" ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" @@ -163,11 +180,11 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ "ldr q12, [%[b_ptr0], #-0x40]\n" "ldr q13, [%[b_ptr0], #-0x30]\n" "ldr q14, [%[b_ptr0], #-0x20]\n" - "b.ne 2b\n" - "1:\n" + "b.ne 3b\n" + "2:\n" "ldr q15, [%[b_ptr0], #-0x10]\n" "prfm PSTL1KEEP, [%[c_ptr0]]\n" - "cbz %[regs], 3f\n" + "cbz %[regs], 4f\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q4, [%[a_ptr0]]\n" ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" @@ -228,8 +245,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa4e9b1 // sdot v17.4s, v13.16b, v4.4b[3]\n" ".inst 0x4fa4e9d2 // sdot v18.4s, v14.16b, v4.4b[3]\n" ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" - "b 4f\n" - "3:\n" + "b 5f\n" + "4:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q8, [%[b_ptr0]]\n" ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" @@ -255,9 +272,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa0e9b1 // sdot v17.4s, v13.16b, v0.4b[3]\n" ".inst 0x4fa0e9d2 // sdot v18.4s, v14.16b, v0.4b[3]\n" ".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n" - "4:\n" - "cbz %[blocks], 5f\n" - "6:\n" + "5:\n" + "cbz %[blocks], 6f\n" + "7:\n" "ldr q8, [%[b_ptr0]]\n" "subs %[blocks], %[blocks], #0x1\n" "ldr q9, [%[b_ptr0], #0x10]\n" @@ -270,17 +287,17 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" ".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" - "b.ne 6b\n" - "5:\n" - "cbz %[odds], 7f\n" + "b.ne 7b\n" + "6:\n" + "cbz %[odds], 8f\n" "ld1 {v0.b}[0], [%[a_ptr0]], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[1], [%[a_ptr0]], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[2], [%[a_ptr0]]\n" - "8:\n" + "9:\n" "ldr q8, [%[b_ptr0]]\n" "ldr q9, [%[b_ptr0], #0x10]\n" "ldr q10, [%[b_ptr0], #0x20]\n" @@ -289,7 +306,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" ".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" - "7:\n" + "8:\n" "str q16, [%[c_ptr0]]\n" "str q17, [%[c_ptr0], #0x10]\n" "str q18, [%[c_ptr0], #0x20]\n" @@ -304,30 +321,54 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ __asm __volatile ( "a_ptr1 .req X0\n" "c_ptr1 .req X1\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "cbnz %[append], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" - "ldr q8, [%[b_ptr0]]\n" + "ldr q1, [a_ptr1]\n" "movi v18.4s, #0\n" - "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q8, [%[b_ptr0]]\n" "movi v19.4s, #0\n" - "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" "movi v20.4s, #0\n" - "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" "movi v21.4s, #0\n" - "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" "movi v22.4s, #0\n" - "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" "movi v23.4s, #0\n" + "ldr q13, [%[b_ptr0], #0x50]\n" "ldr q14, [%[b_ptr0], #0x60]\n" - "add a_ptr1, %[a_ptr0], %[lda]\n" - "add c_ptr1, %[c_ptr0], %[ldc]\n" - "ldr q1, [a_ptr1]\n" "add %[a_ptr0], %[a_ptr0], #0x10\n" "add a_ptr1, a_ptr1, #0x10\n" "add %[b_ptr0], %[b_ptr0], #0x80\n" - "cbz %[loops], 1f\n" - "2:\n" + "cbz %[loops], 2f\n" + "b 3f\n" + "1:\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #0x10]\n" + "ldr q18, [%[c_ptr0], #0x20]\n" + "ldr q19, [%[c_ptr0], #0x30]\n" + "ldr q20, [c_ptr1]\n" + "ldr q21, [c_ptr1, #0x10]\n" + "ldr q22, [c_ptr1, #0x20]\n" + "ldr q23, [c_ptr1, #0x30]\n" + "ldr q0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "ldr q1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "ldr q8, [%[b_ptr0]]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q14, [%[b_ptr0], #0x60]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "cbz %[loops], 2f\n" + "3:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q15, [%[b_ptr0], #-0x10]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -435,12 +476,12 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ "ldr q14, [%[b_ptr0], #-0x20]\n" ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" - "b.ne 2b\n" - "1:\n" + "b.ne 3b\n" + "2:\n" "ldr q15, [%[b_ptr0], #-0x10]\n" "prfm PSTL1KEEP, [%[c_ptr0]]\n" "prfm PSTL1KEEP, [c_ptr1]\n" - "cbz %[regs], 3f\n" + "cbz %[regs], 4f\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q4, [%[a_ptr0]]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -535,8 +576,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa5e9d6 // sdot v22.4s, v14.16b, v5.4b[3]\n" ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" - "b 4f\n" - "3:\n" + "b 5f\n" + "4:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" "ldr q8, [%[b_ptr0]]\n" @@ -578,9 +619,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa1e9d6 // sdot v22.4s, v14.16b, v1.4b[3]\n" ".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n" ".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n" - "4:\n" - "cbz %[blocks], 5f\n" - "6:\n" + "5:\n" + "cbz %[blocks], 6f\n" + "7:\n" "ldr q8, [%[b_ptr0]]\n" "subs %[blocks], %[blocks], #0x1\n" "ldr q9, [%[b_ptr0], #0x10]\n" @@ -599,20 +640,20 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" - "b.ne 6b\n" - "5:\n" - "cbz %[odds], 7f\n" + "b.ne 7b\n" + "6:\n" + "cbz %[odds], 8f\n" "ld1 {v0.b}[0], [%[a_ptr0]], #1\n" "ld1 {v1.b}[0], [a_ptr1], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[1], [%[a_ptr0]], #1\n" "ld1 {v1.b}[1], [a_ptr1], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[2], [%[a_ptr0]]\n" "ld1 {v1.b}[2], [a_ptr1]\n" - "8:\n" + "9:\n" "ldr q8, [%[b_ptr0]]\n" "ldr q9, [%[b_ptr0], #0x10]\n" "ldr q10, [%[b_ptr0], #0x20]\n" @@ -625,7 +666,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" - "7:\n" + "8:\n" "str q16, [%[c_ptr0]]\n" "str q17, [%[c_ptr0], #0x10]\n" "str q18, [%[c_ptr0], #0x20]\n" @@ -648,38 +689,68 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ "a_ptr2 .req X1\n" "c_ptr1 .req X2\n" "c_ptr2 .req X3\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "cbnz %[append], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" - "ldr q8, [%[b_ptr0]]\n" + "ldr q1, [a_ptr1]\n" "movi v18.4s, #0\n" - "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q2, [a_ptr2]\n" "movi v19.4s, #0\n" - "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q8, [%[b_ptr0]]\n" "movi v20.4s, #0\n" - "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" "movi v21.4s, #0\n" - "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" "movi v22.4s, #0\n" - "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" "movi v23.4s, #0\n" - "ldr q14, [%[b_ptr0], #0x60]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" "movi v24.4s, #0\n" - "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" "movi v25.4s, #0\n" - "ldr q1, [a_ptr1]\n" + "ldr q14, [%[b_ptr0], #0x60]\n" "movi v26.4s, #0\n" - "add a_ptr2, a_ptr1, %[lda]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" "movi v27.4s, #0\n" - "ldr q2, [a_ptr2]\n" - "add c_ptr1, %[c_ptr0], %[ldc]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "cbz %[loops], 2f\n" + "b 3f\n" + "1:\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #0x10]\n" + "ldr q18, [%[c_ptr0], #0x20]\n" + "ldr q19, [%[c_ptr0], #0x30]\n" + "ldr q20, [c_ptr1]\n" + "ldr q21, [c_ptr1, #0x10]\n" + "ldr q22, [c_ptr1, #0x20]\n" + "ldr q23, [c_ptr1, #0x30]\n" + "ldr q24, [c_ptr2]\n" + "ldr q25, [c_ptr2, #0x10]\n" + "ldr q26, [c_ptr2, #0x20]\n" + "ldr q27, [c_ptr2, #0x30]\n" + "ldr q0, [%[a_ptr0]]\n" "add %[a_ptr0], %[a_ptr0], #0x10\n" - "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q1, [a_ptr1]\n" "add a_ptr1, a_ptr1, #0x10\n" + "ldr q2, [a_ptr2]\n" "add a_ptr2, a_ptr2, #0x10\n" + "ldr q8, [%[b_ptr0]]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q14, [%[b_ptr0], #0x60]\n" "add %[b_ptr0], %[b_ptr0], #0x80\n" - "cbz %[loops], 1f\n" - "2:\n" + "cbz %[loops], 2f\n" + "3:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q15, [%[b_ptr0], #-0x10]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -823,13 +894,13 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" ".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n" - "b.ne 2b\n" - "1:\n" + "b.ne 3b\n" + "2:\n" "ldr q15, [%[b_ptr0], #-0x10]\n" "prfm PSTL1KEEP, [%[c_ptr0]]\n" "prfm PSTL1KEEP, [c_ptr1]\n" "prfm PSTL1KEEP, [c_ptr2]\n" - "cbz %[regs], 3f\n" + "cbz %[regs], 4f\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q4, [%[a_ptr0]]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -958,8 +1029,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" ".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n" - "b 4f\n" - "3:\n" + "b 5f\n" + "4:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" ".inst 0x4f82e118 // sdot v24.4s, v8.16b, v2.4b[0]\n" @@ -1017,9 +1088,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n" ".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n" ".inst 0x4fa2e9fb // sdot v27.4s, v15.16b, v2.4b[3]\n" - "4:\n" - "cbz %[blocks], 5f\n" - "6:\n" + "5:\n" + "cbz %[blocks], 6f\n" + "7:\n" "ldr q8, [%[b_ptr0]]\n" "subs %[blocks], %[blocks], #0x1\n" "ldr q9, [%[b_ptr0], #0x10]\n" @@ -1044,23 +1115,23 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" ".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n" - "b.ne 6b\n" - "5:\n" - "cbz %[odds], 7f\n" + "b.ne 7b\n" + "6:\n" + "cbz %[odds], 8f\n" "ld1 {v0.b}[0], [%[a_ptr0]], #1\n" "ld1 {v1.b}[0], [a_ptr1], #1\n" "ld1 {v2.b}[0], [a_ptr2], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[1], [%[a_ptr0]], #1\n" "ld1 {v1.b}[1], [a_ptr1], #1\n" "ld1 {v2.b}[1], [a_ptr2], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[2], [%[a_ptr0]]\n" "ld1 {v1.b}[2], [a_ptr1]\n" "ld1 {v2.b}[2], [a_ptr2]\n" - "8:\n" + "9:\n" "ldr q8, [%[b_ptr0]]\n" "ldr q9, [%[b_ptr0], #0x10]\n" "ldr q10, [%[b_ptr0], #0x20]\n" @@ -1077,7 +1148,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" ".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n" - "7:\n" + "8:\n" "str q16, [%[c_ptr0]]\n" "str q17, [%[c_ptr0], #0x10]\n" "str q18, [%[c_ptr0], #0x20]\n" @@ -1109,46 +1180,82 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ "c_ptr1 .req X3\n" "c_ptr2 .req X4\n" "c_ptr3 .req X5\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "cbnz %[append], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" - "ldr q8, [%[b_ptr0]]\n" + "ldr q1, [a_ptr1]\n" "movi v18.4s, #0\n" - "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q2, [a_ptr2]\n" "movi v19.4s, #0\n" - "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q3, [a_ptr3]\n" "movi v20.4s, #0\n" - "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q8, [%[b_ptr0]]\n" "movi v21.4s, #0\n" - "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" "movi v22.4s, #0\n" - "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" "movi v23.4s, #0\n" - "ldr q14, [%[b_ptr0], #0x60]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" "movi v24.4s, #0\n" - "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" "movi v25.4s, #0\n" - "ldr q1, [a_ptr1]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" "movi v26.4s, #0\n" - "add a_ptr2, a_ptr1, %[lda]\n" + "ldr q14, [%[b_ptr0], #0x60]\n" "movi v27.4s, #0\n" - "ldr q2, [a_ptr2]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" "movi v28.4s, #0\n" - "add a_ptr3, a_ptr2, %[lda]\n" + "add a_ptr1, a_ptr1, #0x10\n" "movi v29.4s, #0\n" - "ldr q3, [a_ptr3]\n" + "add a_ptr2, a_ptr2, #0x10\n" "movi v30.4s, #0\n" - "add c_ptr1, %[c_ptr0], %[ldc]\n" + "add a_ptr3, a_ptr3, #0x10\n" "movi v31.4s, #0\n" - "add c_ptr2, c_ptr1, %[ldc]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "cbz %[loops], 2f\n" + "b 3f\n" + "1:\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #0x10]\n" + "ldr q18, [%[c_ptr0], #0x20]\n" + "ldr q19, [%[c_ptr0], #0x30]\n" + "ldr q20, [c_ptr1]\n" + "ldr q21, [c_ptr1, #0x10]\n" + "ldr q22, [c_ptr1, #0x20]\n" + "ldr q23, [c_ptr1, #0x30]\n" + "ldr q24, [c_ptr2]\n" + "ldr q25, [c_ptr2, #0x10]\n" + "ldr q26, [c_ptr2, #0x20]\n" + "ldr q27, [c_ptr2, #0x30]\n" + "ldr q28, [c_ptr3]\n" + "ldr q29, [c_ptr3, #0x10]\n" + "ldr q30, [c_ptr3, #0x20]\n" + "ldr q31, [c_ptr3, #0x30]\n" + "ldr q0, [%[a_ptr0]]\n" "add %[a_ptr0], %[a_ptr0], #0x10\n" - "add c_ptr3, c_ptr2, %[ldc]\n" + "ldr q1, [a_ptr1]\n" "add a_ptr1, a_ptr1, #0x10\n" + "ldr q2, [a_ptr2]\n" "add a_ptr2, a_ptr2, #0x10\n" + "ldr q3, [a_ptr3]\n" "add a_ptr3, a_ptr3, #0x10\n" + "ldr q8, [%[b_ptr0]]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q14, [%[b_ptr0], #0x60]\n" "add %[b_ptr0], %[b_ptr0], #0x80\n" - "cbz %[loops], 1f\n" - "2:\n" + "cbz %[loops], 2f\n" + "3:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q15, [%[b_ptr0], #-0x10]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -1328,14 +1435,14 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" ".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n" ".inst 0x4fa7e9ff // sdot v31.4s, v15.16b, v7.4b[3]\n" - "b.ne 2b\n" - "1:\n" + "b.ne 3b\n" + "2:\n" "ldr q15, [%[b_ptr0], #-0x10]\n" "prfm PSTL1KEEP, [%[c_ptr0]]\n" "prfm PSTL1KEEP, [c_ptr1]\n" "prfm PSTL1KEEP, [c_ptr2]\n" "prfm PSTL1KEEP, [c_ptr3]\n" - "cbz %[regs], 3f\n" + "cbz %[regs], 4f\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q4, [%[a_ptr0]]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -1498,8 +1605,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" ".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n" ".inst 0x4fa7e9ff // sdot v31.4s, v15.16b, v7.4b[3]\n" - "b 4f\n" - "3:\n" + "b 5f\n" + "4:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" ".inst 0x4f82e118 // sdot v24.4s, v8.16b, v2.4b[0]\n" @@ -1573,9 +1680,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n" ".inst 0x4fa2e9fb // sdot v27.4s, v15.16b, v2.4b[3]\n" ".inst 0x4fa3e9ff // sdot v31.4s, v15.16b, v3.4b[3]\n" - "4:\n" - "cbz %[blocks], 5f\n" - "6:\n" + "5:\n" + "cbz %[blocks], 6f\n" + "7:\n" "ldr q8, [%[b_ptr0]]\n" "subs %[blocks], %[blocks], #0x1\n" "ldr q9, [%[b_ptr0], #0x10]\n" @@ -1606,26 +1713,26 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" ".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n" ".inst 0x4f83e17f // sdot v31.4s, v11.16b, v3.4b[0]\n" - "b.ne 6b\n" - "5:\n" - "cbz %[odds], 7f\n" + "b.ne 7b\n" + "6:\n" + "cbz %[odds], 8f\n" "ld1 {v0.b}[0], [%[a_ptr0]], #1\n" "ld1 {v1.b}[0], [a_ptr1], #1\n" "ld1 {v2.b}[0], [a_ptr2], #1\n" "ld1 {v3.b}[0], [a_ptr3], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[1], [%[a_ptr0]], #1\n" "ld1 {v1.b}[1], [a_ptr1], #1\n" "ld1 {v2.b}[1], [a_ptr2], #1\n" "ld1 {v3.b}[1], [a_ptr3], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[2], [%[a_ptr0]]\n" "ld1 {v1.b}[2], [a_ptr1]\n" "ld1 {v2.b}[2], [a_ptr2]\n" "ld1 {v3.b}[2], [a_ptr3]\n" - "8:\n" + "9:\n" "ldr q8, [%[b_ptr0]]\n" "ldr q9, [%[b_ptr0], #0x10]\n" "ldr q10, [%[b_ptr0], #0x20]\n" @@ -1646,7 +1753,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" ".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n" ".inst 0x4f83e17f // sdot v31.4s, v11.16b, v3.4b[0]\n" - "7:\n" + "8:\n" "str q16, [%[c_ptr0]]\n" "str q17, [%[c_ptr0], #0x10]\n" "str q18, [%[c_ptr0], #0x20]\n" diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp index 188dd0b06d..345060f206 100644 --- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp +++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp @@ -40,7 +40,7 @@ private: UniqueGemmCommon<To, Tgemm> _subgemm = nullptr; int32_t *_row_sums = nullptr; int32_t *_col_sums = nullptr; - ARequantizeLayer32 _params; + Requantize32 _params; GemmArgs _args; barrier _barrier; @@ -125,7 +125,7 @@ public: QuantizeWrapper(const QuantizeWrapper &) = delete; QuantizeWrapper operator=(const QuantizeWrapper &) = delete; - QuantizeWrapper(const GemmArgs &args, const ARequantizeLayer32 &qp) : _params(qp), _args(args), _barrier(args._maxthreads) { + QuantizeWrapper(const GemmArgs &args, const Requantize32 &qp) : _params(qp), _args(args), _barrier(args._maxthreads) { GemmArgs newargs = GemmArgs(args._ci, args._Msize, args._Nsize, args._Ksize, args._nbatches, args._nmulti, args._trA, args._trB, Activation(), args._maxthreads, args._pretransposed_hint, nullptr); _subgemm = gemm<To, Tgemm>(newargs); diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp index bffb7ddcb3..00b42cf422 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.cpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp @@ -47,13 +47,19 @@ namespace { * applied to negative values being shifted right to make sure they round * properly - if negative values are never output (e.g. fused ReLU) this is * unnecessary. + * + * The 'per_channel' template parameter selects between per channel and per + * layer requantization - in the former case we need to load vectors of + * shifts and multipliers for each column. A separate vector for each + * column is set up in any case (and it is hoped that the compiler can elide + * the needless movs in the per-layer case). */ -template<bool do_shift_correction> -void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +template<bool do_shift_correction, bool per_channel> +void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height, const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias) { - const int32x4_t v_mul = vdupq_n_s32(qp.requant_mul); - const int32x4_t v_shift = vdupq_n_s32(qp.requant_shift); + const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul); + const int32x4_t v_shift = vdupq_n_s32(qp.per_layer_shift); const int32x4_t v_minval = vdupq_n_s32(qp.minval); const int32x4_t v_maxval = vdupq_n_s32(qp.maxval); const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset); @@ -70,6 +76,8 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u unsigned int odds=(width % 4); const int32_t *colptr = col_bias; + const int32_t *perch_mul_ptr = qp.per_channel_muls; + const int32_t *perch_shift_ptr = qp.per_channel_shifts; const int32_t *in_ptr = input + (row * in_stride); int8_t *out_ptr = output + (row * out_stride); @@ -93,6 +101,33 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1); while (blocks--) { + int32x4_t v_mul0; + int32x4_t v_mul1; + int32x4_t v_mul2; + int32x4_t v_mul3; + + int32x4_t v_shf0; + int32x4_t v_shf1; + int32x4_t v_shf2; + int32x4_t v_shf3; + + if (per_channel) { + v_mul0 = vld1q_s32(perch_mul_ptr); + v_mul1 = vld1q_s32(perch_mul_ptr + 4); + v_mul2 = vld1q_s32(perch_mul_ptr + 8); + v_mul3 = vld1q_s32(perch_mul_ptr + 12); + perch_mul_ptr += 16; + + v_shf0 = vld1q_s32(perch_shift_ptr); + v_shf1 = vld1q_s32(perch_shift_ptr + 4); + v_shf2 = vld1q_s32(perch_shift_ptr + 8); + v_shf3 = vld1q_s32(perch_shift_ptr + 12); + perch_shift_ptr += 16; + } else { + v_mul0=v_mul1=v_mul2=v_mul3=v_mul; + v_shf0=v_shf1=v_shf2=v_shf3=v_shift; + } + // Load column pointers int32x4_t v_col0 = vld1q_s32(colptr); int32x4_t v_col1 = vld1q_s32(colptr + 4); @@ -136,27 +171,27 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in13 = vaddq_s32(v_in13, v_col3); // Quantize - start with multiply - v_in00 = vqrdmulhq_s32(v_in00, v_mul); - v_in01 = vqrdmulhq_s32(v_in01, v_mul); - v_in02 = vqrdmulhq_s32(v_in02, v_mul); - v_in03 = vqrdmulhq_s32(v_in03, v_mul); + v_in00 = vqrdmulhq_s32(v_in00, v_mul0); + v_in01 = vqrdmulhq_s32(v_in01, v_mul1); + v_in02 = vqrdmulhq_s32(v_in02, v_mul2); + v_in03 = vqrdmulhq_s32(v_in03, v_mul3); - v_in10 = vqrdmulhq_s32(v_in10, v_mul); - v_in11 = vqrdmulhq_s32(v_in11, v_mul); - v_in12 = vqrdmulhq_s32(v_in12, v_mul); - v_in13 = vqrdmulhq_s32(v_in13, v_mul); + v_in10 = vqrdmulhq_s32(v_in10, v_mul0); + v_in11 = vqrdmulhq_s32(v_in11, v_mul1); + v_in12 = vqrdmulhq_s32(v_in12, v_mul2); + v_in13 = vqrdmulhq_s32(v_in13, v_mul3); // Compute and add on corrective offset if (do_shift_correction) { - int32x4_t v_temp00 = vandq_s32(v_in00, v_shift); - int32x4_t v_temp01 = vandq_s32(v_in01, v_shift); - int32x4_t v_temp02 = vandq_s32(v_in02, v_shift); - int32x4_t v_temp03 = vandq_s32(v_in03, v_shift); + int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0); + int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1); + int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2); + int32x4_t v_temp03 = vandq_s32(v_in03, v_shf3); - int32x4_t v_temp10 = vandq_s32(v_in10, v_shift); - int32x4_t v_temp11 = vandq_s32(v_in11, v_shift); - int32x4_t v_temp12 = vandq_s32(v_in12, v_shift); - int32x4_t v_temp13 = vandq_s32(v_in13, v_shift); + int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0); + int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1); + int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2); + int32x4_t v_temp13 = vandq_s32(v_in13, v_shf3); v_temp00 = vshrq_n_s32(v_temp00, 31); v_temp01 = vshrq_n_s32(v_temp01, 31); @@ -179,15 +214,15 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in13 = vqaddq_s32(v_in13, v_temp13); } - v_in00 = vrshlq_s32(v_in00, v_shift); - v_in01 = vrshlq_s32(v_in01, v_shift); - v_in02 = vrshlq_s32(v_in02, v_shift); - v_in03 = vrshlq_s32(v_in03, v_shift); + v_in00 = vrshlq_s32(v_in00, v_shf0); + v_in01 = vrshlq_s32(v_in01, v_shf1); + v_in02 = vrshlq_s32(v_in02, v_shf2); + v_in03 = vrshlq_s32(v_in03, v_shf3); - v_in10 = vrshlq_s32(v_in10, v_shift); - v_in11 = vrshlq_s32(v_in11, v_shift); - v_in12 = vrshlq_s32(v_in12, v_shift); - v_in13 = vrshlq_s32(v_in13, v_shift); + v_in10 = vrshlq_s32(v_in10, v_shf0); + v_in11 = vrshlq_s32(v_in11, v_shf1); + v_in12 = vrshlq_s32(v_in12, v_shf2); + v_in13 = vrshlq_s32(v_in13, v_shf3); v_in00 = vaddq_s32(v_in00, v_c_offset); v_in01 = vaddq_s32(v_in01, v_c_offset); @@ -235,6 +270,20 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u } while (regs--) { + int32x4_t v_mul0; + int32x4_t v_shf0; + + if (per_channel) { + v_mul0 = vld1q_s32(perch_mul_ptr); + perch_mul_ptr += 4; + + v_shf0 = vld1q_s32(perch_shift_ptr); + perch_shift_ptr += 4; + } else { + v_mul0=v_mul; + v_shf0=v_shift; + } + // Load column pointers int32x4_t v_col0 = vld1q_s32(colptr); colptr += 4; @@ -258,15 +307,15 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in10 = vaddq_s32(v_in10, v_col0); // Quantize - start with multiply - v_in00 = vqrdmulhq_s32(v_in00, v_mul); + v_in00 = vqrdmulhq_s32(v_in00, v_mul0); - v_in10 = vqrdmulhq_s32(v_in10, v_mul); + v_in10 = vqrdmulhq_s32(v_in10, v_mul0); // Compute and add on corrective offset if (do_shift_correction) { - int32x4_t v_temp00 = vandq_s32(v_in00, v_shift); + int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0); - int32x4_t v_temp10 = vandq_s32(v_in10, v_shift); + int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0); v_temp00 = vshrq_n_s32(v_temp00, 31); @@ -277,9 +326,9 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in10 = vqaddq_s32(v_in10, v_temp10); } - v_in00 = vrshlq_s32(v_in00, v_shift); + v_in00 = vrshlq_s32(v_in00, v_shf0); - v_in10 = vrshlq_s32(v_in10, v_shift); + v_in10 = vrshlq_s32(v_in10, v_shf0); v_in00 = vaddq_s32(v_in00, v_c_offset); @@ -307,21 +356,40 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u int32x4_t v_col0 = vdupq_n_s32(0); int32x4_t v_in00 = vdupq_n_s32(0); int32x4_t v_in10 = vdupq_n_s32(0); + int32x4_t v_mul0 = vdupq_n_s32(0); + int32x4_t v_shf0 = vdupq_n_s32(0); + + if (!per_channel) { + v_mul0 = v_mul; + v_shf0 = v_shift; + } do { v_col0 = vld1q_lane_s32(colptr, v_col0, 0); v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0); v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0); + if (per_channel) { + v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0); + v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0); + } if (odds == 1) { break; } v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1); v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1); v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1); + if (per_channel) { + v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1); + v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1); + } if (odds == 2) { break; } v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2); v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2); v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2); + if (per_channel) { + v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2); + v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2); + } } while (0); // Add on row sum and bias constant @@ -335,15 +403,15 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in10 = vaddq_s32(v_in10, v_col0); // Quantize - start with multiply - v_in00 = vqrdmulhq_s32(v_in00, v_mul); + v_in00 = vqrdmulhq_s32(v_in00, v_mul0); - v_in10 = vqrdmulhq_s32(v_in10, v_mul); + v_in10 = vqrdmulhq_s32(v_in10, v_mul0); // Compute and add on corrective offset if (do_shift_correction) { - int32x4_t v_temp00 = vandq_s32(v_in00, v_shift); + int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0); - int32x4_t v_temp10 = vandq_s32(v_in10, v_shift); + int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0); v_temp00 = vshrq_n_s32(v_temp00, 31); @@ -354,9 +422,9 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in10 = vqaddq_s32(v_in10, v_temp10); } - v_in00 = vrshlq_s32(v_in00, v_shift); + v_in00 = vrshlq_s32(v_in00, v_shf0); - v_in10 = vrshlq_s32(v_in10, v_shift); + v_in10 = vrshlq_s32(v_in10, v_shf0); v_in00 = vaddq_s32(v_in00, v_c_offset); @@ -391,23 +459,33 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u } // anonymous namespace template<typename Tin, typename Tout> -void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias) { - if (qp.minval >= qp.c_offset) { - requantize_block_32_int<false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, - reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias); + if (qp.per_channel_requant) { + if (qp.minval >= qp.c_offset) { + requantize_block_32_int<false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias); + } else { + requantize_block_32_int<true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias); + } } else { - requantize_block_32_int<true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, - reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias); + if (qp.minval >= qp.c_offset) { + requantize_block_32_int<false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias); + } else { + requantize_block_32_int<true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias); + } } } -template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias); -template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias); @@ -448,7 +526,7 @@ template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int wid */ namespace { struct row_sum_helpers { - const ARequantizeLayer32 &qp; + const Requantize32 &qp; /* Load a full 16 byte vector, pairwise accumulate into 'sum' with uadalp or sadalp */ template<typename T> @@ -571,7 +649,7 @@ namespace { } } - row_sum_helpers(const ARequantizeLayer32 &qp) : qp(qp) { } + row_sum_helpers(const Requantize32 &qp) : qp(qp) { } }; template<> @@ -612,8 +690,14 @@ namespace { } template<typename T> -void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *row_bias) { + /* If the 'b' offset is zero, just skip this entirely. */ + if (qp.b_offset == 0) { + memset(row_bias, 0, height * sizeof(int32_t)); + return; + } + row_sum_helpers thehelpers(qp); const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset); @@ -663,8 +747,8 @@ void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned } /* Instantiate the two versions for uint8_t and int8_t. */ -template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *); -template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *); +template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *); +template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *); template<unsigned int active_rows, typename T> inline void add_block(const T *input, unsigned int in_stride, int32_t *output); @@ -739,41 +823,44 @@ inline void add_block(const int8_t *input, unsigned int in_stride, int32_t *outp * in cases where we are not computing the first columns of the output (i.e. * in multithreaded cases where we divide columns across threads) */ template<typename T> -void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col) { - memset(reinterpret_cast<void *>(col_bias), 0, width * sizeof(int32_t)); - - for (unsigned int row=0; row<height; row+=4) { - unsigned int numrows=std::min(height-row, 4u); - - for (unsigned int col=0; col<width; col+=16) { - unsigned int numcols=std::min(width-col, 16u); - - if (numcols==16) { - switch(numrows) { - default: - case 1: - add_block<1>(input + row * in_stride + col, in_stride, col_bias + col); - break; - - case 2: - add_block<2>(input + row * in_stride + col, in_stride, col_bias + col); - break; - - case 3: - add_block<3>(input + row * in_stride + col, in_stride, col_bias + col); - break; - - case 4: - add_block<4>(input + row * in_stride + col, in_stride, col_bias + col); - break; - } - } else { - for (; col<width; col++) { - int32_t sum=0; - for (unsigned int r=0; r<numrows; r++) { - sum += input[(row + r)*in_stride + col]; +void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col) { + /* Only actually add up the columns if a_offset is non-zero. */ + if (qp.a_offset != 0) { + memset(reinterpret_cast<void *>(col_bias), 0, width * sizeof(int32_t)); + + for (unsigned int row=0; row<height; row+=4) { + unsigned int numrows=std::min(height-row, 4u); + + for (unsigned int col=0; col<width; col+=16) { + unsigned int numcols=std::min(width-col, 16u); + + if (numcols==16) { + switch(numrows) { + default: + case 1: + add_block<1>(input + row * in_stride + col, in_stride, col_bias + col); + break; + + case 2: + add_block<2>(input + row * in_stride + col, in_stride, col_bias + col); + break; + + case 3: + add_block<3>(input + row * in_stride + col, in_stride, col_bias + col); + break; + + case 4: + add_block<4>(input + row * in_stride + col, in_stride, col_bias + col); + break; + } + } else { + for (; col<width; col++) { + int32_t sum=0; + for (unsigned int r=0; r<numrows; r++) { + sum += input[(row + r)*in_stride + col]; + } + col_bias[col] += sum; } - col_bias[col] += sum; } } } @@ -792,8 +879,8 @@ void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned } } -template void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); -template void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); +template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); +template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/quantized.hpp b/src/core/NEON/kernels/arm_gemm/quantized.hpp index a22750796c..a91a888ad9 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.hpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.hpp @@ -26,16 +26,16 @@ namespace arm_gemm { template<typename Tin, typename Tout> -void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias); template<typename T> -void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *row_bias); template<typename T> -void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); |