From 71ac9037abce1c6c4af42c485d5395dd6fd79a5a Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Thu, 14 Nov 2019 14:31:44 +0000 Subject: COMPMID-2923 Integrate arm_gemm per channel quantization Signed-off-by: Michalis Spyrou Change-Id: I8667e75843fdd6ac75bd8272a86a348b830da28d Reviewed-on: https://review.mlplatform.org/c/2548 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- .../core/NEON/kernels/assembly/arm_gemm.hpp | 52 +- src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 12 +- .../kernels/arm_gemm/gemm_hybrid_quantized.hpp | 4 +- src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp | 105 ++ src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp | 36 +- .../arm_gemm/kernels/a64_hybrid_fp32_mla_4x8.hpp | 89 + .../kernels/a64_hybrid_fp32_mla_4x8/generic.cpp | 1923 ++++++++++++++++++++ .../arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4.hpp | 2 +- .../kernels/a64_hybrid_s8s32_dot_16x4/a55.cpp | 375 ++-- .../kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp | 321 ++-- .../NEON/kernels/arm_gemm/quantize_wrapper.hpp | 4 +- src/core/NEON/kernels/arm_gemm/quantized.cpp | 265 ++- src/core/NEON/kernels/arm_gemm/quantized.hpp | 6 +- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 70 +- .../functions/NEGEMMLowpMatrixMultiplyCore.cpp | 4 +- 15 files changed, 2887 insertions(+), 381 deletions(-) create mode 100644 src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8/generic.cpp diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp index d51fda525b..e89523981d 100644 --- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp +++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp @@ -108,23 +108,45 @@ public: } }; -struct ARequantizeLayer32 +struct Requantize32 { public: - const int32_t *bias; - size_t bias_multi_stride; - int32_t a_offset; - int32_t b_offset; - int32_t c_offset; - int32_t requant_shift; - int32_t requant_mul; - int32_t minval; - int32_t maxval; - - ARequantizeLayer32() = default; - - ARequantizeLayer32(const int32_t *b, size_t bms, int32_t ao, int32_t bo, int32_t co, int32_t rs, int32_t rm, int32_t minv, int32_t maxv) : - bias(b), bias_multi_stride(bms), a_offset(ao), b_offset(bo), c_offset(co), requant_shift(rs), requant_mul(rm), minval(minv), maxval(maxv) + const int32_t *bias = nullptr; + size_t bias_multi_stride = 0; + int32_t a_offset = 0; + int32_t b_offset = 0; + int32_t c_offset = 0; + bool per_channel_requant = false; + int32_t per_layer_shift = 0; + int32_t per_layer_mul = 0; + const int32_t *per_channel_shifts = nullptr; + const int32_t *per_channel_muls = nullptr; + int32_t minval = 0; + int32_t maxval = 0; + + Requantize32() = default; + + // Constructor for per-tensor quantization + Requantize32(const int32_t *bias, size_t bias_multi_stride, + int32_t a_offset, int32_t b_offset, int32_t c_offset, + int32_t requant_shift, int32_t requant_mul, + int32_t minv, int32_t maxv) : + bias(bias), bias_multi_stride(bias_multi_stride), + a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), + per_channel_requant(false), per_layer_shift(requant_shift), per_layer_mul(requant_mul), + minval(minv), maxval(maxv) + { + } + + // Constructor for per-channel quantization + Requantize32(const int32_t *bias, size_t bias_multi_stride, + int32_t a_offset, int32_t b_offset, int32_t c_offset, + const int32_t *requant_shifts, const int32_t *requant_muls, + int32_t minv, int32_t maxv) : + bias(bias), bias_multi_stride(bias_multi_stride), + a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), + per_channel_requant(true), per_channel_shifts(requant_shifts), per_channel_muls(requant_muls), + minval(minv), maxval(maxv) { } }; diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index cf91ee0652..7f171ec15a 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -33,6 +33,7 @@ #include "kernels/a32_sgemm_8x6.hpp" #include "kernels/a64_hybrid_fp32_mla_16x4.hpp" +#include "kernels/a64_hybrid_fp32_mla_4x8.hpp" #include "kernels/a64_native_fp32_mla_16x4.hpp" #include "kernels/a64_smallK_hybrid_fp32_mla_4x6.hpp" #include "kernels/a64_smallK_hybrid_fp32_mla_4x8.hpp" @@ -104,11 +105,18 @@ static const GemmImplementation gemm_fp32_methods[] = nullptr, [](const GemmArgs &args) { return new GemmHybrid(args); } }, +{ + GemmMethod::GEMM_HYBRID, + "hybrid_fp32_mla_4x8_normal", + [](const GemmArgs &args) { return (args._Ksize >= 4) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args) { return (args._Nsize < 12); }, + [](const GemmArgs &args) { return new GemmHybrid(args); } +}, { GemmMethod::GEMM_HYBRID, "hybrid_fp32_mla_16x4", [](const GemmArgs &args) { return (args._Ksize >= 4) && !args._trA && args._pretransposed_hint; }, - [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, + [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || (args._Msize < 16) || (args._nmulti > 1); }, [](const GemmArgs &args) { return new GemmHybrid(args); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp index 574ecef5b2..22b6960baf 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp @@ -68,7 +68,7 @@ class GemmHybridQuantized : public GemmCommon { const NDRange<4> _window_range; - ARequantizeLayer32 _qp; + Requantize32 _qp; int32_t *row_bias = nullptr; int32_t *col_bias = nullptr; @@ -140,7 +140,7 @@ public: GemmHybridQuantized & operator= (GemmHybridQuantized &) = delete; /* Constructor */ - GemmHybridQuantized(const GemmArgs &args, const ARequantizeLayer32 &qp) + GemmHybridQuantized(const GemmArgs &args, const Requantize32 &qp) : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), _nbatches(args._nbatches), _nmulti(args._nmulti), _trB(args._trB), _k_block(compute_k_block(args)), _n_block(compute_n_block(args)), diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp new file mode 100644 index 0000000000..73d0c272a6 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2019 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include "arm_gemm.hpp" + +#include "kernels/a64_hybrid_s8s32_dot_16x4.hpp" +#include "kernels/a64_smallK_hybrid_s8s32_dot_4x6.hpp" +#include "kernels/a64_smallK_hybrid_s8s32_dot_4x8.hpp" +#include "kernels/sve_hybrid_s8s32_dot_4VLx4.hpp" +#include "kernels/sve_smallK_hybrid_s8s32_dot_1VLx8.hpp" + +#include "gemm_hybrid_quantized.hpp" +#include "quantize_wrapper.hpp" + +namespace arm_gemm { + +static const GemmImplementation gemm_qint8_methods[] = +{ +#ifdef __ARM_FEATURE_SVE +{ + GemmMethod::GEMM_HYBRID_QUANTIZED, + "smallK_hybrid_s8s32_dot_1VLx8", + [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64 && !args._trA && args._pretransposed_hint; }, + nullptr, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized(args, qp); } +}, +{ + GemmMethod::GEMM_HYBRID_QUANTIZED, + "hybrid_s8s32_dot_4VLx4", + [](const GemmArgs &args, const Requantize32 &) { return args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized(args, qp); } +}, +#endif +{ + GemmMethod::GEMM_HYBRID_QUANTIZED, + "smallK_hybrid_s8s32_dot_4x8", + [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._trA && args._pretransposed_hint; }, + nullptr, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized(args, qp); } +}, +{ + GemmMethod::GEMM_HYBRID_QUANTIZED, + "smallK_hybrid_s8s32_dot_4x6", + [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._trA && args._pretransposed_hint; }, + nullptr, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized(args, qp); } +}, +{ + GemmMethod::GEMM_HYBRID_QUANTIZED, + "hybrid_s8s32_dot_16x4", + [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return args._Nsize<=256 && args._Ksize>128; }, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized(args, qp); } +}, +{ + GemmMethod::QUANTIZE_WRAPPER, + "quantized_wrapper", + nullptr, + nullptr, + [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper(args, qp); } +}, +{ + GemmMethod::DEFAULT, + "", + nullptr, + nullptr, + nullptr +} +}; + +template<> +const GemmImplementation *gemm_implementation_list() { + return gemm_qint8_methods; +} + +template UniqueGemmCommon gemm(const GemmArgs &args, const Requantize32 &os); +template KernelDescription get_gemm_method(const GemmArgs &args, const Requantize32 &os); +template std::vector get_compatible_kernels(const GemmArgs &args, const Requantize32 &os); + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index 079c04ae06..59cd1704ff 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -36,51 +36,51 @@ namespace arm_gemm { -static const GemmImplementation gemm_quint8_methods[] = +static const GemmImplementation gemm_quint8_methods[] = { #ifdef __ARM_FEATURE_SVE { GemmMethod::GEMM_HYBRID_QUANTIZED, "smallK_hybrid_u8u32_dot_1VLx8", - [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._Ksize<=64 && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64 && !args._trA && args._pretransposed_hint; }, nullptr, - [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized(args, qp); } + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized(args, qp); } }, { GemmMethod::GEMM_HYBRID_QUANTIZED, "hybrid_u8u32_dot_4VLx4", - [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; }, - [](const GemmArgs &args, const ARequantizeLayer32 &) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, - [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized(args, qp); } + [](const GemmArgs &args, const Requantize32 &) { return args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized(args, qp); } }, #endif { GemmMethod::GEMM_HYBRID_QUANTIZED, "smallK_hybrid_u8u32_dot_4x8", - [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._trA && args._pretransposed_hint; }, nullptr, - [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized(args, qp); } + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized(args, qp); } }, { GemmMethod::GEMM_HYBRID_QUANTIZED, "smallK_hybrid_u8u32_dot_4x6", - [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._trA && args._pretransposed_hint; }, nullptr, - [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized(args, qp); } + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized(args, qp); } }, { GemmMethod::GEMM_HYBRID_QUANTIZED, "hybrid_u8u32_dot_16x4", - [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._ci->has_dotprod() && args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; }, - [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._Nsize<=256 && args._Ksize>128; }, - [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized(args, qp); } + [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; }, + [](const GemmArgs &args, const Requantize32 &) { return args._Nsize<=256 && args._Ksize>128; }, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized(args, qp); } }, { GemmMethod::QUANTIZE_WRAPPER, "quantized_wrapper", nullptr, nullptr, - [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new QuantizeWrapper(args, qp); } + [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper(args, qp); } }, { GemmMethod::DEFAULT, @@ -92,13 +92,13 @@ static const GemmImplementation gemm_quint }; template<> -const GemmImplementation *gemm_implementation_list() { +const GemmImplementation *gemm_implementation_list() { return gemm_quint8_methods; } -template UniqueGemmCommon gemm(const GemmArgs &args, const ARequantizeLayer32 &os); -template KernelDescription get_gemm_method(const GemmArgs &args, const ARequantizeLayer32 &os); -template std::vector get_compatible_kernels(const GemmArgs &args, const ARequantizeLayer32 &os); +template UniqueGemmCommon gemm(const GemmArgs &args, const Requantize32 &os); +template KernelDescription get_gemm_method(const GemmArgs &args, const Requantize32 &os); +template std::vector get_compatible_kernels(const GemmArgs &args, const Requantize32 &os); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8.hpp new file mode 100644 index 0000000000..da5beef48c --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8.hpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2018-2019 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + + +#include "../std_transforms_fixed.hpp" + +namespace arm_gemm +{ + +// Actual kernel implementations +void a64_hybrid_fp32_mla_4x8(const float *, int, const float *, float *, int, int, int, int, const float *, Activation, bool); + +class hybrid_fp32_mla_4x8 +{ +public: + typedef float operand_type; + typedef float result_type; + + typedef void (*kern_type)(const float *, int, const float *, float *, int, int, int, int, const float *, Activation, bool); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 8; + } + + static unsigned int out_width() + { + return 4; + } + + static constexpr unsigned int k_unroll() + { + return 1; + } + + static constexpr bool supports_append() + { + return false; + } + + static constexpr bool supports_bias() + { + return true; + } + + static constexpr bool supports_activation() + { + return true; + } + + StdTransformsFixed transforms = {}; + + // Default to the generic kernel + kern_type kernel=a64_hybrid_fp32_mla_4x8; + + hybrid_fp32_mla_4x8(const CPUInfo *ci) + { + UNUSED(ci); + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8/generic.cpp new file mode 100644 index 0000000000..db7eb83160 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8/generic.cpp @@ -0,0 +1,1923 @@ +/* + * Copyright (c) 2018-2019 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include + +#include "arm_gemm.hpp" + +#include "../../asmlib.hpp" +#include "../../utils.hpp" + +namespace arm_gemm { + +void a64_hybrid_fp32_mla_4x8(const float *A, int lda, const float *B, float *C, int ldc, int M, int N, int K, const float *bias, Activation act, bool append) { + const int K_stride = K; + const long loops_count = ((K + 4) / 8) - 1; + K -= loops_count * 8; + const long regs_count = (K / 4) - 1; + K -= (regs_count + 1) * 4; + const long blocks_count = K / 1; + float nullbias[4]; + if (!append && !bias) { + memset(nullbias, 0, (4 * sizeof(float))); + } + float minval = - static_cast(std::numeric_limits::infinity()); + float maxval = static_cast(std::numeric_limits::infinity()); + const float * const minptr = &minval; + const float * const maxptr = &maxval; + + switch(act.type) + { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + maxval = static_cast(act.param1); + /* fall through */ + case Activation::Type::ReLU: + minval = 0.0f; + break; + } + + for (int y=0; y(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory" + ); + break; + case 2: + __asm __volatile ( + "a_ptr1 .req X0\n" + "c_ptr1 .req X1\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "ldr q16, [%[b_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "ldr q9, [a_ptr1]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "ldr q16, [%[b_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "str q25, [c_ptr1]\n" + ".unreq a_ptr1\n" + ".unreq c_ptr1\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "cc", "memory" + ); + break; + case 3: + __asm __volatile ( + "a_ptr1 .req X0\n" + "a_ptr2 .req X1\n" + "c_ptr1 .req X2\n" + "c_ptr2 .req X3\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "mov v26.16b, v24.16b\n" + "ldr q2, [a_ptr2]\n" + "ldr q16, [%[b_ptr0]]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q9, [a_ptr1]\n" + "ldr q10, [a_ptr2]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "add a_ptr2, a_ptr2, #0x20\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "ldr q2, [a_ptr2, #-0x10]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "prfm PLDL1KEEP, [a_ptr2, #0x40]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "prfm PSTL1KEEP, [c_ptr2]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q10, [a_ptr2]\n" + "ldr q16, [%[b_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr s2, [a_ptr2]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "add a_ptr2, a_ptr2, #0x4\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmax v26.4s, v26.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "fmin v26.4s, v26.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "str q25, [c_ptr1]\n" + "str q26, [c_ptr2]\n" + ".unreq a_ptr1\n" + ".unreq a_ptr2\n" + ".unreq c_ptr1\n" + ".unreq c_ptr2\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "cc", "memory" + ); + break; + case 4: + __asm __volatile ( + "a_ptr1 .req X0\n" + "a_ptr2 .req X1\n" + "a_ptr3 .req X2\n" + "c_ptr1 .req X3\n" + "c_ptr2 .req X4\n" + "c_ptr3 .req X5\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "mov v26.16b, v24.16b\n" + "ldr q2, [a_ptr2]\n" + "mov v27.16b, v24.16b\n" + "ldr q16, [%[b_ptr0]]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "ldr q3, [a_ptr3]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add a_ptr3, a_ptr3, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q11, [a_ptr3]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "add a_ptr2, a_ptr2, #0x20\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "ldr q2, [a_ptr2, #-0x10]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "add a_ptr3, a_ptr3, #0x20\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "ldr q3, [a_ptr3, #-0x10]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "prfm PLDL1KEEP, [a_ptr2, #0x40]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "prfm PLDL1KEEP, [a_ptr3, #0x40]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "prfm PSTL1KEEP, [c_ptr2]\n" + "prfm PSTL1KEEP, [c_ptr3]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr3, a_ptr3, #0x10\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr s2, [a_ptr2]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "add a_ptr2, a_ptr2, #0x4\n" + "ldr s3, [a_ptr3]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "add a_ptr3, a_ptr3, #0x4\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmax v26.4s, v26.4s, v22.4s\n" + "fmax v27.4s, v27.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "fmin v26.4s, v26.4s, v23.4s\n" + "fmin v27.4s, v27.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "str q25, [c_ptr1]\n" + "str q26, [c_ptr2]\n" + "str q27, [c_ptr3]\n" + ".unreq a_ptr1\n" + ".unreq a_ptr2\n" + ".unreq a_ptr3\n" + ".unreq c_ptr1\n" + ".unreq c_ptr2\n" + ".unreq c_ptr3\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc", "memory" + ); + break; + case 5: + __asm __volatile ( + "a_ptr1 .req X0\n" + "a_ptr2 .req X1\n" + "a_ptr3 .req X2\n" + "a_ptr4 .req X3\n" + "c_ptr1 .req X4\n" + "c_ptr2 .req X5\n" + "c_ptr3 .req X6\n" + "c_ptr4 .req X7\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "mov v26.16b, v24.16b\n" + "ldr q2, [a_ptr2]\n" + "mov v27.16b, v24.16b\n" + "ldr q16, [%[b_ptr0]]\n" + "mov v28.16b, v24.16b\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "ldr q3, [a_ptr3]\n" + "add a_ptr4, a_ptr3, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q4, [a_ptr4]\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add c_ptr4, c_ptr3, %[ldc]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add a_ptr3, a_ptr3, #0x10\n" + "add a_ptr4, a_ptr4, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q12, [a_ptr4]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "add a_ptr2, a_ptr2, #0x20\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "add a_ptr3, a_ptr3, #0x20\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "ldr q2, [a_ptr2, #-0x10]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "ldr q3, [a_ptr3, #-0x10]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add a_ptr4, a_ptr4, #0x20\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "ldr q4, [a_ptr4, #-0x10]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "prfm PLDL1KEEP, [a_ptr2, #0x40]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "prfm PLDL1KEEP, [a_ptr3, #0x40]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "prfm PSTL1KEEP, [c_ptr2]\n" + "prfm PSTL1KEEP, [c_ptr3]\n" + "prfm PSTL1KEEP, [c_ptr4]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr3, a_ptr3, #0x10\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr4, a_ptr4, #0x10\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr s2, [a_ptr2]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "add a_ptr2, a_ptr2, #0x4\n" + "ldr s3, [a_ptr3]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "add a_ptr3, a_ptr3, #0x4\n" + "ldr s4, [a_ptr4]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "add a_ptr4, a_ptr4, #0x4\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmax v26.4s, v26.4s, v22.4s\n" + "fmax v27.4s, v27.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "fmin v26.4s, v26.4s, v23.4s\n" + "fmin v27.4s, v27.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "fmax v28.4s, v28.4s, v22.4s\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "str q25, [c_ptr1]\n" + "fmin v28.4s, v28.4s, v23.4s\n" + "str q26, [c_ptr2]\n" + "str q27, [c_ptr3]\n" + "str q28, [c_ptr4]\n" + ".unreq a_ptr1\n" + ".unreq a_ptr2\n" + ".unreq a_ptr3\n" + ".unreq a_ptr4\n" + ".unreq c_ptr1\n" + ".unreq c_ptr2\n" + ".unreq c_ptr3\n" + ".unreq c_ptr4\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory" + ); + break; + case 6: + __asm __volatile ( + "a_ptr1 .req X0\n" + "a_ptr2 .req X1\n" + "a_ptr3 .req X2\n" + "a_ptr4 .req X3\n" + "a_ptr5 .req X4\n" + "c_ptr1 .req X5\n" + "c_ptr2 .req X6\n" + "c_ptr3 .req X7\n" + "c_ptr4 .req X8\n" + "c_ptr5 .req X9\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "mov v26.16b, v24.16b\n" + "ldr q2, [a_ptr2]\n" + "mov v27.16b, v24.16b\n" + "ldr q16, [%[b_ptr0]]\n" + "mov v28.16b, v24.16b\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "mov v29.16b, v24.16b\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "ldr q3, [a_ptr3]\n" + "add a_ptr4, a_ptr3, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q4, [a_ptr4]\n" + "add a_ptr5, a_ptr4, %[lda]\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "ldr q5, [a_ptr5]\n" + "add c_ptr4, c_ptr3, %[ldc]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add c_ptr5, c_ptr4, %[ldc]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add a_ptr3, a_ptr3, #0x10\n" + "add a_ptr4, a_ptr4, #0x10\n" + "add a_ptr5, a_ptr5, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q13, [a_ptr5]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr2, a_ptr2, #0x20\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "add a_ptr3, a_ptr3, #0x20\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "add a_ptr4, a_ptr4, #0x20\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "add a_ptr5, a_ptr5, #0x20\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "ldr q2, [a_ptr2, #-0x10]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "ldr q3, [a_ptr3, #-0x10]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "ldr q4, [a_ptr4, #-0x10]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "ldr q5, [a_ptr5, #-0x10]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "prfm PLDL1KEEP, [a_ptr2, #0x40]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "prfm PLDL1KEEP, [a_ptr3, #0x40]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v29.4s, v16.4s, v13.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v29.4s, v17.4s, v13.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v29.4s, v18.4s, v13.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "fmla v29.4s, v19.4s, v13.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "prfm PSTL1KEEP, [c_ptr2]\n" + "prfm PSTL1KEEP, [c_ptr3]\n" + "prfm PSTL1KEEP, [c_ptr4]\n" + "prfm PSTL1KEEP, [c_ptr5]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "ldr q13, [a_ptr5]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "add a_ptr3, a_ptr3, #0x10\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr4, a_ptr4, #0x10\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr5, a_ptr5, #0x10\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v29.4s, v16.4s, v13.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v29.4s, v17.4s, v13.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v29.4s, v18.4s, v13.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "fmla v29.4s, v19.4s, v13.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr s2, [a_ptr2]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "add a_ptr2, a_ptr2, #0x4\n" + "ldr s3, [a_ptr3]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "add a_ptr3, a_ptr3, #0x4\n" + "ldr s4, [a_ptr4]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "add a_ptr4, a_ptr4, #0x4\n" + "ldr s5, [a_ptr5]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "add a_ptr5, a_ptr5, #0x4\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmax v26.4s, v26.4s, v22.4s\n" + "fmax v27.4s, v27.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "fmin v26.4s, v26.4s, v23.4s\n" + "fmin v27.4s, v27.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "fmax v28.4s, v28.4s, v22.4s\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "fmax v29.4s, v29.4s, v22.4s\n" + "str q25, [c_ptr1]\n" + "fmin v28.4s, v28.4s, v23.4s\n" + "fmin v29.4s, v29.4s, v23.4s\n" + "str q26, [c_ptr2]\n" + "str q27, [c_ptr3]\n" + "str q28, [c_ptr4]\n" + "str q29, [c_ptr5]\n" + ".unreq a_ptr1\n" + ".unreq a_ptr2\n" + ".unreq a_ptr3\n" + ".unreq a_ptr4\n" + ".unreq a_ptr5\n" + ".unreq c_ptr1\n" + ".unreq c_ptr2\n" + ".unreq c_ptr3\n" + ".unreq c_ptr4\n" + ".unreq c_ptr5\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "cc", "memory" + ); + break; + case 7: + __asm __volatile ( + "a_ptr1 .req X0\n" + "a_ptr2 .req X1\n" + "a_ptr3 .req X2\n" + "a_ptr4 .req X3\n" + "a_ptr5 .req X4\n" + "a_ptr6 .req X5\n" + "c_ptr1 .req X6\n" + "c_ptr2 .req X7\n" + "c_ptr3 .req X8\n" + "c_ptr4 .req X9\n" + "c_ptr5 .req X10\n" + "c_ptr6 .req X11\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "mov v26.16b, v24.16b\n" + "ldr q2, [a_ptr2]\n" + "mov v27.16b, v24.16b\n" + "ldr q16, [%[b_ptr0]]\n" + "mov v28.16b, v24.16b\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "mov v29.16b, v24.16b\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "mov v30.16b, v24.16b\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "ldr q3, [a_ptr3]\n" + "add a_ptr4, a_ptr3, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q4, [a_ptr4]\n" + "add a_ptr5, a_ptr4, %[lda]\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "ldr q5, [a_ptr5]\n" + "add a_ptr6, a_ptr5, %[lda]\n" + "add c_ptr4, c_ptr3, %[ldc]\n" + "ldr q6, [a_ptr6]\n" + "add c_ptr5, c_ptr4, %[ldc]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add c_ptr6, c_ptr5, %[ldc]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add a_ptr3, a_ptr3, #0x10\n" + "add a_ptr4, a_ptr4, #0x10\n" + "add a_ptr5, a_ptr5, #0x10\n" + "add a_ptr6, a_ptr6, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "ldr q13, [a_ptr5]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q14, [a_ptr6]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v30.4s, v17.4s, v6.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr2, a_ptr2, #0x20\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr3, a_ptr3, #0x20\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "add a_ptr4, a_ptr4, #0x20\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "add a_ptr5, a_ptr5, #0x20\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "add a_ptr6, a_ptr6, #0x20\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v30.4s, v18.4s, v6.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "ldr q2, [a_ptr2, #-0x10]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "ldr q3, [a_ptr3, #-0x10]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "ldr q4, [a_ptr4, #-0x10]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "ldr q5, [a_ptr5, #-0x10]\n" + "fmla v30.4s, v19.4s, v6.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "ldr q6, [a_ptr6, #-0x10]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "prfm PLDL1KEEP, [a_ptr2, #0x40]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "prfm PLDL1KEEP, [a_ptr3, #0x40]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v29.4s, v16.4s, v13.s[0]\n" + "fmla v30.4s, v16.4s, v14.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v29.4s, v17.4s, v13.s[1]\n" + "fmla v30.4s, v17.4s, v14.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v29.4s, v18.4s, v13.s[2]\n" + "fmla v30.4s, v18.4s, v14.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "fmla v29.4s, v19.4s, v13.s[3]\n" + "fmla v30.4s, v19.4s, v14.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "prfm PSTL1KEEP, [c_ptr2]\n" + "prfm PSTL1KEEP, [c_ptr3]\n" + "prfm PSTL1KEEP, [c_ptr4]\n" + "prfm PSTL1KEEP, [c_ptr5]\n" + "prfm PSTL1KEEP, [c_ptr6]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "ldr q13, [a_ptr5]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "ldr q14, [a_ptr6]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "add a_ptr3, a_ptr3, #0x10\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "add a_ptr4, a_ptr4, #0x10\n" + "fmla v30.4s, v17.4s, v6.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr5, a_ptr5, #0x10\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr6, a_ptr6, #0x10\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "fmla v30.4s, v18.4s, v6.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "fmla v30.4s, v19.4s, v6.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v29.4s, v16.4s, v13.s[0]\n" + "fmla v30.4s, v16.4s, v14.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v29.4s, v17.4s, v13.s[1]\n" + "fmla v30.4s, v17.4s, v14.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v29.4s, v18.4s, v13.s[2]\n" + "fmla v30.4s, v18.4s, v14.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "fmla v29.4s, v19.4s, v13.s[3]\n" + "fmla v30.4s, v19.4s, v14.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "fmla v30.4s, v17.4s, v6.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "fmla v30.4s, v18.4s, v6.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "fmla v30.4s, v19.4s, v6.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr s2, [a_ptr2]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "add a_ptr2, a_ptr2, #0x4\n" + "ldr s3, [a_ptr3]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "add a_ptr3, a_ptr3, #0x4\n" + "ldr s4, [a_ptr4]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "add a_ptr4, a_ptr4, #0x4\n" + "ldr s5, [a_ptr5]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "add a_ptr5, a_ptr5, #0x4\n" + "ldr s6, [a_ptr6]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "add a_ptr6, a_ptr6, #0x4\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmax v26.4s, v26.4s, v22.4s\n" + "fmax v27.4s, v27.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "fmin v26.4s, v26.4s, v23.4s\n" + "fmin v27.4s, v27.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "fmax v28.4s, v28.4s, v22.4s\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "fmax v29.4s, v29.4s, v22.4s\n" + "str q25, [c_ptr1]\n" + "fmax v30.4s, v30.4s, v22.4s\n" + "fmin v28.4s, v28.4s, v23.4s\n" + "fmin v29.4s, v29.4s, v23.4s\n" + "str q26, [c_ptr2]\n" + "fmin v30.4s, v30.4s, v23.4s\n" + "str q27, [c_ptr3]\n" + "str q28, [c_ptr4]\n" + "str q29, [c_ptr5]\n" + "str q30, [c_ptr6]\n" + ".unreq a_ptr1\n" + ".unreq a_ptr2\n" + ".unreq a_ptr3\n" + ".unreq a_ptr4\n" + ".unreq a_ptr5\n" + ".unreq a_ptr6\n" + ".unreq c_ptr1\n" + ".unreq c_ptr2\n" + ".unreq c_ptr3\n" + ".unreq c_ptr4\n" + ".unreq c_ptr5\n" + ".unreq c_ptr6\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "cc", "memory" + ); + break; + default: + case 8: + __asm __volatile ( + "a_ptr1 .req X0\n" + "a_ptr2 .req X1\n" + "a_ptr3 .req X2\n" + "a_ptr4 .req X3\n" + "a_ptr5 .req X4\n" + "a_ptr6 .req X5\n" + "a_ptr7 .req X6\n" + "c_ptr1 .req X7\n" + "c_ptr2 .req X8\n" + "c_ptr3 .req X9\n" + "c_ptr4 .req X10\n" + "c_ptr5 .req X11\n" + "c_ptr6 .req X12\n" + "c_ptr7 .req X13\n" + "ldr q24, [%[biasptr]]\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q0, [%[a_ptr0]]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "mov v25.16b, v24.16b\n" + "ldr q1, [a_ptr1]\n" + "mov v26.16b, v24.16b\n" + "ldr q2, [a_ptr2]\n" + "mov v27.16b, v24.16b\n" + "ldr q16, [%[b_ptr0]]\n" + "mov v28.16b, v24.16b\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "mov v29.16b, v24.16b\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "mov v30.16b, v24.16b\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "mov v31.16b, v24.16b\n" + "ldr q3, [a_ptr3]\n" + "add a_ptr4, a_ptr3, %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "ldr q4, [a_ptr4]\n" + "add a_ptr5, a_ptr4, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q5, [a_ptr5]\n" + "add a_ptr6, a_ptr5, %[lda]\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "ldr q6, [a_ptr6]\n" + "add a_ptr7, a_ptr6, %[lda]\n" + "add c_ptr4, c_ptr3, %[ldc]\n" + "ldr q7, [a_ptr7]\n" + "add c_ptr5, c_ptr4, %[ldc]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "add c_ptr6, c_ptr5, %[ldc]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add c_ptr7, c_ptr6, %[ldc]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add a_ptr3, a_ptr3, #0x10\n" + "add a_ptr4, a_ptr4, #0x10\n" + "add a_ptr5, a_ptr5, #0x10\n" + "add a_ptr6, a_ptr6, #0x10\n" + "add a_ptr7, a_ptr7, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "cbz %[loops], 1f\n" + "2:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "ldr q13, [a_ptr5]\n" + "fmla v31.4s, v16.4s, v7.s[0]\n" + "ldr q14, [a_ptr6]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q15, [a_ptr7]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "subs %[loops], %[loops], #0x1\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x20\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "add a_ptr1, a_ptr1, #0x20\n" + "fmla v30.4s, v17.4s, v6.s[1]\n" + "add a_ptr2, a_ptr2, #0x20\n" + "fmla v31.4s, v17.4s, v7.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr3, a_ptr3, #0x20\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr4, a_ptr4, #0x20\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "add a_ptr5, a_ptr5, #0x20\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "add a_ptr6, a_ptr6, #0x20\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "add a_ptr7, a_ptr7, #0x20\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "prfm PLDL1KEEP, [a_ptr1, #0x40]\n" + "fmla v30.4s, v18.4s, v6.s[2]\n" + "prfm PLDL1KEEP, [a_ptr2, #0x40]\n" + "fmla v31.4s, v18.4s, v7.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "ldr q0, [%[a_ptr0], #-0x10]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "ldr q1, [a_ptr1, #-0x10]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "ldr q2, [a_ptr2, #-0x10]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "ldr q3, [a_ptr3, #-0x10]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "ldr q4, [a_ptr4, #-0x10]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "ldr q5, [a_ptr5, #-0x10]\n" + "fmla v30.4s, v19.4s, v6.s[3]\n" + "ldr q6, [a_ptr6, #-0x10]\n" + "fmla v31.4s, v19.4s, v7.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "ldr q7, [a_ptr7, #-0x10]\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "prfm PLDL1KEEP, [a_ptr3, #0x40]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v29.4s, v16.4s, v13.s[0]\n" + "fmla v30.4s, v16.4s, v14.s[0]\n" + "fmla v31.4s, v16.4s, v15.s[0]\n" + "ldr q16, [%[b_ptr0], #0x40]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v29.4s, v17.4s, v13.s[1]\n" + "fmla v30.4s, v17.4s, v14.s[1]\n" + "fmla v31.4s, v17.4s, v15.s[1]\n" + "ldr q17, [%[b_ptr0], #0x50]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v29.4s, v18.4s, v13.s[2]\n" + "fmla v30.4s, v18.4s, v14.s[2]\n" + "fmla v31.4s, v18.4s, v15.s[2]\n" + "ldr q18, [%[b_ptr0], #0x60]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "fmla v29.4s, v19.4s, v13.s[3]\n" + "fmla v30.4s, v19.4s, v14.s[3]\n" + "fmla v31.4s, v19.4s, v15.s[3]\n" + "b.ne 2b\n" + "1:\n" + "ldr q19, [%[b_ptr0], #-0x10]\n" + "prfm PSTL1KEEP, [%[c_ptr0]]\n" + "prfm PSTL1KEEP, [c_ptr1]\n" + "prfm PSTL1KEEP, [c_ptr2]\n" + "prfm PSTL1KEEP, [c_ptr3]\n" + "prfm PSTL1KEEP, [c_ptr4]\n" + "prfm PSTL1KEEP, [c_ptr5]\n" + "prfm PSTL1KEEP, [c_ptr6]\n" + "prfm PSTL1KEEP, [c_ptr7]\n" + "cbz %[regs], 3f\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr q8, [%[a_ptr0]]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "ldr q9, [a_ptr1]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "ldr q10, [a_ptr2]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "ldr q11, [a_ptr3]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "ldr q12, [a_ptr4]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "ldr q13, [a_ptr5]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "ldr q14, [a_ptr6]\n" + "fmla v31.4s, v16.4s, v7.s[0]\n" + "ldr q15, [a_ptr7]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "ldr q16, [%[b_ptr0]]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "add a_ptr2, a_ptr2, #0x10\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "add a_ptr3, a_ptr3, #0x10\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "add a_ptr4, a_ptr4, #0x10\n" + "fmla v30.4s, v17.4s, v6.s[1]\n" + "add a_ptr5, a_ptr5, #0x10\n" + "fmla v31.4s, v17.4s, v7.s[1]\n" + "ldr q17, [%[b_ptr0], #0x10]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "add a_ptr6, a_ptr6, #0x10\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "add a_ptr7, a_ptr7, #0x10\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "fmla v30.4s, v18.4s, v6.s[2]\n" + "fmla v31.4s, v18.4s, v7.s[2]\n" + "ldr q18, [%[b_ptr0], #0x20]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "fmla v30.4s, v19.4s, v6.s[3]\n" + "fmla v31.4s, v19.4s, v7.s[3]\n" + "ldr q19, [%[b_ptr0], #0x30]\n" + "fmla v24.4s, v16.4s, v8.s[0]\n" + "add %[b_ptr0], %[b_ptr0], #0x40\n" + "fmla v25.4s, v16.4s, v9.s[0]\n" + "fmla v26.4s, v16.4s, v10.s[0]\n" + "fmla v27.4s, v16.4s, v11.s[0]\n" + "fmla v28.4s, v16.4s, v12.s[0]\n" + "fmla v29.4s, v16.4s, v13.s[0]\n" + "fmla v30.4s, v16.4s, v14.s[0]\n" + "fmla v31.4s, v16.4s, v15.s[0]\n" + "fmla v24.4s, v17.4s, v8.s[1]\n" + "fmla v25.4s, v17.4s, v9.s[1]\n" + "fmla v26.4s, v17.4s, v10.s[1]\n" + "fmla v27.4s, v17.4s, v11.s[1]\n" + "fmla v28.4s, v17.4s, v12.s[1]\n" + "fmla v29.4s, v17.4s, v13.s[1]\n" + "fmla v30.4s, v17.4s, v14.s[1]\n" + "fmla v31.4s, v17.4s, v15.s[1]\n" + "fmla v24.4s, v18.4s, v8.s[2]\n" + "fmla v25.4s, v18.4s, v9.s[2]\n" + "fmla v26.4s, v18.4s, v10.s[2]\n" + "fmla v27.4s, v18.4s, v11.s[2]\n" + "fmla v28.4s, v18.4s, v12.s[2]\n" + "fmla v29.4s, v18.4s, v13.s[2]\n" + "fmla v30.4s, v18.4s, v14.s[2]\n" + "fmla v31.4s, v18.4s, v15.s[2]\n" + "fmla v24.4s, v19.4s, v8.s[3]\n" + "fmla v25.4s, v19.4s, v9.s[3]\n" + "fmla v26.4s, v19.4s, v10.s[3]\n" + "fmla v27.4s, v19.4s, v11.s[3]\n" + "fmla v28.4s, v19.4s, v12.s[3]\n" + "fmla v29.4s, v19.4s, v13.s[3]\n" + "fmla v30.4s, v19.4s, v14.s[3]\n" + "fmla v31.4s, v19.4s, v15.s[3]\n" + "b 4f\n" + "3:\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "fmla v31.4s, v16.4s, v7.s[0]\n" + "fmla v24.4s, v17.4s, v0.s[1]\n" + "fmla v25.4s, v17.4s, v1.s[1]\n" + "fmla v26.4s, v17.4s, v2.s[1]\n" + "fmla v27.4s, v17.4s, v3.s[1]\n" + "fmla v28.4s, v17.4s, v4.s[1]\n" + "fmla v29.4s, v17.4s, v5.s[1]\n" + "fmla v30.4s, v17.4s, v6.s[1]\n" + "fmla v31.4s, v17.4s, v7.s[1]\n" + "fmla v24.4s, v18.4s, v0.s[2]\n" + "fmla v25.4s, v18.4s, v1.s[2]\n" + "fmla v26.4s, v18.4s, v2.s[2]\n" + "fmla v27.4s, v18.4s, v3.s[2]\n" + "fmla v28.4s, v18.4s, v4.s[2]\n" + "fmla v29.4s, v18.4s, v5.s[2]\n" + "fmla v30.4s, v18.4s, v6.s[2]\n" + "fmla v31.4s, v18.4s, v7.s[2]\n" + "fmla v24.4s, v19.4s, v0.s[3]\n" + "fmla v25.4s, v19.4s, v1.s[3]\n" + "fmla v26.4s, v19.4s, v2.s[3]\n" + "fmla v27.4s, v19.4s, v3.s[3]\n" + "fmla v28.4s, v19.4s, v4.s[3]\n" + "fmla v29.4s, v19.4s, v5.s[3]\n" + "fmla v30.4s, v19.4s, v6.s[3]\n" + "fmla v31.4s, v19.4s, v7.s[3]\n" + "4:\n" + "cbz %[blocks], 5f\n" + "6:\n" + "ldr q16, [%[b_ptr0]]\n" + "subs %[blocks], %[blocks], #0x1\n" + "add %[b_ptr0], %[b_ptr0], #0x10\n" + "ldr s0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x4\n" + "ldr s1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x4\n" + "fmla v24.4s, v16.4s, v0.s[0]\n" + "ldr s2, [a_ptr2]\n" + "fmla v25.4s, v16.4s, v1.s[0]\n" + "add a_ptr2, a_ptr2, #0x4\n" + "ldr s3, [a_ptr3]\n" + "fmla v26.4s, v16.4s, v2.s[0]\n" + "add a_ptr3, a_ptr3, #0x4\n" + "ldr s4, [a_ptr4]\n" + "fmla v27.4s, v16.4s, v3.s[0]\n" + "add a_ptr4, a_ptr4, #0x4\n" + "ldr s5, [a_ptr5]\n" + "fmla v28.4s, v16.4s, v4.s[0]\n" + "add a_ptr5, a_ptr5, #0x4\n" + "ldr s6, [a_ptr6]\n" + "fmla v29.4s, v16.4s, v5.s[0]\n" + "add a_ptr6, a_ptr6, #0x4\n" + "ldr s7, [a_ptr7]\n" + "fmla v30.4s, v16.4s, v6.s[0]\n" + "add a_ptr7, a_ptr7, #0x4\n" + "fmla v31.4s, v16.4s, v7.s[0]\n" + "b.ne 6b\n" + "5:\n" + "ld1r {v22.4s}, [%[minptr]]\n" + "ld1r {v23.4s}, [%[maxptr]]\n" + "fmax v24.4s, v24.4s, v22.4s\n" + "fmax v25.4s, v25.4s, v22.4s\n" + "fmax v26.4s, v26.4s, v22.4s\n" + "fmax v27.4s, v27.4s, v22.4s\n" + "fmin v24.4s, v24.4s, v23.4s\n" + "fmin v25.4s, v25.4s, v23.4s\n" + "fmin v26.4s, v26.4s, v23.4s\n" + "fmin v27.4s, v27.4s, v23.4s\n" + "str q24, [%[c_ptr0]]\n" + "fmax v28.4s, v28.4s, v22.4s\n" + "add %[c_ptr0], %[c_ptr0], #0x10\n" + "fmax v29.4s, v29.4s, v22.4s\n" + "str q25, [c_ptr1]\n" + "fmax v30.4s, v30.4s, v22.4s\n" + "fmin v28.4s, v28.4s, v23.4s\n" + "fmax v31.4s, v31.4s, v22.4s\n" + "str q26, [c_ptr2]\n" + "fmin v29.4s, v29.4s, v23.4s\n" + "fmin v30.4s, v30.4s, v23.4s\n" + "fmin v31.4s, v31.4s, v23.4s\n" + "str q27, [c_ptr3]\n" + "str q28, [c_ptr4]\n" + "str q29, [c_ptr5]\n" + "str q30, [c_ptr6]\n" + "str q31, [c_ptr7]\n" + ".unreq a_ptr1\n" + ".unreq a_ptr2\n" + ".unreq a_ptr3\n" + ".unreq a_ptr4\n" + ".unreq a_ptr5\n" + ".unreq a_ptr6\n" + ".unreq a_ptr7\n" + ".unreq c_ptr1\n" + ".unreq c_ptr2\n" + ".unreq c_ptr3\n" + ".unreq c_ptr4\n" + ".unreq c_ptr5\n" + ".unreq c_ptr6\n" + ".unreq c_ptr7\n" + : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks) + : [width] "r" (width), [append] "r" (static_cast(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", "memory" + ); + break; + } + if (use_result_buffer) { + for(int cy=0; cy _subgemm = nullptr; int32_t *_row_sums = nullptr; int32_t *_col_sums = nullptr; - ARequantizeLayer32 _params; + Requantize32 _params; GemmArgs _args; barrier _barrier; @@ -125,7 +125,7 @@ public: QuantizeWrapper(const QuantizeWrapper &) = delete; QuantizeWrapper operator=(const QuantizeWrapper &) = delete; - QuantizeWrapper(const GemmArgs &args, const ARequantizeLayer32 &qp) : _params(qp), _args(args), _barrier(args._maxthreads) { + QuantizeWrapper(const GemmArgs &args, const Requantize32 &qp) : _params(qp), _args(args), _barrier(args._maxthreads) { GemmArgs newargs = GemmArgs(args._ci, args._Msize, args._Nsize, args._Ksize, args._nbatches, args._nmulti, args._trA, args._trB, Activation(), args._maxthreads, args._pretransposed_hint, nullptr); _subgemm = gemm(newargs); diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp index bffb7ddcb3..00b42cf422 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.cpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp @@ -47,13 +47,19 @@ namespace { * applied to negative values being shifted right to make sure they round * properly - if negative values are never output (e.g. fused ReLU) this is * unnecessary. + * + * The 'per_channel' template parameter selects between per channel and per + * layer requantization - in the former case we need to load vectors of + * shifts and multipliers for each column. A separate vector for each + * column is set up in any case (and it is hoped that the compiler can elide + * the needless movs in the per-layer case). */ -template -void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +template +void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height, const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias) { - const int32x4_t v_mul = vdupq_n_s32(qp.requant_mul); - const int32x4_t v_shift = vdupq_n_s32(qp.requant_shift); + const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul); + const int32x4_t v_shift = vdupq_n_s32(qp.per_layer_shift); const int32x4_t v_minval = vdupq_n_s32(qp.minval); const int32x4_t v_maxval = vdupq_n_s32(qp.maxval); const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset); @@ -70,6 +76,8 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u unsigned int odds=(width % 4); const int32_t *colptr = col_bias; + const int32_t *perch_mul_ptr = qp.per_channel_muls; + const int32_t *perch_shift_ptr = qp.per_channel_shifts; const int32_t *in_ptr = input + (row * in_stride); int8_t *out_ptr = output + (row * out_stride); @@ -93,6 +101,33 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1); while (blocks--) { + int32x4_t v_mul0; + int32x4_t v_mul1; + int32x4_t v_mul2; + int32x4_t v_mul3; + + int32x4_t v_shf0; + int32x4_t v_shf1; + int32x4_t v_shf2; + int32x4_t v_shf3; + + if (per_channel) { + v_mul0 = vld1q_s32(perch_mul_ptr); + v_mul1 = vld1q_s32(perch_mul_ptr + 4); + v_mul2 = vld1q_s32(perch_mul_ptr + 8); + v_mul3 = vld1q_s32(perch_mul_ptr + 12); + perch_mul_ptr += 16; + + v_shf0 = vld1q_s32(perch_shift_ptr); + v_shf1 = vld1q_s32(perch_shift_ptr + 4); + v_shf2 = vld1q_s32(perch_shift_ptr + 8); + v_shf3 = vld1q_s32(perch_shift_ptr + 12); + perch_shift_ptr += 16; + } else { + v_mul0=v_mul1=v_mul2=v_mul3=v_mul; + v_shf0=v_shf1=v_shf2=v_shf3=v_shift; + } + // Load column pointers int32x4_t v_col0 = vld1q_s32(colptr); int32x4_t v_col1 = vld1q_s32(colptr + 4); @@ -136,27 +171,27 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in13 = vaddq_s32(v_in13, v_col3); // Quantize - start with multiply - v_in00 = vqrdmulhq_s32(v_in00, v_mul); - v_in01 = vqrdmulhq_s32(v_in01, v_mul); - v_in02 = vqrdmulhq_s32(v_in02, v_mul); - v_in03 = vqrdmulhq_s32(v_in03, v_mul); + v_in00 = vqrdmulhq_s32(v_in00, v_mul0); + v_in01 = vqrdmulhq_s32(v_in01, v_mul1); + v_in02 = vqrdmulhq_s32(v_in02, v_mul2); + v_in03 = vqrdmulhq_s32(v_in03, v_mul3); - v_in10 = vqrdmulhq_s32(v_in10, v_mul); - v_in11 = vqrdmulhq_s32(v_in11, v_mul); - v_in12 = vqrdmulhq_s32(v_in12, v_mul); - v_in13 = vqrdmulhq_s32(v_in13, v_mul); + v_in10 = vqrdmulhq_s32(v_in10, v_mul0); + v_in11 = vqrdmulhq_s32(v_in11, v_mul1); + v_in12 = vqrdmulhq_s32(v_in12, v_mul2); + v_in13 = vqrdmulhq_s32(v_in13, v_mul3); // Compute and add on corrective offset if (do_shift_correction) { - int32x4_t v_temp00 = vandq_s32(v_in00, v_shift); - int32x4_t v_temp01 = vandq_s32(v_in01, v_shift); - int32x4_t v_temp02 = vandq_s32(v_in02, v_shift); - int32x4_t v_temp03 = vandq_s32(v_in03, v_shift); + int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0); + int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1); + int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2); + int32x4_t v_temp03 = vandq_s32(v_in03, v_shf3); - int32x4_t v_temp10 = vandq_s32(v_in10, v_shift); - int32x4_t v_temp11 = vandq_s32(v_in11, v_shift); - int32x4_t v_temp12 = vandq_s32(v_in12, v_shift); - int32x4_t v_temp13 = vandq_s32(v_in13, v_shift); + int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0); + int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1); + int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2); + int32x4_t v_temp13 = vandq_s32(v_in13, v_shf3); v_temp00 = vshrq_n_s32(v_temp00, 31); v_temp01 = vshrq_n_s32(v_temp01, 31); @@ -179,15 +214,15 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in13 = vqaddq_s32(v_in13, v_temp13); } - v_in00 = vrshlq_s32(v_in00, v_shift); - v_in01 = vrshlq_s32(v_in01, v_shift); - v_in02 = vrshlq_s32(v_in02, v_shift); - v_in03 = vrshlq_s32(v_in03, v_shift); + v_in00 = vrshlq_s32(v_in00, v_shf0); + v_in01 = vrshlq_s32(v_in01, v_shf1); + v_in02 = vrshlq_s32(v_in02, v_shf2); + v_in03 = vrshlq_s32(v_in03, v_shf3); - v_in10 = vrshlq_s32(v_in10, v_shift); - v_in11 = vrshlq_s32(v_in11, v_shift); - v_in12 = vrshlq_s32(v_in12, v_shift); - v_in13 = vrshlq_s32(v_in13, v_shift); + v_in10 = vrshlq_s32(v_in10, v_shf0); + v_in11 = vrshlq_s32(v_in11, v_shf1); + v_in12 = vrshlq_s32(v_in12, v_shf2); + v_in13 = vrshlq_s32(v_in13, v_shf3); v_in00 = vaddq_s32(v_in00, v_c_offset); v_in01 = vaddq_s32(v_in01, v_c_offset); @@ -235,6 +270,20 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u } while (regs--) { + int32x4_t v_mul0; + int32x4_t v_shf0; + + if (per_channel) { + v_mul0 = vld1q_s32(perch_mul_ptr); + perch_mul_ptr += 4; + + v_shf0 = vld1q_s32(perch_shift_ptr); + perch_shift_ptr += 4; + } else { + v_mul0=v_mul; + v_shf0=v_shift; + } + // Load column pointers int32x4_t v_col0 = vld1q_s32(colptr); colptr += 4; @@ -258,15 +307,15 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in10 = vaddq_s32(v_in10, v_col0); // Quantize - start with multiply - v_in00 = vqrdmulhq_s32(v_in00, v_mul); + v_in00 = vqrdmulhq_s32(v_in00, v_mul0); - v_in10 = vqrdmulhq_s32(v_in10, v_mul); + v_in10 = vqrdmulhq_s32(v_in10, v_mul0); // Compute and add on corrective offset if (do_shift_correction) { - int32x4_t v_temp00 = vandq_s32(v_in00, v_shift); + int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0); - int32x4_t v_temp10 = vandq_s32(v_in10, v_shift); + int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0); v_temp00 = vshrq_n_s32(v_temp00, 31); @@ -277,9 +326,9 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in10 = vqaddq_s32(v_in10, v_temp10); } - v_in00 = vrshlq_s32(v_in00, v_shift); + v_in00 = vrshlq_s32(v_in00, v_shf0); - v_in10 = vrshlq_s32(v_in10, v_shift); + v_in10 = vrshlq_s32(v_in10, v_shf0); v_in00 = vaddq_s32(v_in00, v_c_offset); @@ -307,21 +356,40 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u int32x4_t v_col0 = vdupq_n_s32(0); int32x4_t v_in00 = vdupq_n_s32(0); int32x4_t v_in10 = vdupq_n_s32(0); + int32x4_t v_mul0 = vdupq_n_s32(0); + int32x4_t v_shf0 = vdupq_n_s32(0); + + if (!per_channel) { + v_mul0 = v_mul; + v_shf0 = v_shift; + } do { v_col0 = vld1q_lane_s32(colptr, v_col0, 0); v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0); v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0); + if (per_channel) { + v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0); + v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0); + } if (odds == 1) { break; } v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1); v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1); v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1); + if (per_channel) { + v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1); + v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1); + } if (odds == 2) { break; } v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2); v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2); v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2); + if (per_channel) { + v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2); + v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2); + } } while (0); // Add on row sum and bias constant @@ -335,15 +403,15 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in10 = vaddq_s32(v_in10, v_col0); // Quantize - start with multiply - v_in00 = vqrdmulhq_s32(v_in00, v_mul); + v_in00 = vqrdmulhq_s32(v_in00, v_mul0); - v_in10 = vqrdmulhq_s32(v_in10, v_mul); + v_in10 = vqrdmulhq_s32(v_in10, v_mul0); // Compute and add on corrective offset if (do_shift_correction) { - int32x4_t v_temp00 = vandq_s32(v_in00, v_shift); + int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0); - int32x4_t v_temp10 = vandq_s32(v_in10, v_shift); + int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0); v_temp00 = vshrq_n_s32(v_temp00, 31); @@ -354,9 +422,9 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in10 = vqaddq_s32(v_in10, v_temp10); } - v_in00 = vrshlq_s32(v_in00, v_shift); + v_in00 = vrshlq_s32(v_in00, v_shf0); - v_in10 = vrshlq_s32(v_in10, v_shift); + v_in10 = vrshlq_s32(v_in10, v_shf0); v_in00 = vaddq_s32(v_in00, v_c_offset); @@ -391,23 +459,33 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u } // anonymous namespace template -void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias) { - if (qp.minval >= qp.c_offset) { - requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, - reinterpret_cast(output), out_stride, row_bias, col_bias); + if (qp.per_channel_requant) { + if (qp.minval >= qp.c_offset) { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias); + } else { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias); + } } else { - requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, - reinterpret_cast(output), out_stride, row_bias, col_bias); + if (qp.minval >= qp.c_offset) { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias); + } else { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias); + } } } -template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias); -template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias); @@ -448,7 +526,7 @@ template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int wid */ namespace { struct row_sum_helpers { - const ARequantizeLayer32 &qp; + const Requantize32 &qp; /* Load a full 16 byte vector, pairwise accumulate into 'sum' with uadalp or sadalp */ template @@ -571,7 +649,7 @@ namespace { } } - row_sum_helpers(const ARequantizeLayer32 &qp) : qp(qp) { } + row_sum_helpers(const Requantize32 &qp) : qp(qp) { } }; template<> @@ -612,8 +690,14 @@ namespace { } template -void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *row_bias) { + /* If the 'b' offset is zero, just skip this entirely. */ + if (qp.b_offset == 0) { + memset(row_bias, 0, height * sizeof(int32_t)); + return; + } + row_sum_helpers thehelpers(qp); const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset); @@ -663,8 +747,8 @@ void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned } /* Instantiate the two versions for uint8_t and int8_t. */ -template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *); -template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *); +template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *); +template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *); template inline void add_block(const T *input, unsigned int in_stride, int32_t *output); @@ -739,41 +823,44 @@ inline void add_block(const int8_t *input, unsigned int in_stride, int32_t *outp * in cases where we are not computing the first columns of the output (i.e. * in multithreaded cases where we divide columns across threads) */ template -void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col) { - memset(reinterpret_cast(col_bias), 0, width * sizeof(int32_t)); - - for (unsigned int row=0; row(input + row * in_stride + col, in_stride, col_bias + col); - break; - - case 2: - add_block<2>(input + row * in_stride + col, in_stride, col_bias + col); - break; - - case 3: - add_block<3>(input + row * in_stride + col, in_stride, col_bias + col); - break; - - case 4: - add_block<4>(input + row * in_stride + col, in_stride, col_bias + col); - break; - } - } else { - for (; col(col_bias), 0, width * sizeof(int32_t)); + + for (unsigned int row=0; row(input + row * in_stride + col, in_stride, col_bias + col); + break; + + case 2: + add_block<2>(input + row * in_stride + col, in_stride, col_bias + col); + break; + + case 3: + add_block<3>(input + row * in_stride + col, in_stride, col_bias + col); + break; + + case 4: + add_block<4>(input + row * in_stride + col, in_stride, col_bias + col); + break; + } + } else { + for (; col -void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias); template -void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *row_bias); template -void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 65d800cb0c..4e43d04446 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -164,6 +164,23 @@ public: arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {}); + /** Set requantization shifts to be used + * + * @param[in] shifts Requantization shifts + * + * @return Pointer to the shift data + */ + /** Set requantization data to be used + * + * + * @param shifts Requantization shifts + * @param multipliers Requantization multipliers + * + * @return A tuple with the pointers to the shift and multiplier data respectively + */ + std::tuple set_requantize_data(const std::vector &shifts, + const std::vector &multipliers); + // Inherited methods overridden: void run() override; void prepare() override; @@ -212,8 +229,23 @@ private: FallbackTransform _weights_transform{}; /** GEMM kernel description */ arm_gemm::KernelDescription _kernel_info{}; + /** Per channel quantization shifts */ + std::vector _shifts{}; + /** Per channel quantization multipliers */ + std::vector _multipliers{}; }; +template +std::tuple Fallback::set_requantize_data(const std::vector &shifts, + const std::vector &multipliers) +{ + _multipliers = multipliers; + _shifts = shifts; + std::transform(_shifts.begin(), _shifts.end(), _shifts.begin(), + std::bind1st(std::multiplies(), -1)); + return std::make_tuple(_shifts.data(), _multipliers.data()); +} + template void Fallback::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, @@ -435,18 +467,32 @@ void create_arm_gemm_quant(std::unique_ptr &a arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B()); + // Create arm_gemm fallback + auto fallback = support::cpp14::make_unique>(); + // Configure requantization info const int32_t a_offset = -a->info()->quantization_info().uniform().offset; const int32_t b_offset = -b->info()->quantization_info().uniform().offset; const GEMMLowpOutputStageInfo os_info = gemm_info.gemmlowp_output_stage(); - const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr, 0, - a_offset, b_offset, os_info.gemmlowp_offset, - -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier, - os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); + arm_gemm::Requantize32 gemm_requant_info{}; + if(os_info.gemmlowp_shifts.size() > 1) + { + const auto requantize_data = fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers); + gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, + a_offset, b_offset, os_info.gemmlowp_offset, + std::get<0>(requantize_data), std::get<1>(requantize_data), + os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); + } + else + { + gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, + a_offset, b_offset, os_info.gemmlowp_offset, + -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier, + os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); + } - // Create arm_gemm fallback - auto fallback = support::cpp14::make_unique>(); + // Configure fallback fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info); arm_gemm = std::move(fallback); } @@ -484,7 +530,6 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8_SIGNED && d->data_type() != DataType::S32, "Only S32 output supported for QASYMM8_SIGNED input"); return Status{}; } @@ -524,7 +569,14 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const break; case DataType::S8: case DataType::QASYMM8_SIGNED: - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + if(d->info()->data_type() == DataType::S32) + { + create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + } + else + { + create_arm_gemm_quant(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager); + } break; #endif /* __aarch64__ */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp index 440f043527..38481afe88 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -119,7 +119,7 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, case DataType::U8: case DataType::S8: { - if(a_to_use->info()->data_type() == DataType::QASYMM8 && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + if(is_data_type_quantized_asymmetric(a_to_use->info()->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) { _asm_glue.configure(a_to_use, b, c, output, gemm_info); _fused_assembly_path = _asm_glue.is_configured(); -- cgit v1.2.1