aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2019-11-14 14:31:44 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-01-23 14:57:14 +0000
commit71ac9037abce1c6c4af42c485d5395dd6fd79a5a (patch)
tree7097d94d7760bf8e172fc4c3725a2eff90bea9a1
parent19bd412fd044197726dbd8c756dbd74a9e33fd2b (diff)
downloadComputeLibrary-71ac9037abce1c6c4af42c485d5395dd6fd79a5a.tar.gz
COMPMID-2923 Integrate arm_gemm per channel quantization
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Change-Id: I8667e75843fdd6ac75bd8272a86a348b830da28d Reviewed-on: https://review.mlplatform.org/c/2548 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp52
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp12
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp4
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp105
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp36
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8.hpp89
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8/generic.cpp1923
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/a55.cpp375
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp321
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp4
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantized.cpp265
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantized.hpp6
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp70
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp4
15 files changed, 2887 insertions, 381 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
index d51fda525b..e89523981d 100644
--- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
+++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
@@ -108,23 +108,45 @@ public:
}
};
-struct ARequantizeLayer32
+struct Requantize32
{
public:
- const int32_t *bias;
- size_t bias_multi_stride;
- int32_t a_offset;
- int32_t b_offset;
- int32_t c_offset;
- int32_t requant_shift;
- int32_t requant_mul;
- int32_t minval;
- int32_t maxval;
-
- ARequantizeLayer32() = default;
-
- ARequantizeLayer32(const int32_t *b, size_t bms, int32_t ao, int32_t bo, int32_t co, int32_t rs, int32_t rm, int32_t minv, int32_t maxv) :
- bias(b), bias_multi_stride(bms), a_offset(ao), b_offset(bo), c_offset(co), requant_shift(rs), requant_mul(rm), minval(minv), maxval(maxv)
+ const int32_t *bias = nullptr;
+ size_t bias_multi_stride = 0;
+ int32_t a_offset = 0;
+ int32_t b_offset = 0;
+ int32_t c_offset = 0;
+ bool per_channel_requant = false;
+ int32_t per_layer_shift = 0;
+ int32_t per_layer_mul = 0;
+ const int32_t *per_channel_shifts = nullptr;
+ const int32_t *per_channel_muls = nullptr;
+ int32_t minval = 0;
+ int32_t maxval = 0;
+
+ Requantize32() = default;
+
+ // Constructor for per-tensor quantization
+ Requantize32(const int32_t *bias, size_t bias_multi_stride,
+ int32_t a_offset, int32_t b_offset, int32_t c_offset,
+ int32_t requant_shift, int32_t requant_mul,
+ int32_t minv, int32_t maxv) :
+ bias(bias), bias_multi_stride(bias_multi_stride),
+ a_offset(a_offset), b_offset(b_offset), c_offset(c_offset),
+ per_channel_requant(false), per_layer_shift(requant_shift), per_layer_mul(requant_mul),
+ minval(minv), maxval(maxv)
+ {
+ }
+
+ // Constructor for per-channel quantization
+ Requantize32(const int32_t *bias, size_t bias_multi_stride,
+ int32_t a_offset, int32_t b_offset, int32_t c_offset,
+ const int32_t *requant_shifts, const int32_t *requant_muls,
+ int32_t minv, int32_t maxv) :
+ bias(bias), bias_multi_stride(bias_multi_stride),
+ a_offset(a_offset), b_offset(b_offset), c_offset(c_offset),
+ per_channel_requant(true), per_channel_shifts(requant_shifts), per_channel_muls(requant_muls),
+ minval(minv), maxval(maxv)
{
}
};
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index cf91ee0652..7f171ec15a 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -33,6 +33,7 @@
#include "kernels/a32_sgemm_8x6.hpp"
#include "kernels/a64_hybrid_fp32_mla_16x4.hpp"
+#include "kernels/a64_hybrid_fp32_mla_4x8.hpp"
#include "kernels/a64_native_fp32_mla_16x4.hpp"
#include "kernels/a64_smallK_hybrid_fp32_mla_4x6.hpp"
#include "kernels/a64_smallK_hybrid_fp32_mla_4x8.hpp"
@@ -106,9 +107,16 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] =
},
{
GemmMethod::GEMM_HYBRID,
+ "hybrid_fp32_mla_4x8_normal",
+ [](const GemmArgs &args) { return (args._Ksize >= 4) && !args._trA && args._pretransposed_hint; },
+ [](const GemmArgs &args) { return (args._Nsize < 12); },
+ [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_4x8, float, float>(args); }
+},
+{
+ GemmMethod::GEMM_HYBRID,
"hybrid_fp32_mla_16x4",
[](const GemmArgs &args) { return (args._Ksize >= 4) && !args._trA && args._pretransposed_hint; },
- [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
+ [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || (args._Msize < 16) || (args._nmulti > 1); },
[](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_16x4, float, float>(args); }
},
{
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
index 574ecef5b2..22b6960baf 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
@@ -68,7 +68,7 @@ class GemmHybridQuantized : public GemmCommon<To, Tr> {
const NDRange<4> _window_range;
- ARequantizeLayer32 _qp;
+ Requantize32 _qp;
int32_t *row_bias = nullptr;
int32_t *col_bias = nullptr;
@@ -140,7 +140,7 @@ public:
GemmHybridQuantized & operator= (GemmHybridQuantized &) = delete;
/* Constructor */
- GemmHybridQuantized(const GemmArgs &args, const ARequantizeLayer32 &qp)
+ GemmHybridQuantized(const GemmArgs &args, const Requantize32 &qp)
: _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
_nbatches(args._nbatches), _nmulti(args._nmulti), _trB(args._trB),
_k_block(compute_k_block(args)), _n_block(compute_n_block(args)),
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
new file mode 100644
index 0000000000..73d0c272a6
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2019 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __aarch64__
+
+#include "arm_gemm.hpp"
+
+#include "kernels/a64_hybrid_s8s32_dot_16x4.hpp"
+#include "kernels/a64_smallK_hybrid_s8s32_dot_4x6.hpp"
+#include "kernels/a64_smallK_hybrid_s8s32_dot_4x8.hpp"
+#include "kernels/sve_hybrid_s8s32_dot_4VLx4.hpp"
+#include "kernels/sve_smallK_hybrid_s8s32_dot_1VLx8.hpp"
+
+#include "gemm_hybrid_quantized.hpp"
+#include "quantize_wrapper.hpp"
+
+namespace arm_gemm {
+
+static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods[] =
+{
+#ifdef __ARM_FEATURE_SVE
+{
+ GemmMethod::GEMM_HYBRID_QUANTIZED,
+ "smallK_hybrid_s8s32_dot_1VLx8",
+ [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64 && !args._trA && args._pretransposed_hint; },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_s8s32_dot_1VLx8, int8_t, int8_t>(args, qp); }
+},
+{
+ GemmMethod::GEMM_HYBRID_QUANTIZED,
+ "hybrid_s8s32_dot_4VLx4",
+ [](const GemmArgs &args, const Requantize32 &) { return args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; },
+ [](const GemmArgs &args, const Requantize32 &) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_s8s32_dot_4VLx4, int8_t, int8_t>(args, qp); }
+},
+#endif
+{
+ GemmMethod::GEMM_HYBRID_QUANTIZED,
+ "smallK_hybrid_s8s32_dot_4x8",
+ [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._trA && args._pretransposed_hint; },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_s8s32_dot_4x8, int8_t, int8_t>(args, qp); }
+},
+{
+ GemmMethod::GEMM_HYBRID_QUANTIZED,
+ "smallK_hybrid_s8s32_dot_4x6",
+ [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._trA && args._pretransposed_hint; },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_s8s32_dot_4x6, int8_t, int8_t>(args, qp); }
+},
+{
+ GemmMethod::GEMM_HYBRID_QUANTIZED,
+ "hybrid_s8s32_dot_16x4",
+ [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; },
+ [](const GemmArgs &args, const Requantize32 &) { return args._Nsize<=256 && args._Ksize>128; },
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_s8s32_dot_16x4, int8_t, int8_t>(args, qp); }
+},
+{
+ GemmMethod::QUANTIZE_WRAPPER,
+ "quantized_wrapper",
+ nullptr,
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<int8_t, int8_t, int32_t>(args, qp); }
+},
+{
+ GemmMethod::DEFAULT,
+ "",
+ nullptr,
+ nullptr,
+ nullptr
+}
+};
+
+template<>
+const GemmImplementation<int8_t, int8_t, Requantize32> *gemm_implementation_list<int8_t, int8_t, Requantize32>() {
+ return gemm_qint8_methods;
+}
+
+template UniqueGemmCommon<int8_t, int8_t> gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template KernelDescription get_gemm_method<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template std::vector<KernelDescription> get_compatible_kernels<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+
+} // namespace arm_gemm
+
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index 079c04ae06..59cd1704ff 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -36,51 +36,51 @@
namespace arm_gemm {
-static const GemmImplementation<uint8_t, uint8_t, ARequantizeLayer32> gemm_quint8_methods[] =
+static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_methods[] =
{
#ifdef __ARM_FEATURE_SVE
{
GemmMethod::GEMM_HYBRID_QUANTIZED,
"smallK_hybrid_u8u32_dot_1VLx8",
- [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._Ksize<=64 && !args._trA && args._pretransposed_hint; },
+ [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64 && !args._trA && args._pretransposed_hint; },
nullptr,
- [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_1VLx8, uint8_t, uint8_t>(args, qp); }
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_1VLx8, uint8_t, uint8_t>(args, qp); }
},
{
GemmMethod::GEMM_HYBRID_QUANTIZED,
"hybrid_u8u32_dot_4VLx4",
- [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; },
- [](const GemmArgs &args, const ARequantizeLayer32 &) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
- [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized<hybrid_u8u32_dot_4VLx4, uint8_t, uint8_t>(args, qp); }
+ [](const GemmArgs &args, const Requantize32 &) { return args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; },
+ [](const GemmArgs &args, const Requantize32 &) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_u8u32_dot_4VLx4, uint8_t, uint8_t>(args, qp); }
},
#endif
{
GemmMethod::GEMM_HYBRID_QUANTIZED,
"smallK_hybrid_u8u32_dot_4x8",
- [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._trA && args._pretransposed_hint; },
+ [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._trA && args._pretransposed_hint; },
nullptr,
- [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_4x8, uint8_t, uint8_t>(args, qp); }
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_4x8, uint8_t, uint8_t>(args, qp); }
},
{
GemmMethod::GEMM_HYBRID_QUANTIZED,
"smallK_hybrid_u8u32_dot_4x6",
- [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._trA && args._pretransposed_hint; },
+ [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._trA && args._pretransposed_hint; },
nullptr,
- [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_4x6, uint8_t, uint8_t>(args, qp); }
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_4x6, uint8_t, uint8_t>(args, qp); }
},
{
GemmMethod::GEMM_HYBRID_QUANTIZED,
"hybrid_u8u32_dot_16x4",
- [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._ci->has_dotprod() && args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; },
- [](const GemmArgs &args, const ARequantizeLayer32 &) { return args._Nsize<=256 && args._Ksize>128; },
- [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new GemmHybridQuantized<hybrid_u8u32_dot_16x4, uint8_t, uint8_t>(args, qp); }
+ [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && args._Ksize>=16 && !args._trA && !args._trB && args._pretransposed_hint; },
+ [](const GemmArgs &args, const Requantize32 &) { return args._Nsize<=256 && args._Ksize>128; },
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_u8u32_dot_16x4, uint8_t, uint8_t>(args, qp); }
},
{
GemmMethod::QUANTIZE_WRAPPER,
"quantized_wrapper",
nullptr,
nullptr,
- [](const GemmArgs &args, const ARequantizeLayer32 &qp) { return new QuantizeWrapper<uint8_t, uint8_t, uint32_t>(args, qp); }
+ [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<uint8_t, uint8_t, uint32_t>(args, qp); }
},
{
GemmMethod::DEFAULT,
@@ -92,13 +92,13 @@ static const GemmImplementation<uint8_t, uint8_t, ARequantizeLayer32> gemm_quint
};
template<>
-const GemmImplementation<uint8_t, uint8_t, ARequantizeLayer32> *gemm_implementation_list<uint8_t, uint8_t, ARequantizeLayer32>() {
+const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_list<uint8_t, uint8_t, Requantize32>() {
return gemm_quint8_methods;
}
-template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, ARequantizeLayer32>(const GemmArgs &args, const ARequantizeLayer32 &os);
-template KernelDescription get_gemm_method<uint8_t, uint8_t, ARequantizeLayer32>(const GemmArgs &args, const ARequantizeLayer32 &os);
-template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, ARequantizeLayer32>(const GemmArgs &args, const ARequantizeLayer32 &os);
+template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template KernelDescription get_gemm_method<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8.hpp
new file mode 100644
index 0000000000..da5beef48c
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8.hpp
@@ -0,0 +1,89 @@
+/*
+ * Copyright (c) 2018-2019 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#ifdef __aarch64__
+
+
+#include "../std_transforms_fixed.hpp"
+
+namespace arm_gemm
+{
+
+// Actual kernel implementations
+void a64_hybrid_fp32_mla_4x8(const float *, int, const float *, float *, int, int, int, int, const float *, Activation, bool);
+
+class hybrid_fp32_mla_4x8
+{
+public:
+ typedef float operand_type;
+ typedef float result_type;
+
+ typedef void (*kern_type)(const float *, int, const float *, float *, int, int, int, int, const float *, Activation, bool);
+
+ /* Kernel blocking parameters */
+ static constexpr unsigned int out_height()
+ {
+ return 8;
+ }
+
+ static unsigned int out_width()
+ {
+ return 4;
+ }
+
+ static constexpr unsigned int k_unroll()
+ {
+ return 1;
+ }
+
+ static constexpr bool supports_append()
+ {
+ return false;
+ }
+
+ static constexpr bool supports_bias()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_activation()
+ {
+ return true;
+ }
+
+ StdTransformsFixed<operand_type, result_type, 8, 4, 1> transforms = {};
+
+ // Default to the generic kernel
+ kern_type kernel=a64_hybrid_fp32_mla_4x8;
+
+ hybrid_fp32_mla_4x8(const CPUInfo *ci)
+ {
+ UNUSED(ci);
+ }
+};
+
+} // namespace arm_gemm
+
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8/generic.cpp
new file mode 100644
index 0000000000..db7eb83160
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x8/generic.cpp
@@ -0,0 +1,1923 @@
+/*
+ * Copyright (c) 2018-2019 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __aarch64__
+
+#include <algorithm>
+
+#include "arm_gemm.hpp"
+
+#include "../../asmlib.hpp"
+#include "../../utils.hpp"
+
+namespace arm_gemm {
+
+void a64_hybrid_fp32_mla_4x8(const float *A, int lda, const float *B, float *C, int ldc, int M, int N, int K, const float *bias, Activation act, bool append) {
+ const int K_stride = K;
+ const long loops_count = ((K + 4) / 8) - 1;
+ K -= loops_count * 8;
+ const long regs_count = (K / 4) - 1;
+ K -= (regs_count + 1) * 4;
+ const long blocks_count = K / 1;
+ float nullbias[4];
+ if (!append && !bias) {
+ memset(nullbias, 0, (4 * sizeof(float)));
+ }
+ float minval = - static_cast<float>(std::numeric_limits<float>::infinity());
+ float maxval = static_cast<float>(std::numeric_limits<float>::infinity());
+ const float * const minptr = &minval;
+ const float * const maxptr = &maxval;
+
+ switch(act.type)
+ {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ maxval = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ minval = 0.0f;
+ break;
+ }
+
+ for (int y=0; y<M; y+=8) {
+ const float * const a_ptr0_base = A + (y * lda);
+ const unsigned long ldab = lda * sizeof(float);
+
+ float *c_ptr0 = C + (y * ldc);
+
+ for (int x0=0; x0<N; x0+=4ul) {
+ const long width = std::min((unsigned long)N-x0, 4ul);
+ long loops = loops_count;
+ long regs = regs_count;
+ long blocks = blocks_count;
+ const float *a_ptr0 = a_ptr0_base;
+ const float *b_ptr0 = B + (K_stride * x0);
+ const bool use_result_buffer = (width < 4);
+ float result_buffer[32];
+ const unsigned long ldcb = (use_result_buffer ? 4 : ldc) * sizeof(float);
+ float *c_ptr_real = c_ptr0;
+ if (use_result_buffer && append) {
+ for(int cy=0; cy<std::min(M-y, 8); cy++) {
+ for(unsigned int cx=0; cx<width; cx++) {
+ result_buffer[cy * 4 + cx] = c_ptr_real[cy * ldc + cx];
+ }
+ }
+ }
+ if (use_result_buffer) {
+ c_ptr0 = result_buffer;
+ }
+ const float *biasptr = bias ? bias+x0 : nullbias;
+
+ switch(M-y) {
+ case 1:
+ __asm __volatile (
+ "ldr q24, [%[biasptr]]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "cbz %[loops], 1f\n"
+ "2:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "subs %[loops], %[loops], #0x1\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x20\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "ldr q0, [%[a_ptr0], #-0x10]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "ldr q16, [%[b_ptr0], #0x40]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x50]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x60]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "b.ne 2b\n"
+ "1:\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "prfm PSTL1KEEP, [%[c_ptr0]]\n"
+ "cbz %[regs], 3f\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "b 4f\n"
+ "3:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "4:\n"
+ "cbz %[blocks], 5f\n"
+ "6:\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "subs %[blocks], %[blocks], #0x1\n"
+ "add %[b_ptr0], %[b_ptr0], #0x10\n"
+ "ldr s0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x4\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "b.ne 6b\n"
+ "5:\n"
+ "ld1r {v22.4s}, [%[minptr]]\n"
+ "ld1r {v23.4s}, [%[maxptr]]\n"
+ "fmax v24.4s, v24.4s, v22.4s\n"
+ "fmin v24.4s, v24.4s, v23.4s\n"
+ "str q24, [%[c_ptr0]]\n"
+ "add %[c_ptr0], %[c_ptr0], #0x10\n"
+ : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks)
+ : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"
+ );
+ break;
+ case 2:
+ __asm __volatile (
+ "a_ptr1 .req X0\n"
+ "c_ptr1 .req X1\n"
+ "ldr q24, [%[biasptr]]\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "mov v25.16b, v24.16b\n"
+ "ldr q1, [a_ptr1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "cbz %[loops], 1f\n"
+ "2:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "ldr q9, [a_ptr1]\n"
+ "subs %[loops], %[loops], #0x1\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x20\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "add a_ptr1, a_ptr1, #0x20\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "prfm PLDL1KEEP, [a_ptr1, #0x40]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "ldr q0, [%[a_ptr0], #-0x10]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "ldr q1, [a_ptr1, #-0x10]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "ldr q16, [%[b_ptr0], #0x40]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x50]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x60]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "b.ne 2b\n"
+ "1:\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "prfm PSTL1KEEP, [%[c_ptr0]]\n"
+ "prfm PSTL1KEEP, [c_ptr1]\n"
+ "cbz %[regs], 3f\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "b 4f\n"
+ "3:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "4:\n"
+ "cbz %[blocks], 5f\n"
+ "6:\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "subs %[blocks], %[blocks], #0x1\n"
+ "add %[b_ptr0], %[b_ptr0], #0x10\n"
+ "ldr s0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x4\n"
+ "ldr s1, [a_ptr1]\n"
+ "add a_ptr1, a_ptr1, #0x4\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "b.ne 6b\n"
+ "5:\n"
+ "ld1r {v22.4s}, [%[minptr]]\n"
+ "ld1r {v23.4s}, [%[maxptr]]\n"
+ "fmax v24.4s, v24.4s, v22.4s\n"
+ "fmax v25.4s, v25.4s, v22.4s\n"
+ "fmin v24.4s, v24.4s, v23.4s\n"
+ "fmin v25.4s, v25.4s, v23.4s\n"
+ "str q24, [%[c_ptr0]]\n"
+ "add %[c_ptr0], %[c_ptr0], #0x10\n"
+ "str q25, [c_ptr1]\n"
+ ".unreq a_ptr1\n"
+ ".unreq c_ptr1\n"
+ : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks)
+ : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "cc", "memory"
+ );
+ break;
+ case 3:
+ __asm __volatile (
+ "a_ptr1 .req X0\n"
+ "a_ptr2 .req X1\n"
+ "c_ptr1 .req X2\n"
+ "c_ptr2 .req X3\n"
+ "ldr q24, [%[biasptr]]\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add a_ptr2, a_ptr1, %[lda]\n"
+ "mov v25.16b, v24.16b\n"
+ "ldr q1, [a_ptr1]\n"
+ "mov v26.16b, v24.16b\n"
+ "ldr q2, [a_ptr2]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "add c_ptr2, c_ptr1, %[ldc]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "cbz %[loops], 1f\n"
+ "2:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "ldr q10, [a_ptr2]\n"
+ "subs %[loops], %[loops], #0x1\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x20\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "add a_ptr1, a_ptr1, #0x20\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "ldr q0, [%[a_ptr0], #-0x10]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "ldr q1, [a_ptr1, #-0x10]\n"
+ "add a_ptr2, a_ptr2, #0x20\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "ldr q2, [a_ptr2, #-0x10]\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "prfm PLDL1KEEP, [a_ptr1, #0x40]\n"
+ "fmla v26.4s, v16.4s, v10.s[0]\n"
+ "ldr q16, [%[b_ptr0], #0x40]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "prfm PLDL1KEEP, [a_ptr2, #0x40]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "fmla v26.4s, v17.4s, v10.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x50]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v26.4s, v18.4s, v10.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x60]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "fmla v26.4s, v19.4s, v10.s[3]\n"
+ "b.ne 2b\n"
+ "1:\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "prfm PSTL1KEEP, [%[c_ptr0]]\n"
+ "prfm PSTL1KEEP, [c_ptr1]\n"
+ "prfm PSTL1KEEP, [c_ptr2]\n"
+ "cbz %[regs], 3f\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "ldr q10, [a_ptr2]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "fmla v26.4s, v16.4s, v10.s[0]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "fmla v26.4s, v17.4s, v10.s[1]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v26.4s, v18.4s, v10.s[2]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "fmla v26.4s, v19.4s, v10.s[3]\n"
+ "b 4f\n"
+ "3:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "4:\n"
+ "cbz %[blocks], 5f\n"
+ "6:\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "subs %[blocks], %[blocks], #0x1\n"
+ "add %[b_ptr0], %[b_ptr0], #0x10\n"
+ "ldr s0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x4\n"
+ "ldr s1, [a_ptr1]\n"
+ "add a_ptr1, a_ptr1, #0x4\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr s2, [a_ptr2]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "add a_ptr2, a_ptr2, #0x4\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "b.ne 6b\n"
+ "5:\n"
+ "ld1r {v22.4s}, [%[minptr]]\n"
+ "ld1r {v23.4s}, [%[maxptr]]\n"
+ "fmax v24.4s, v24.4s, v22.4s\n"
+ "fmax v25.4s, v25.4s, v22.4s\n"
+ "fmax v26.4s, v26.4s, v22.4s\n"
+ "fmin v24.4s, v24.4s, v23.4s\n"
+ "fmin v25.4s, v25.4s, v23.4s\n"
+ "fmin v26.4s, v26.4s, v23.4s\n"
+ "str q24, [%[c_ptr0]]\n"
+ "add %[c_ptr0], %[c_ptr0], #0x10\n"
+ "str q25, [c_ptr1]\n"
+ "str q26, [c_ptr2]\n"
+ ".unreq a_ptr1\n"
+ ".unreq a_ptr2\n"
+ ".unreq c_ptr1\n"
+ ".unreq c_ptr2\n"
+ : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks)
+ : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "cc", "memory"
+ );
+ break;
+ case 4:
+ __asm __volatile (
+ "a_ptr1 .req X0\n"
+ "a_ptr2 .req X1\n"
+ "a_ptr3 .req X2\n"
+ "c_ptr1 .req X3\n"
+ "c_ptr2 .req X4\n"
+ "c_ptr3 .req X5\n"
+ "ldr q24, [%[biasptr]]\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add a_ptr2, a_ptr1, %[lda]\n"
+ "mov v25.16b, v24.16b\n"
+ "ldr q1, [a_ptr1]\n"
+ "mov v26.16b, v24.16b\n"
+ "ldr q2, [a_ptr2]\n"
+ "mov v27.16b, v24.16b\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "add a_ptr3, a_ptr2, %[lda]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "ldr q3, [a_ptr3]\n"
+ "add c_ptr2, c_ptr1, %[ldc]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "add c_ptr3, c_ptr2, %[ldc]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "cbz %[loops], 1f\n"
+ "2:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "ldr q10, [a_ptr2]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q11, [a_ptr3]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "subs %[loops], %[loops], #0x1\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x20\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "add a_ptr1, a_ptr1, #0x20\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "ldr q0, [%[a_ptr0], #-0x10]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "ldr q1, [a_ptr1, #-0x10]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "add a_ptr2, a_ptr2, #0x20\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "ldr q2, [a_ptr2, #-0x10]\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "add a_ptr3, a_ptr3, #0x20\n"
+ "fmla v26.4s, v16.4s, v10.s[0]\n"
+ "ldr q3, [a_ptr3, #-0x10]\n"
+ "fmla v27.4s, v16.4s, v11.s[0]\n"
+ "ldr q16, [%[b_ptr0], #0x40]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "prfm PLDL1KEEP, [a_ptr1, #0x40]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "prfm PLDL1KEEP, [a_ptr2, #0x40]\n"
+ "fmla v26.4s, v17.4s, v10.s[1]\n"
+ "prfm PLDL1KEEP, [a_ptr3, #0x40]\n"
+ "fmla v27.4s, v17.4s, v11.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x50]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v26.4s, v18.4s, v10.s[2]\n"
+ "fmla v27.4s, v18.4s, v11.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x60]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "fmla v26.4s, v19.4s, v10.s[3]\n"
+ "fmla v27.4s, v19.4s, v11.s[3]\n"
+ "b.ne 2b\n"
+ "1:\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "prfm PSTL1KEEP, [%[c_ptr0]]\n"
+ "prfm PSTL1KEEP, [c_ptr1]\n"
+ "prfm PSTL1KEEP, [c_ptr2]\n"
+ "prfm PSTL1KEEP, [c_ptr3]\n"
+ "cbz %[regs], 3f\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "ldr q10, [a_ptr2]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "ldr q11, [a_ptr3]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "fmla v26.4s, v16.4s, v10.s[0]\n"
+ "fmla v27.4s, v16.4s, v11.s[0]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "fmla v26.4s, v17.4s, v10.s[1]\n"
+ "fmla v27.4s, v17.4s, v11.s[1]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v26.4s, v18.4s, v10.s[2]\n"
+ "fmla v27.4s, v18.4s, v11.s[2]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "fmla v26.4s, v19.4s, v10.s[3]\n"
+ "fmla v27.4s, v19.4s, v11.s[3]\n"
+ "b 4f\n"
+ "3:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "4:\n"
+ "cbz %[blocks], 5f\n"
+ "6:\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "subs %[blocks], %[blocks], #0x1\n"
+ "add %[b_ptr0], %[b_ptr0], #0x10\n"
+ "ldr s0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x4\n"
+ "ldr s1, [a_ptr1]\n"
+ "add a_ptr1, a_ptr1, #0x4\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr s2, [a_ptr2]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "add a_ptr2, a_ptr2, #0x4\n"
+ "ldr s3, [a_ptr3]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "add a_ptr3, a_ptr3, #0x4\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "b.ne 6b\n"
+ "5:\n"
+ "ld1r {v22.4s}, [%[minptr]]\n"
+ "ld1r {v23.4s}, [%[maxptr]]\n"
+ "fmax v24.4s, v24.4s, v22.4s\n"
+ "fmax v25.4s, v25.4s, v22.4s\n"
+ "fmax v26.4s, v26.4s, v22.4s\n"
+ "fmax v27.4s, v27.4s, v22.4s\n"
+ "fmin v24.4s, v24.4s, v23.4s\n"
+ "fmin v25.4s, v25.4s, v23.4s\n"
+ "fmin v26.4s, v26.4s, v23.4s\n"
+ "fmin v27.4s, v27.4s, v23.4s\n"
+ "str q24, [%[c_ptr0]]\n"
+ "add %[c_ptr0], %[c_ptr0], #0x10\n"
+ "str q25, [c_ptr1]\n"
+ "str q26, [c_ptr2]\n"
+ "str q27, [c_ptr3]\n"
+ ".unreq a_ptr1\n"
+ ".unreq a_ptr2\n"
+ ".unreq a_ptr3\n"
+ ".unreq c_ptr1\n"
+ ".unreq c_ptr2\n"
+ ".unreq c_ptr3\n"
+ : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks)
+ : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc", "memory"
+ );
+ break;
+ case 5:
+ __asm __volatile (
+ "a_ptr1 .req X0\n"
+ "a_ptr2 .req X1\n"
+ "a_ptr3 .req X2\n"
+ "a_ptr4 .req X3\n"
+ "c_ptr1 .req X4\n"
+ "c_ptr2 .req X5\n"
+ "c_ptr3 .req X6\n"
+ "c_ptr4 .req X7\n"
+ "ldr q24, [%[biasptr]]\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add a_ptr2, a_ptr1, %[lda]\n"
+ "mov v25.16b, v24.16b\n"
+ "ldr q1, [a_ptr1]\n"
+ "mov v26.16b, v24.16b\n"
+ "ldr q2, [a_ptr2]\n"
+ "mov v27.16b, v24.16b\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "mov v28.16b, v24.16b\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "add a_ptr3, a_ptr2, %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "ldr q3, [a_ptr3]\n"
+ "add a_ptr4, a_ptr3, %[lda]\n"
+ "add c_ptr2, c_ptr1, %[ldc]\n"
+ "ldr q4, [a_ptr4]\n"
+ "add c_ptr3, c_ptr2, %[ldc]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "add c_ptr4, c_ptr3, %[ldc]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
+ "add a_ptr4, a_ptr4, #0x10\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "cbz %[loops], 1f\n"
+ "2:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "ldr q10, [a_ptr2]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "ldr q11, [a_ptr3]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q12, [a_ptr4]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "subs %[loops], %[loops], #0x1\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n"
+ "fmla v28.4s, v17.4s, v4.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x20\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "add a_ptr1, a_ptr1, #0x20\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "add a_ptr2, a_ptr2, #0x20\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "add a_ptr3, a_ptr3, #0x20\n"
+ "fmla v28.4s, v18.4s, v4.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "ldr q0, [%[a_ptr0], #-0x10]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "ldr q1, [a_ptr1, #-0x10]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "ldr q2, [a_ptr2, #-0x10]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "ldr q3, [a_ptr3, #-0x10]\n"
+ "fmla v28.4s, v19.4s, v4.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "add a_ptr4, a_ptr4, #0x20\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "ldr q4, [a_ptr4, #-0x10]\n"
+ "fmla v26.4s, v16.4s, v10.s[0]\n"
+ "prfm PLDL1KEEP, [a_ptr1, #0x40]\n"
+ "fmla v27.4s, v16.4s, v11.s[0]\n"
+ "prfm PLDL1KEEP, [a_ptr2, #0x40]\n"
+ "fmla v28.4s, v16.4s, v12.s[0]\n"
+ "ldr q16, [%[b_ptr0], #0x40]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "prfm PLDL1KEEP, [a_ptr3, #0x40]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "fmla v26.4s, v17.4s, v10.s[1]\n"
+ "fmla v27.4s, v17.4s, v11.s[1]\n"
+ "fmla v28.4s, v17.4s, v12.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x50]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v26.4s, v18.4s, v10.s[2]\n"
+ "fmla v27.4s, v18.4s, v11.s[2]\n"
+ "fmla v28.4s, v18.4s, v12.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x60]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "fmla v26.4s, v19.4s, v10.s[3]\n"
+ "fmla v27.4s, v19.4s, v11.s[3]\n"
+ "fmla v28.4s, v19.4s, v12.s[3]\n"
+ "b.ne 2b\n"
+ "1:\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "prfm PSTL1KEEP, [%[c_ptr0]]\n"
+ "prfm PSTL1KEEP, [c_ptr1]\n"
+ "prfm PSTL1KEEP, [c_ptr2]\n"
+ "prfm PSTL1KEEP, [c_ptr3]\n"
+ "prfm PSTL1KEEP, [c_ptr4]\n"
+ "cbz %[regs], 3f\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "ldr q10, [a_ptr2]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "ldr q11, [a_ptr3]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "ldr q12, [a_ptr4]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "fmla v28.4s, v17.4s, v4.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "add a_ptr4, a_ptr4, #0x10\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "fmla v28.4s, v18.4s, v4.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "fmla v28.4s, v19.4s, v4.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "fmla v26.4s, v16.4s, v10.s[0]\n"
+ "fmla v27.4s, v16.4s, v11.s[0]\n"
+ "fmla v28.4s, v16.4s, v12.s[0]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "fmla v26.4s, v17.4s, v10.s[1]\n"
+ "fmla v27.4s, v17.4s, v11.s[1]\n"
+ "fmla v28.4s, v17.4s, v12.s[1]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v26.4s, v18.4s, v10.s[2]\n"
+ "fmla v27.4s, v18.4s, v11.s[2]\n"
+ "fmla v28.4s, v18.4s, v12.s[2]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "fmla v26.4s, v19.4s, v10.s[3]\n"
+ "fmla v27.4s, v19.4s, v11.s[3]\n"
+ "fmla v28.4s, v19.4s, v12.s[3]\n"
+ "b 4f\n"
+ "3:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "fmla v28.4s, v17.4s, v4.s[1]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "fmla v28.4s, v18.4s, v4.s[2]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "fmla v28.4s, v19.4s, v4.s[3]\n"
+ "4:\n"
+ "cbz %[blocks], 5f\n"
+ "6:\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "subs %[blocks], %[blocks], #0x1\n"
+ "add %[b_ptr0], %[b_ptr0], #0x10\n"
+ "ldr s0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x4\n"
+ "ldr s1, [a_ptr1]\n"
+ "add a_ptr1, a_ptr1, #0x4\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr s2, [a_ptr2]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "add a_ptr2, a_ptr2, #0x4\n"
+ "ldr s3, [a_ptr3]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "add a_ptr3, a_ptr3, #0x4\n"
+ "ldr s4, [a_ptr4]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "add a_ptr4, a_ptr4, #0x4\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "b.ne 6b\n"
+ "5:\n"
+ "ld1r {v22.4s}, [%[minptr]]\n"
+ "ld1r {v23.4s}, [%[maxptr]]\n"
+ "fmax v24.4s, v24.4s, v22.4s\n"
+ "fmax v25.4s, v25.4s, v22.4s\n"
+ "fmax v26.4s, v26.4s, v22.4s\n"
+ "fmax v27.4s, v27.4s, v22.4s\n"
+ "fmin v24.4s, v24.4s, v23.4s\n"
+ "fmin v25.4s, v25.4s, v23.4s\n"
+ "fmin v26.4s, v26.4s, v23.4s\n"
+ "fmin v27.4s, v27.4s, v23.4s\n"
+ "str q24, [%[c_ptr0]]\n"
+ "fmax v28.4s, v28.4s, v22.4s\n"
+ "add %[c_ptr0], %[c_ptr0], #0x10\n"
+ "str q25, [c_ptr1]\n"
+ "fmin v28.4s, v28.4s, v23.4s\n"
+ "str q26, [c_ptr2]\n"
+ "str q27, [c_ptr3]\n"
+ "str q28, [c_ptr4]\n"
+ ".unreq a_ptr1\n"
+ ".unreq a_ptr2\n"
+ ".unreq a_ptr3\n"
+ ".unreq a_ptr4\n"
+ ".unreq c_ptr1\n"
+ ".unreq c_ptr2\n"
+ ".unreq c_ptr3\n"
+ ".unreq c_ptr4\n"
+ : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks)
+ : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"
+ );
+ break;
+ case 6:
+ __asm __volatile (
+ "a_ptr1 .req X0\n"
+ "a_ptr2 .req X1\n"
+ "a_ptr3 .req X2\n"
+ "a_ptr4 .req X3\n"
+ "a_ptr5 .req X4\n"
+ "c_ptr1 .req X5\n"
+ "c_ptr2 .req X6\n"
+ "c_ptr3 .req X7\n"
+ "c_ptr4 .req X8\n"
+ "c_ptr5 .req X9\n"
+ "ldr q24, [%[biasptr]]\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add a_ptr2, a_ptr1, %[lda]\n"
+ "mov v25.16b, v24.16b\n"
+ "ldr q1, [a_ptr1]\n"
+ "mov v26.16b, v24.16b\n"
+ "ldr q2, [a_ptr2]\n"
+ "mov v27.16b, v24.16b\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "mov v28.16b, v24.16b\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "mov v29.16b, v24.16b\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "add a_ptr3, a_ptr2, %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "ldr q3, [a_ptr3]\n"
+ "add a_ptr4, a_ptr3, %[lda]\n"
+ "add c_ptr2, c_ptr1, %[ldc]\n"
+ "ldr q4, [a_ptr4]\n"
+ "add a_ptr5, a_ptr4, %[lda]\n"
+ "add c_ptr3, c_ptr2, %[ldc]\n"
+ "ldr q5, [a_ptr5]\n"
+ "add c_ptr4, c_ptr3, %[ldc]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "add c_ptr5, c_ptr4, %[ldc]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
+ "add a_ptr4, a_ptr4, #0x10\n"
+ "add a_ptr5, a_ptr5, #0x10\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "cbz %[loops], 1f\n"
+ "2:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "ldr q10, [a_ptr2]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "ldr q11, [a_ptr3]\n"
+ "fmla v29.4s, v16.4s, v5.s[0]\n"
+ "ldr q12, [a_ptr4]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q13, [a_ptr5]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "subs %[loops], %[loops], #0x1\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n"
+ "fmla v28.4s, v17.4s, v4.s[1]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x20\n"
+ "fmla v29.4s, v17.4s, v5.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "add a_ptr1, a_ptr1, #0x20\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "add a_ptr2, a_ptr2, #0x20\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "add a_ptr3, a_ptr3, #0x20\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "add a_ptr4, a_ptr4, #0x20\n"
+ "fmla v28.4s, v18.4s, v4.s[2]\n"
+ "add a_ptr5, a_ptr5, #0x20\n"
+ "fmla v29.4s, v18.4s, v5.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "ldr q0, [%[a_ptr0], #-0x10]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "ldr q1, [a_ptr1, #-0x10]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "ldr q2, [a_ptr2, #-0x10]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "ldr q3, [a_ptr3, #-0x10]\n"
+ "fmla v28.4s, v19.4s, v4.s[3]\n"
+ "ldr q4, [a_ptr4, #-0x10]\n"
+ "fmla v29.4s, v19.4s, v5.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "ldr q5, [a_ptr5, #-0x10]\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "prfm PLDL1KEEP, [a_ptr1, #0x40]\n"
+ "fmla v26.4s, v16.4s, v10.s[0]\n"
+ "prfm PLDL1KEEP, [a_ptr2, #0x40]\n"
+ "fmla v27.4s, v16.4s, v11.s[0]\n"
+ "prfm PLDL1KEEP, [a_ptr3, #0x40]\n"
+ "fmla v28.4s, v16.4s, v12.s[0]\n"
+ "fmla v29.4s, v16.4s, v13.s[0]\n"
+ "ldr q16, [%[b_ptr0], #0x40]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "fmla v26.4s, v17.4s, v10.s[1]\n"
+ "fmla v27.4s, v17.4s, v11.s[1]\n"
+ "fmla v28.4s, v17.4s, v12.s[1]\n"
+ "fmla v29.4s, v17.4s, v13.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x50]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v26.4s, v18.4s, v10.s[2]\n"
+ "fmla v27.4s, v18.4s, v11.s[2]\n"
+ "fmla v28.4s, v18.4s, v12.s[2]\n"
+ "fmla v29.4s, v18.4s, v13.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x60]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "fmla v26.4s, v19.4s, v10.s[3]\n"
+ "fmla v27.4s, v19.4s, v11.s[3]\n"
+ "fmla v28.4s, v19.4s, v12.s[3]\n"
+ "fmla v29.4s, v19.4s, v13.s[3]\n"
+ "b.ne 2b\n"
+ "1:\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "prfm PSTL1KEEP, [%[c_ptr0]]\n"
+ "prfm PSTL1KEEP, [c_ptr1]\n"
+ "prfm PSTL1KEEP, [c_ptr2]\n"
+ "prfm PSTL1KEEP, [c_ptr3]\n"
+ "prfm PSTL1KEEP, [c_ptr4]\n"
+ "prfm PSTL1KEEP, [c_ptr5]\n"
+ "cbz %[regs], 3f\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "ldr q10, [a_ptr2]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "ldr q11, [a_ptr3]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "ldr q12, [a_ptr4]\n"
+ "fmla v29.4s, v16.4s, v5.s[0]\n"
+ "ldr q13, [a_ptr5]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "fmla v28.4s, v17.4s, v4.s[1]\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
+ "fmla v29.4s, v17.4s, v5.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "add a_ptr4, a_ptr4, #0x10\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "add a_ptr5, a_ptr5, #0x10\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "fmla v28.4s, v18.4s, v4.s[2]\n"
+ "fmla v29.4s, v18.4s, v5.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "fmla v28.4s, v19.4s, v4.s[3]\n"
+ "fmla v29.4s, v19.4s, v5.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "fmla v26.4s, v16.4s, v10.s[0]\n"
+ "fmla v27.4s, v16.4s, v11.s[0]\n"
+ "fmla v28.4s, v16.4s, v12.s[0]\n"
+ "fmla v29.4s, v16.4s, v13.s[0]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "fmla v26.4s, v17.4s, v10.s[1]\n"
+ "fmla v27.4s, v17.4s, v11.s[1]\n"
+ "fmla v28.4s, v17.4s, v12.s[1]\n"
+ "fmla v29.4s, v17.4s, v13.s[1]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v26.4s, v18.4s, v10.s[2]\n"
+ "fmla v27.4s, v18.4s, v11.s[2]\n"
+ "fmla v28.4s, v18.4s, v12.s[2]\n"
+ "fmla v29.4s, v18.4s, v13.s[2]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "fmla v26.4s, v19.4s, v10.s[3]\n"
+ "fmla v27.4s, v19.4s, v11.s[3]\n"
+ "fmla v28.4s, v19.4s, v12.s[3]\n"
+ "fmla v29.4s, v19.4s, v13.s[3]\n"
+ "b 4f\n"
+ "3:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "fmla v29.4s, v16.4s, v5.s[0]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "fmla v28.4s, v17.4s, v4.s[1]\n"
+ "fmla v29.4s, v17.4s, v5.s[1]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "fmla v28.4s, v18.4s, v4.s[2]\n"
+ "fmla v29.4s, v18.4s, v5.s[2]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "fmla v28.4s, v19.4s, v4.s[3]\n"
+ "fmla v29.4s, v19.4s, v5.s[3]\n"
+ "4:\n"
+ "cbz %[blocks], 5f\n"
+ "6:\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "subs %[blocks], %[blocks], #0x1\n"
+ "add %[b_ptr0], %[b_ptr0], #0x10\n"
+ "ldr s0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x4\n"
+ "ldr s1, [a_ptr1]\n"
+ "add a_ptr1, a_ptr1, #0x4\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr s2, [a_ptr2]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "add a_ptr2, a_ptr2, #0x4\n"
+ "ldr s3, [a_ptr3]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "add a_ptr3, a_ptr3, #0x4\n"
+ "ldr s4, [a_ptr4]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "add a_ptr4, a_ptr4, #0x4\n"
+ "ldr s5, [a_ptr5]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "add a_ptr5, a_ptr5, #0x4\n"
+ "fmla v29.4s, v16.4s, v5.s[0]\n"
+ "b.ne 6b\n"
+ "5:\n"
+ "ld1r {v22.4s}, [%[minptr]]\n"
+ "ld1r {v23.4s}, [%[maxptr]]\n"
+ "fmax v24.4s, v24.4s, v22.4s\n"
+ "fmax v25.4s, v25.4s, v22.4s\n"
+ "fmax v26.4s, v26.4s, v22.4s\n"
+ "fmax v27.4s, v27.4s, v22.4s\n"
+ "fmin v24.4s, v24.4s, v23.4s\n"
+ "fmin v25.4s, v25.4s, v23.4s\n"
+ "fmin v26.4s, v26.4s, v23.4s\n"
+ "fmin v27.4s, v27.4s, v23.4s\n"
+ "str q24, [%[c_ptr0]]\n"
+ "fmax v28.4s, v28.4s, v22.4s\n"
+ "add %[c_ptr0], %[c_ptr0], #0x10\n"
+ "fmax v29.4s, v29.4s, v22.4s\n"
+ "str q25, [c_ptr1]\n"
+ "fmin v28.4s, v28.4s, v23.4s\n"
+ "fmin v29.4s, v29.4s, v23.4s\n"
+ "str q26, [c_ptr2]\n"
+ "str q27, [c_ptr3]\n"
+ "str q28, [c_ptr4]\n"
+ "str q29, [c_ptr5]\n"
+ ".unreq a_ptr1\n"
+ ".unreq a_ptr2\n"
+ ".unreq a_ptr3\n"
+ ".unreq a_ptr4\n"
+ ".unreq a_ptr5\n"
+ ".unreq c_ptr1\n"
+ ".unreq c_ptr2\n"
+ ".unreq c_ptr3\n"
+ ".unreq c_ptr4\n"
+ ".unreq c_ptr5\n"
+ : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks)
+ : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "cc", "memory"
+ );
+ break;
+ case 7:
+ __asm __volatile (
+ "a_ptr1 .req X0\n"
+ "a_ptr2 .req X1\n"
+ "a_ptr3 .req X2\n"
+ "a_ptr4 .req X3\n"
+ "a_ptr5 .req X4\n"
+ "a_ptr6 .req X5\n"
+ "c_ptr1 .req X6\n"
+ "c_ptr2 .req X7\n"
+ "c_ptr3 .req X8\n"
+ "c_ptr4 .req X9\n"
+ "c_ptr5 .req X10\n"
+ "c_ptr6 .req X11\n"
+ "ldr q24, [%[biasptr]]\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add a_ptr2, a_ptr1, %[lda]\n"
+ "mov v25.16b, v24.16b\n"
+ "ldr q1, [a_ptr1]\n"
+ "mov v26.16b, v24.16b\n"
+ "ldr q2, [a_ptr2]\n"
+ "mov v27.16b, v24.16b\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "mov v28.16b, v24.16b\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "mov v29.16b, v24.16b\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "mov v30.16b, v24.16b\n"
+ "add a_ptr3, a_ptr2, %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "ldr q3, [a_ptr3]\n"
+ "add a_ptr4, a_ptr3, %[lda]\n"
+ "add c_ptr2, c_ptr1, %[ldc]\n"
+ "ldr q4, [a_ptr4]\n"
+ "add a_ptr5, a_ptr4, %[lda]\n"
+ "add c_ptr3, c_ptr2, %[ldc]\n"
+ "ldr q5, [a_ptr5]\n"
+ "add a_ptr6, a_ptr5, %[lda]\n"
+ "add c_ptr4, c_ptr3, %[ldc]\n"
+ "ldr q6, [a_ptr6]\n"
+ "add c_ptr5, c_ptr4, %[ldc]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "add c_ptr6, c_ptr5, %[ldc]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
+ "add a_ptr4, a_ptr4, #0x10\n"
+ "add a_ptr5, a_ptr5, #0x10\n"
+ "add a_ptr6, a_ptr6, #0x10\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "cbz %[loops], 1f\n"
+ "2:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "ldr q10, [a_ptr2]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "ldr q11, [a_ptr3]\n"
+ "fmla v29.4s, v16.4s, v5.s[0]\n"
+ "ldr q12, [a_ptr4]\n"
+ "fmla v30.4s, v16.4s, v6.s[0]\n"
+ "ldr q13, [a_ptr5]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q14, [a_ptr6]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "subs %[loops], %[loops], #0x1\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n"
+ "fmla v28.4s, v17.4s, v4.s[1]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x20\n"
+ "fmla v29.4s, v17.4s, v5.s[1]\n"
+ "add a_ptr1, a_ptr1, #0x20\n"
+ "fmla v30.4s, v17.4s, v6.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "add a_ptr2, a_ptr2, #0x20\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "add a_ptr3, a_ptr3, #0x20\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "add a_ptr4, a_ptr4, #0x20\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "add a_ptr5, a_ptr5, #0x20\n"
+ "fmla v28.4s, v18.4s, v4.s[2]\n"
+ "add a_ptr6, a_ptr6, #0x20\n"
+ "fmla v29.4s, v18.4s, v5.s[2]\n"
+ "prfm PLDL1KEEP, [a_ptr1, #0x40]\n"
+ "fmla v30.4s, v18.4s, v6.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "ldr q0, [%[a_ptr0], #-0x10]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "ldr q1, [a_ptr1, #-0x10]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "ldr q2, [a_ptr2, #-0x10]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "ldr q3, [a_ptr3, #-0x10]\n"
+ "fmla v28.4s, v19.4s, v4.s[3]\n"
+ "ldr q4, [a_ptr4, #-0x10]\n"
+ "fmla v29.4s, v19.4s, v5.s[3]\n"
+ "ldr q5, [a_ptr5, #-0x10]\n"
+ "fmla v30.4s, v19.4s, v6.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "ldr q6, [a_ptr6, #-0x10]\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "prfm PLDL1KEEP, [a_ptr2, #0x40]\n"
+ "fmla v26.4s, v16.4s, v10.s[0]\n"
+ "prfm PLDL1KEEP, [a_ptr3, #0x40]\n"
+ "fmla v27.4s, v16.4s, v11.s[0]\n"
+ "fmla v28.4s, v16.4s, v12.s[0]\n"
+ "fmla v29.4s, v16.4s, v13.s[0]\n"
+ "fmla v30.4s, v16.4s, v14.s[0]\n"
+ "ldr q16, [%[b_ptr0], #0x40]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "fmla v26.4s, v17.4s, v10.s[1]\n"
+ "fmla v27.4s, v17.4s, v11.s[1]\n"
+ "fmla v28.4s, v17.4s, v12.s[1]\n"
+ "fmla v29.4s, v17.4s, v13.s[1]\n"
+ "fmla v30.4s, v17.4s, v14.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x50]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v26.4s, v18.4s, v10.s[2]\n"
+ "fmla v27.4s, v18.4s, v11.s[2]\n"
+ "fmla v28.4s, v18.4s, v12.s[2]\n"
+ "fmla v29.4s, v18.4s, v13.s[2]\n"
+ "fmla v30.4s, v18.4s, v14.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x60]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "fmla v26.4s, v19.4s, v10.s[3]\n"
+ "fmla v27.4s, v19.4s, v11.s[3]\n"
+ "fmla v28.4s, v19.4s, v12.s[3]\n"
+ "fmla v29.4s, v19.4s, v13.s[3]\n"
+ "fmla v30.4s, v19.4s, v14.s[3]\n"
+ "b.ne 2b\n"
+ "1:\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "prfm PSTL1KEEP, [%[c_ptr0]]\n"
+ "prfm PSTL1KEEP, [c_ptr1]\n"
+ "prfm PSTL1KEEP, [c_ptr2]\n"
+ "prfm PSTL1KEEP, [c_ptr3]\n"
+ "prfm PSTL1KEEP, [c_ptr4]\n"
+ "prfm PSTL1KEEP, [c_ptr5]\n"
+ "prfm PSTL1KEEP, [c_ptr6]\n"
+ "cbz %[regs], 3f\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "ldr q10, [a_ptr2]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "ldr q11, [a_ptr3]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "ldr q12, [a_ptr4]\n"
+ "fmla v29.4s, v16.4s, v5.s[0]\n"
+ "ldr q13, [a_ptr5]\n"
+ "fmla v30.4s, v16.4s, v6.s[0]\n"
+ "ldr q14, [a_ptr6]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "fmla v28.4s, v17.4s, v4.s[1]\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
+ "fmla v29.4s, v17.4s, v5.s[1]\n"
+ "add a_ptr4, a_ptr4, #0x10\n"
+ "fmla v30.4s, v17.4s, v6.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "add a_ptr5, a_ptr5, #0x10\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "add a_ptr6, a_ptr6, #0x10\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "fmla v28.4s, v18.4s, v4.s[2]\n"
+ "fmla v29.4s, v18.4s, v5.s[2]\n"
+ "fmla v30.4s, v18.4s, v6.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "fmla v28.4s, v19.4s, v4.s[3]\n"
+ "fmla v29.4s, v19.4s, v5.s[3]\n"
+ "fmla v30.4s, v19.4s, v6.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "fmla v26.4s, v16.4s, v10.s[0]\n"
+ "fmla v27.4s, v16.4s, v11.s[0]\n"
+ "fmla v28.4s, v16.4s, v12.s[0]\n"
+ "fmla v29.4s, v16.4s, v13.s[0]\n"
+ "fmla v30.4s, v16.4s, v14.s[0]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "fmla v26.4s, v17.4s, v10.s[1]\n"
+ "fmla v27.4s, v17.4s, v11.s[1]\n"
+ "fmla v28.4s, v17.4s, v12.s[1]\n"
+ "fmla v29.4s, v17.4s, v13.s[1]\n"
+ "fmla v30.4s, v17.4s, v14.s[1]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v26.4s, v18.4s, v10.s[2]\n"
+ "fmla v27.4s, v18.4s, v11.s[2]\n"
+ "fmla v28.4s, v18.4s, v12.s[2]\n"
+ "fmla v29.4s, v18.4s, v13.s[2]\n"
+ "fmla v30.4s, v18.4s, v14.s[2]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "fmla v26.4s, v19.4s, v10.s[3]\n"
+ "fmla v27.4s, v19.4s, v11.s[3]\n"
+ "fmla v28.4s, v19.4s, v12.s[3]\n"
+ "fmla v29.4s, v19.4s, v13.s[3]\n"
+ "fmla v30.4s, v19.4s, v14.s[3]\n"
+ "b 4f\n"
+ "3:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "fmla v29.4s, v16.4s, v5.s[0]\n"
+ "fmla v30.4s, v16.4s, v6.s[0]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "fmla v28.4s, v17.4s, v4.s[1]\n"
+ "fmla v29.4s, v17.4s, v5.s[1]\n"
+ "fmla v30.4s, v17.4s, v6.s[1]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "fmla v28.4s, v18.4s, v4.s[2]\n"
+ "fmla v29.4s, v18.4s, v5.s[2]\n"
+ "fmla v30.4s, v18.4s, v6.s[2]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "fmla v28.4s, v19.4s, v4.s[3]\n"
+ "fmla v29.4s, v19.4s, v5.s[3]\n"
+ "fmla v30.4s, v19.4s, v6.s[3]\n"
+ "4:\n"
+ "cbz %[blocks], 5f\n"
+ "6:\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "subs %[blocks], %[blocks], #0x1\n"
+ "add %[b_ptr0], %[b_ptr0], #0x10\n"
+ "ldr s0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x4\n"
+ "ldr s1, [a_ptr1]\n"
+ "add a_ptr1, a_ptr1, #0x4\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr s2, [a_ptr2]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "add a_ptr2, a_ptr2, #0x4\n"
+ "ldr s3, [a_ptr3]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "add a_ptr3, a_ptr3, #0x4\n"
+ "ldr s4, [a_ptr4]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "add a_ptr4, a_ptr4, #0x4\n"
+ "ldr s5, [a_ptr5]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "add a_ptr5, a_ptr5, #0x4\n"
+ "ldr s6, [a_ptr6]\n"
+ "fmla v29.4s, v16.4s, v5.s[0]\n"
+ "add a_ptr6, a_ptr6, #0x4\n"
+ "fmla v30.4s, v16.4s, v6.s[0]\n"
+ "b.ne 6b\n"
+ "5:\n"
+ "ld1r {v22.4s}, [%[minptr]]\n"
+ "ld1r {v23.4s}, [%[maxptr]]\n"
+ "fmax v24.4s, v24.4s, v22.4s\n"
+ "fmax v25.4s, v25.4s, v22.4s\n"
+ "fmax v26.4s, v26.4s, v22.4s\n"
+ "fmax v27.4s, v27.4s, v22.4s\n"
+ "fmin v24.4s, v24.4s, v23.4s\n"
+ "fmin v25.4s, v25.4s, v23.4s\n"
+ "fmin v26.4s, v26.4s, v23.4s\n"
+ "fmin v27.4s, v27.4s, v23.4s\n"
+ "str q24, [%[c_ptr0]]\n"
+ "fmax v28.4s, v28.4s, v22.4s\n"
+ "add %[c_ptr0], %[c_ptr0], #0x10\n"
+ "fmax v29.4s, v29.4s, v22.4s\n"
+ "str q25, [c_ptr1]\n"
+ "fmax v30.4s, v30.4s, v22.4s\n"
+ "fmin v28.4s, v28.4s, v23.4s\n"
+ "fmin v29.4s, v29.4s, v23.4s\n"
+ "str q26, [c_ptr2]\n"
+ "fmin v30.4s, v30.4s, v23.4s\n"
+ "str q27, [c_ptr3]\n"
+ "str q28, [c_ptr4]\n"
+ "str q29, [c_ptr5]\n"
+ "str q30, [c_ptr6]\n"
+ ".unreq a_ptr1\n"
+ ".unreq a_ptr2\n"
+ ".unreq a_ptr3\n"
+ ".unreq a_ptr4\n"
+ ".unreq a_ptr5\n"
+ ".unreq a_ptr6\n"
+ ".unreq c_ptr1\n"
+ ".unreq c_ptr2\n"
+ ".unreq c_ptr3\n"
+ ".unreq c_ptr4\n"
+ ".unreq c_ptr5\n"
+ ".unreq c_ptr6\n"
+ : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks)
+ : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "cc", "memory"
+ );
+ break;
+ default:
+ case 8:
+ __asm __volatile (
+ "a_ptr1 .req X0\n"
+ "a_ptr2 .req X1\n"
+ "a_ptr3 .req X2\n"
+ "a_ptr4 .req X3\n"
+ "a_ptr5 .req X4\n"
+ "a_ptr6 .req X5\n"
+ "a_ptr7 .req X6\n"
+ "c_ptr1 .req X7\n"
+ "c_ptr2 .req X8\n"
+ "c_ptr3 .req X9\n"
+ "c_ptr4 .req X10\n"
+ "c_ptr5 .req X11\n"
+ "c_ptr6 .req X12\n"
+ "c_ptr7 .req X13\n"
+ "ldr q24, [%[biasptr]]\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add a_ptr2, a_ptr1, %[lda]\n"
+ "mov v25.16b, v24.16b\n"
+ "ldr q1, [a_ptr1]\n"
+ "mov v26.16b, v24.16b\n"
+ "ldr q2, [a_ptr2]\n"
+ "mov v27.16b, v24.16b\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "mov v28.16b, v24.16b\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "mov v29.16b, v24.16b\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "mov v30.16b, v24.16b\n"
+ "add a_ptr3, a_ptr2, %[lda]\n"
+ "mov v31.16b, v24.16b\n"
+ "ldr q3, [a_ptr3]\n"
+ "add a_ptr4, a_ptr3, %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "ldr q4, [a_ptr4]\n"
+ "add a_ptr5, a_ptr4, %[lda]\n"
+ "add c_ptr2, c_ptr1, %[ldc]\n"
+ "ldr q5, [a_ptr5]\n"
+ "add a_ptr6, a_ptr5, %[lda]\n"
+ "add c_ptr3, c_ptr2, %[ldc]\n"
+ "ldr q6, [a_ptr6]\n"
+ "add a_ptr7, a_ptr6, %[lda]\n"
+ "add c_ptr4, c_ptr3, %[ldc]\n"
+ "ldr q7, [a_ptr7]\n"
+ "add c_ptr5, c_ptr4, %[ldc]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "add c_ptr6, c_ptr5, %[ldc]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "add c_ptr7, c_ptr6, %[ldc]\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
+ "add a_ptr4, a_ptr4, #0x10\n"
+ "add a_ptr5, a_ptr5, #0x10\n"
+ "add a_ptr6, a_ptr6, #0x10\n"
+ "add a_ptr7, a_ptr7, #0x10\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "cbz %[loops], 1f\n"
+ "2:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "ldr q10, [a_ptr2]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "ldr q11, [a_ptr3]\n"
+ "fmla v29.4s, v16.4s, v5.s[0]\n"
+ "ldr q12, [a_ptr4]\n"
+ "fmla v30.4s, v16.4s, v6.s[0]\n"
+ "ldr q13, [a_ptr5]\n"
+ "fmla v31.4s, v16.4s, v7.s[0]\n"
+ "ldr q14, [a_ptr6]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q15, [a_ptr7]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "subs %[loops], %[loops], #0x1\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n"
+ "fmla v28.4s, v17.4s, v4.s[1]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x20\n"
+ "fmla v29.4s, v17.4s, v5.s[1]\n"
+ "add a_ptr1, a_ptr1, #0x20\n"
+ "fmla v30.4s, v17.4s, v6.s[1]\n"
+ "add a_ptr2, a_ptr2, #0x20\n"
+ "fmla v31.4s, v17.4s, v7.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "add a_ptr3, a_ptr3, #0x20\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "add a_ptr4, a_ptr4, #0x20\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "add a_ptr5, a_ptr5, #0x20\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "add a_ptr6, a_ptr6, #0x20\n"
+ "fmla v28.4s, v18.4s, v4.s[2]\n"
+ "add a_ptr7, a_ptr7, #0x20\n"
+ "fmla v29.4s, v18.4s, v5.s[2]\n"
+ "prfm PLDL1KEEP, [a_ptr1, #0x40]\n"
+ "fmla v30.4s, v18.4s, v6.s[2]\n"
+ "prfm PLDL1KEEP, [a_ptr2, #0x40]\n"
+ "fmla v31.4s, v18.4s, v7.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "ldr q0, [%[a_ptr0], #-0x10]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "ldr q1, [a_ptr1, #-0x10]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "ldr q2, [a_ptr2, #-0x10]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "ldr q3, [a_ptr3, #-0x10]\n"
+ "fmla v28.4s, v19.4s, v4.s[3]\n"
+ "ldr q4, [a_ptr4, #-0x10]\n"
+ "fmla v29.4s, v19.4s, v5.s[3]\n"
+ "ldr q5, [a_ptr5, #-0x10]\n"
+ "fmla v30.4s, v19.4s, v6.s[3]\n"
+ "ldr q6, [a_ptr6, #-0x10]\n"
+ "fmla v31.4s, v19.4s, v7.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "ldr q7, [a_ptr7, #-0x10]\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "prfm PLDL1KEEP, [a_ptr3, #0x40]\n"
+ "fmla v26.4s, v16.4s, v10.s[0]\n"
+ "fmla v27.4s, v16.4s, v11.s[0]\n"
+ "fmla v28.4s, v16.4s, v12.s[0]\n"
+ "fmla v29.4s, v16.4s, v13.s[0]\n"
+ "fmla v30.4s, v16.4s, v14.s[0]\n"
+ "fmla v31.4s, v16.4s, v15.s[0]\n"
+ "ldr q16, [%[b_ptr0], #0x40]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "fmla v26.4s, v17.4s, v10.s[1]\n"
+ "fmla v27.4s, v17.4s, v11.s[1]\n"
+ "fmla v28.4s, v17.4s, v12.s[1]\n"
+ "fmla v29.4s, v17.4s, v13.s[1]\n"
+ "fmla v30.4s, v17.4s, v14.s[1]\n"
+ "fmla v31.4s, v17.4s, v15.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x50]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v26.4s, v18.4s, v10.s[2]\n"
+ "fmla v27.4s, v18.4s, v11.s[2]\n"
+ "fmla v28.4s, v18.4s, v12.s[2]\n"
+ "fmla v29.4s, v18.4s, v13.s[2]\n"
+ "fmla v30.4s, v18.4s, v14.s[2]\n"
+ "fmla v31.4s, v18.4s, v15.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x60]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "fmla v26.4s, v19.4s, v10.s[3]\n"
+ "fmla v27.4s, v19.4s, v11.s[3]\n"
+ "fmla v28.4s, v19.4s, v12.s[3]\n"
+ "fmla v29.4s, v19.4s, v13.s[3]\n"
+ "fmla v30.4s, v19.4s, v14.s[3]\n"
+ "fmla v31.4s, v19.4s, v15.s[3]\n"
+ "b.ne 2b\n"
+ "1:\n"
+ "ldr q19, [%[b_ptr0], #-0x10]\n"
+ "prfm PSTL1KEEP, [%[c_ptr0]]\n"
+ "prfm PSTL1KEEP, [c_ptr1]\n"
+ "prfm PSTL1KEEP, [c_ptr2]\n"
+ "prfm PSTL1KEEP, [c_ptr3]\n"
+ "prfm PSTL1KEEP, [c_ptr4]\n"
+ "prfm PSTL1KEEP, [c_ptr5]\n"
+ "prfm PSTL1KEEP, [c_ptr6]\n"
+ "prfm PSTL1KEEP, [c_ptr7]\n"
+ "cbz %[regs], 3f\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr q8, [%[a_ptr0]]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "ldr q9, [a_ptr1]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "ldr q10, [a_ptr2]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "ldr q11, [a_ptr3]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "ldr q12, [a_ptr4]\n"
+ "fmla v29.4s, v16.4s, v5.s[0]\n"
+ "ldr q13, [a_ptr5]\n"
+ "fmla v30.4s, v16.4s, v6.s[0]\n"
+ "ldr q14, [a_ptr6]\n"
+ "fmla v31.4s, v16.4s, v7.s[0]\n"
+ "ldr q15, [a_ptr7]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "fmla v28.4s, v17.4s, v4.s[1]\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
+ "fmla v29.4s, v17.4s, v5.s[1]\n"
+ "add a_ptr4, a_ptr4, #0x10\n"
+ "fmla v30.4s, v17.4s, v6.s[1]\n"
+ "add a_ptr5, a_ptr5, #0x10\n"
+ "fmla v31.4s, v17.4s, v7.s[1]\n"
+ "ldr q17, [%[b_ptr0], #0x10]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "add a_ptr6, a_ptr6, #0x10\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "add a_ptr7, a_ptr7, #0x10\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "fmla v28.4s, v18.4s, v4.s[2]\n"
+ "fmla v29.4s, v18.4s, v5.s[2]\n"
+ "fmla v30.4s, v18.4s, v6.s[2]\n"
+ "fmla v31.4s, v18.4s, v7.s[2]\n"
+ "ldr q18, [%[b_ptr0], #0x20]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "fmla v28.4s, v19.4s, v4.s[3]\n"
+ "fmla v29.4s, v19.4s, v5.s[3]\n"
+ "fmla v30.4s, v19.4s, v6.s[3]\n"
+ "fmla v31.4s, v19.4s, v7.s[3]\n"
+ "ldr q19, [%[b_ptr0], #0x30]\n"
+ "fmla v24.4s, v16.4s, v8.s[0]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x40\n"
+ "fmla v25.4s, v16.4s, v9.s[0]\n"
+ "fmla v26.4s, v16.4s, v10.s[0]\n"
+ "fmla v27.4s, v16.4s, v11.s[0]\n"
+ "fmla v28.4s, v16.4s, v12.s[0]\n"
+ "fmla v29.4s, v16.4s, v13.s[0]\n"
+ "fmla v30.4s, v16.4s, v14.s[0]\n"
+ "fmla v31.4s, v16.4s, v15.s[0]\n"
+ "fmla v24.4s, v17.4s, v8.s[1]\n"
+ "fmla v25.4s, v17.4s, v9.s[1]\n"
+ "fmla v26.4s, v17.4s, v10.s[1]\n"
+ "fmla v27.4s, v17.4s, v11.s[1]\n"
+ "fmla v28.4s, v17.4s, v12.s[1]\n"
+ "fmla v29.4s, v17.4s, v13.s[1]\n"
+ "fmla v30.4s, v17.4s, v14.s[1]\n"
+ "fmla v31.4s, v17.4s, v15.s[1]\n"
+ "fmla v24.4s, v18.4s, v8.s[2]\n"
+ "fmla v25.4s, v18.4s, v9.s[2]\n"
+ "fmla v26.4s, v18.4s, v10.s[2]\n"
+ "fmla v27.4s, v18.4s, v11.s[2]\n"
+ "fmla v28.4s, v18.4s, v12.s[2]\n"
+ "fmla v29.4s, v18.4s, v13.s[2]\n"
+ "fmla v30.4s, v18.4s, v14.s[2]\n"
+ "fmla v31.4s, v18.4s, v15.s[2]\n"
+ "fmla v24.4s, v19.4s, v8.s[3]\n"
+ "fmla v25.4s, v19.4s, v9.s[3]\n"
+ "fmla v26.4s, v19.4s, v10.s[3]\n"
+ "fmla v27.4s, v19.4s, v11.s[3]\n"
+ "fmla v28.4s, v19.4s, v12.s[3]\n"
+ "fmla v29.4s, v19.4s, v13.s[3]\n"
+ "fmla v30.4s, v19.4s, v14.s[3]\n"
+ "fmla v31.4s, v19.4s, v15.s[3]\n"
+ "b 4f\n"
+ "3:\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "fmla v29.4s, v16.4s, v5.s[0]\n"
+ "fmla v30.4s, v16.4s, v6.s[0]\n"
+ "fmla v31.4s, v16.4s, v7.s[0]\n"
+ "fmla v24.4s, v17.4s, v0.s[1]\n"
+ "fmla v25.4s, v17.4s, v1.s[1]\n"
+ "fmla v26.4s, v17.4s, v2.s[1]\n"
+ "fmla v27.4s, v17.4s, v3.s[1]\n"
+ "fmla v28.4s, v17.4s, v4.s[1]\n"
+ "fmla v29.4s, v17.4s, v5.s[1]\n"
+ "fmla v30.4s, v17.4s, v6.s[1]\n"
+ "fmla v31.4s, v17.4s, v7.s[1]\n"
+ "fmla v24.4s, v18.4s, v0.s[2]\n"
+ "fmla v25.4s, v18.4s, v1.s[2]\n"
+ "fmla v26.4s, v18.4s, v2.s[2]\n"
+ "fmla v27.4s, v18.4s, v3.s[2]\n"
+ "fmla v28.4s, v18.4s, v4.s[2]\n"
+ "fmla v29.4s, v18.4s, v5.s[2]\n"
+ "fmla v30.4s, v18.4s, v6.s[2]\n"
+ "fmla v31.4s, v18.4s, v7.s[2]\n"
+ "fmla v24.4s, v19.4s, v0.s[3]\n"
+ "fmla v25.4s, v19.4s, v1.s[3]\n"
+ "fmla v26.4s, v19.4s, v2.s[3]\n"
+ "fmla v27.4s, v19.4s, v3.s[3]\n"
+ "fmla v28.4s, v19.4s, v4.s[3]\n"
+ "fmla v29.4s, v19.4s, v5.s[3]\n"
+ "fmla v30.4s, v19.4s, v6.s[3]\n"
+ "fmla v31.4s, v19.4s, v7.s[3]\n"
+ "4:\n"
+ "cbz %[blocks], 5f\n"
+ "6:\n"
+ "ldr q16, [%[b_ptr0]]\n"
+ "subs %[blocks], %[blocks], #0x1\n"
+ "add %[b_ptr0], %[b_ptr0], #0x10\n"
+ "ldr s0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x4\n"
+ "ldr s1, [a_ptr1]\n"
+ "add a_ptr1, a_ptr1, #0x4\n"
+ "fmla v24.4s, v16.4s, v0.s[0]\n"
+ "ldr s2, [a_ptr2]\n"
+ "fmla v25.4s, v16.4s, v1.s[0]\n"
+ "add a_ptr2, a_ptr2, #0x4\n"
+ "ldr s3, [a_ptr3]\n"
+ "fmla v26.4s, v16.4s, v2.s[0]\n"
+ "add a_ptr3, a_ptr3, #0x4\n"
+ "ldr s4, [a_ptr4]\n"
+ "fmla v27.4s, v16.4s, v3.s[0]\n"
+ "add a_ptr4, a_ptr4, #0x4\n"
+ "ldr s5, [a_ptr5]\n"
+ "fmla v28.4s, v16.4s, v4.s[0]\n"
+ "add a_ptr5, a_ptr5, #0x4\n"
+ "ldr s6, [a_ptr6]\n"
+ "fmla v29.4s, v16.4s, v5.s[0]\n"
+ "add a_ptr6, a_ptr6, #0x4\n"
+ "ldr s7, [a_ptr7]\n"
+ "fmla v30.4s, v16.4s, v6.s[0]\n"
+ "add a_ptr7, a_ptr7, #0x4\n"
+ "fmla v31.4s, v16.4s, v7.s[0]\n"
+ "b.ne 6b\n"
+ "5:\n"
+ "ld1r {v22.4s}, [%[minptr]]\n"
+ "ld1r {v23.4s}, [%[maxptr]]\n"
+ "fmax v24.4s, v24.4s, v22.4s\n"
+ "fmax v25.4s, v25.4s, v22.4s\n"
+ "fmax v26.4s, v26.4s, v22.4s\n"
+ "fmax v27.4s, v27.4s, v22.4s\n"
+ "fmin v24.4s, v24.4s, v23.4s\n"
+ "fmin v25.4s, v25.4s, v23.4s\n"
+ "fmin v26.4s, v26.4s, v23.4s\n"
+ "fmin v27.4s, v27.4s, v23.4s\n"
+ "str q24, [%[c_ptr0]]\n"
+ "fmax v28.4s, v28.4s, v22.4s\n"
+ "add %[c_ptr0], %[c_ptr0], #0x10\n"
+ "fmax v29.4s, v29.4s, v22.4s\n"
+ "str q25, [c_ptr1]\n"
+ "fmax v30.4s, v30.4s, v22.4s\n"
+ "fmin v28.4s, v28.4s, v23.4s\n"
+ "fmax v31.4s, v31.4s, v22.4s\n"
+ "str q26, [c_ptr2]\n"
+ "fmin v29.4s, v29.4s, v23.4s\n"
+ "fmin v30.4s, v30.4s, v23.4s\n"
+ "fmin v31.4s, v31.4s, v23.4s\n"
+ "str q27, [c_ptr3]\n"
+ "str q28, [c_ptr4]\n"
+ "str q29, [c_ptr5]\n"
+ "str q30, [c_ptr6]\n"
+ "str q31, [c_ptr7]\n"
+ ".unreq a_ptr1\n"
+ ".unreq a_ptr2\n"
+ ".unreq a_ptr3\n"
+ ".unreq a_ptr4\n"
+ ".unreq a_ptr5\n"
+ ".unreq a_ptr6\n"
+ ".unreq a_ptr7\n"
+ ".unreq c_ptr1\n"
+ ".unreq c_ptr2\n"
+ ".unreq c_ptr3\n"
+ ".unreq c_ptr4\n"
+ ".unreq c_ptr5\n"
+ ".unreq c_ptr6\n"
+ ".unreq c_ptr7\n"
+ : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks)
+ : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb), [biasptr] "r" (biasptr), [minptr] "r" (minptr), [maxptr] "r" (maxptr)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", "memory"
+ );
+ break;
+ }
+ if (use_result_buffer) {
+ for(int cy=0; cy<std::min(M-y, 8); cy++) {
+ for(unsigned int cx=0; cx<width; cx++) {
+ c_ptr_real[cy * ldc + cx] = result_buffer[cy * 4 + cx];
+ }
+ }
+ }
+ }
+ }
+}
+
+} // namespace arm_gemm
+
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4.hpp
index 5a6fabcfa9..bdc62ea181 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4.hpp
@@ -61,7 +61,7 @@ public:
static constexpr bool supports_append()
{
- return false;
+ return true;
}
static constexpr bool supports_bias()
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/a55.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/a55.cpp
index 3ecf0151aa..7c08aa2165 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/a55.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/a55.cpp
@@ -35,7 +35,6 @@ namespace arm_gemm {
void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, int32_t *C, int ldc, int M, int N, int K, const int32_t *bias, Activation act, bool append) {
UNUSED(bias);
UNUSED(act);
-
const int K_stride = ((K + 3) / 4) * 4;
const long loops_count = ((K + 16) / 32) - 1;
K -= loops_count * 32;
@@ -80,6 +79,7 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
"temploadreg1 .req X1\n"
"temploadreg2 .req X2\n"
"temploadreg3 .req X3\n"
+ "cbnz %[append], 1f\n"
"movi v16.4s, #0\n"
"ldr q0, [%[a_ptr0]]\n"
"movi v17.4s, #0\n"
@@ -95,8 +95,26 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
"ldr d14, [%[b_ptr0], #0x60]\n"
"ldr temploadreg2, [%[b_ptr0], #0x68]\n"
"add %[b_ptr0], %[b_ptr0], #0x80\n"
- "cbz %[loops], 1f\n"
- "2:\n"
+ "cbz %[loops], 2f\n"
+ "b 3f\n"
+ "1:\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #0x10]\n"
+ "ldr q18, [%[c_ptr0], #0x20]\n"
+ "ldr q19, [%[c_ptr0], #0x30]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "ldr q8, [%[b_ptr0]]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr d14, [%[b_ptr0], #0x60]\n"
+ "ldr temploadreg2, [%[b_ptr0], #0x68]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "cbz %[loops], 2f\n"
+ "3:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ins v14.d[1], temploadreg2\n"
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
@@ -236,14 +254,14 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
"ins v11.d[1], temploadreg3\n"
"ins v12.d[1], temploadreg0\n"
"ins v13.d[1], temploadreg1\n"
- "b.ne 2b\n"
- "1:\n"
+ "b.ne 3b\n"
+ "2:\n"
"ins v14.d[1], temploadreg2\n"
"prfm PSTL1KEEP, [%[c_ptr0]]\n"
"ldr d15, [%[b_ptr0], #-0x10]\n"
"ldr temploadreg3, [%[b_ptr0], #-0x8]\n"
"ins v15.d[1], temploadreg3\n"
- "cbz %[regs], 3f\n"
+ "cbz %[regs], 4f\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr d4, [%[a_ptr0]]\n"
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
@@ -354,8 +372,8 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4fa4e9b1 // sdot v17.4s, v13.16b, v4.4b[3]\n"
".inst 0x4fa4e9d2 // sdot v18.4s, v14.16b, v4.4b[3]\n"
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
- "b 4f\n"
- "3:\n"
+ "b 5f\n"
+ "4:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr d8, [%[b_ptr0]]\n"
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
@@ -397,9 +415,9 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4fa0e9b1 // sdot v17.4s, v13.16b, v0.4b[3]\n"
".inst 0x4fa0e9d2 // sdot v18.4s, v14.16b, v0.4b[3]\n"
".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n"
- "4:\n"
- "cbz %[blocks], 5f\n"
- "6:\n"
+ "5:\n"
+ "cbz %[blocks], 6f\n"
+ "7:\n"
"ldr q8, [%[b_ptr0]]\n"
"subs %[blocks], %[blocks], #0x1\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
@@ -412,17 +430,17 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
- "b.ne 6b\n"
- "5:\n"
- "cbz %[odds], 7f\n"
+ "b.ne 7b\n"
+ "6:\n"
+ "cbz %[odds], 8f\n"
"ld1 {v0.b}[0], [%[a_ptr0]], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[1], [%[a_ptr0]], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[2], [%[a_ptr0]]\n"
- "8:\n"
+ "9:\n"
"ldr q8, [%[b_ptr0]]\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
"ldr q10, [%[b_ptr0], #0x20]\n"
@@ -431,7 +449,7 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
- "7:\n"
+ "8:\n"
"str q16, [%[c_ptr0]]\n"
"str q17, [%[c_ptr0], #0x10]\n"
"str q18, [%[c_ptr0], #0x20]\n"
@@ -454,74 +472,99 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
"temploadreg1 .req X3\n"
"temploadreg2 .req X4\n"
"temploadreg3 .req X5\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "cbnz %[append], 1f\n"
"movi v16.4s, #0\n"
"ldr q0, [%[a_ptr0]]\n"
"movi v17.4s, #0\n"
- "ldr q8, [%[b_ptr0]]\n"
+ "ldr q1, [a_ptr1]\n"
"movi v18.4s, #0\n"
- "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q8, [%[b_ptr0]]\n"
"movi v19.4s, #0\n"
- "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
"movi v20.4s, #0\n"
- "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
"movi v21.4s, #0\n"
- "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
"movi v22.4s, #0\n"
- "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
"movi v23.4s, #0\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
"ldr d14, [%[b_ptr0], #0x60]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
"ldr temploadreg2, [%[b_ptr0], #0x68]\n"
- "add a_ptr1, %[a_ptr0], %[lda]\n"
- "add c_ptr1, %[c_ptr0], %[ldc]\n"
- "ldr q1, [a_ptr1]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "cbz %[loops], 2f\n"
+ "b 3f\n"
+ "1:\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #0x10]\n"
+ "ldr q18, [%[c_ptr0], #0x20]\n"
+ "ldr q19, [%[c_ptr0], #0x30]\n"
+ "ldr q20, [c_ptr1]\n"
+ "ldr q21, [c_ptr1, #0x10]\n"
+ "ldr q22, [c_ptr1, #0x20]\n"
+ "ldr q23, [c_ptr1, #0x30]\n"
+ "ldr q0, [%[a_ptr0]]\n"
"add %[a_ptr0], %[a_ptr0], #0x10\n"
- "ins v14.d[1], temploadreg2\n"
+ "ldr q1, [a_ptr1]\n"
"add a_ptr1, a_ptr1, #0x10\n"
+ "ldr q8, [%[b_ptr0]]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr d14, [%[b_ptr0], #0x60]\n"
+ "ldr temploadreg2, [%[b_ptr0], #0x68]\n"
"add %[b_ptr0], %[b_ptr0], #0x80\n"
- "cbz %[loops], 1f\n"
- "2:\n"
+ "cbz %[loops], 2f\n"
+ "3:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
- "ldr d15, [%[b_ptr0], #-0x10]\n"
+ "ins v14.d[1], temploadreg2\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
- "ldr temploadreg3, [%[b_ptr0], #-0x8]\n"
+ "ldr d15, [%[b_ptr0], #-0x10]\n"
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
- "ldr d4, [%[a_ptr0]]\n"
+ "ldr temploadreg3, [%[b_ptr0], #-0x8]\n"
".inst 0x4f81e135 // sdot v21.4s, v9.16b, v1.4b[0]\n"
- "ldr temploadreg0, [%[a_ptr0], #0x8]\n"
+ "ldr d4, [%[a_ptr0]]\n"
".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n"
- "ldr d5, [a_ptr1]\n"
+ "ldr temploadreg0, [%[a_ptr0], #0x8]\n"
".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n"
- "ldr temploadreg1, [a_ptr1, #0x8]\n"
+ "ldr d5, [a_ptr1]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
- "ldr d8, [%[b_ptr0]]\n"
+ "ldr temploadreg1, [a_ptr1, #0x8]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
- "ins v4.d[1], temploadreg0\n"
+ "ldr d8, [%[b_ptr0]]\n"
".inst 0x4fa0e190 // sdot v16.4s, v12.16b, v0.4b[1]\n"
- "ldr temploadreg0, [%[b_ptr0], #0x8]\n"
+ "ins v4.d[1], temploadreg0\n"
".inst 0x4fa1e194 // sdot v20.4s, v12.16b, v1.4b[1]\n"
- "ldr d9, [%[b_ptr0], #0x10]\n"
+ "ldr temploadreg0, [%[b_ptr0], #0x8]\n"
".inst 0x4fa0e1b1 // sdot v17.4s, v13.16b, v0.4b[1]\n"
- "ins v5.d[1], temploadreg1\n"
+ "ldr d9, [%[b_ptr0], #0x10]\n"
".inst 0x4fa1e1b5 // sdot v21.4s, v13.16b, v1.4b[1]\n"
- "ldr temploadreg1, [%[b_ptr0], #0x18]\n"
+ "ins v5.d[1], temploadreg1\n"
".inst 0x4fa0e1d2 // sdot v18.4s, v14.16b, v0.4b[1]\n"
- "ldr d10, [%[b_ptr0], #0x20]\n"
+ "ldr temploadreg1, [%[b_ptr0], #0x18]\n"
".inst 0x4fa1e1d6 // sdot v22.4s, v14.16b, v1.4b[1]\n"
+ "ldr d10, [%[b_ptr0], #0x20]\n"
"ldr temploadreg2, [%[b_ptr0], #0x28]\n"
- "ldr d11, [%[b_ptr0], #0x30]\n"
"subs %[loops], %[loops], #0x1\n"
- "ins v15.d[1], temploadreg3\n"
+ "ldr d11, [%[b_ptr0], #0x30]\n"
"prfm PLDL1KEEP, [%[a_ptr0], #0x40]\n"
- "ldr temploadreg3, [%[b_ptr0], #0x38]\n"
+ "ins v15.d[1], temploadreg3\n"
"add %[a_ptr0], %[a_ptr0], #0x20\n"
+ "ldr temploadreg3, [%[b_ptr0], #0x38]\n"
+ "add a_ptr1, a_ptr1, #0x20\n"
".inst 0x4fa0e1f3 // sdot v19.4s, v15.16b, v0.4b[1]\n"
"ldr d12, [%[b_ptr0], #0x40]\n"
".inst 0x4fa1e1f7 // sdot v23.4s, v15.16b, v1.4b[1]\n"
"ins v8.d[1], temploadreg0\n"
"ldr temploadreg0, [%[b_ptr0], #0x48]\n"
- "add a_ptr1, a_ptr1, #0x20\n"
- "ldr d13, [%[b_ptr0], #0x50]\n"
"prfm PLDL1KEEP, [a_ptr1, #0x40]\n"
+ "ldr d13, [%[b_ptr0], #0x50]\n"
".inst 0x4f80e910 // sdot v16.4s, v8.16b, v0.4b[2]\n"
"ins v9.d[1], temploadreg1\n"
".inst 0x4f81e914 // sdot v20.4s, v8.16b, v1.4b[2]\n"
@@ -658,15 +701,15 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
"ins v11.d[1], temploadreg3\n"
"ins v12.d[1], temploadreg0\n"
"ins v13.d[1], temploadreg1\n"
+ "b.ne 3b\n"
+ "2:\n"
"ins v14.d[1], temploadreg2\n"
- "b.ne 2b\n"
- "1:\n"
- "ldr d15, [%[b_ptr0], #-0x10]\n"
"prfm PSTL1KEEP, [%[c_ptr0]]\n"
- "ldr temploadreg3, [%[b_ptr0], #-0x8]\n"
+ "ldr d15, [%[b_ptr0], #-0x10]\n"
"prfm PSTL1KEEP, [c_ptr1]\n"
+ "ldr temploadreg3, [%[b_ptr0], #-0x8]\n"
"ins v15.d[1], temploadreg3\n"
- "cbz %[regs], 3f\n"
+ "cbz %[regs], 4f\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr d4, [%[a_ptr0]]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -813,8 +856,8 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4fa5e9d6 // sdot v22.4s, v14.16b, v5.4b[3]\n"
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
- "b 4f\n"
- "3:\n"
+ "b 5f\n"
+ "4:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr temploadreg0, [%[b_ptr0], #0x8]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -872,9 +915,9 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4fa1e9d6 // sdot v22.4s, v14.16b, v1.4b[3]\n"
".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n"
".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n"
- "4:\n"
- "cbz %[blocks], 5f\n"
- "6:\n"
+ "5:\n"
+ "cbz %[blocks], 6f\n"
+ "7:\n"
"ldr q8, [%[b_ptr0]]\n"
"subs %[blocks], %[blocks], #0x1\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
@@ -893,20 +936,20 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
- "b.ne 6b\n"
- "5:\n"
- "cbz %[odds], 7f\n"
+ "b.ne 7b\n"
+ "6:\n"
+ "cbz %[odds], 8f\n"
"ld1 {v0.b}[0], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[0], [a_ptr1], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[1], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[1], [a_ptr1], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[2], [%[a_ptr0]]\n"
"ld1 {v1.b}[2], [a_ptr1]\n"
- "8:\n"
+ "9:\n"
"ldr q8, [%[b_ptr0]]\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
"ldr q10, [%[b_ptr0], #0x20]\n"
@@ -919,7 +962,7 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
- "7:\n"
+ "8:\n"
"str q16, [%[c_ptr0]]\n"
"str q17, [%[c_ptr0], #0x10]\n"
"str q18, [%[c_ptr0], #0x20]\n"
@@ -950,40 +993,72 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
"temploadreg1 .req X5\n"
"temploadreg2 .req X6\n"
"temploadreg3 .req X7\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "add a_ptr2, a_ptr1, %[lda]\n"
+ "add c_ptr2, c_ptr1, %[ldc]\n"
+ "cbnz %[append], 1f\n"
"movi v16.4s, #0\n"
"ldr q0, [%[a_ptr0]]\n"
"movi v17.4s, #0\n"
- "ldr q8, [%[b_ptr0]]\n"
+ "ldr q1, [a_ptr1]\n"
"movi v18.4s, #0\n"
- "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q2, [a_ptr2]\n"
"movi v19.4s, #0\n"
- "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q8, [%[b_ptr0]]\n"
"movi v20.4s, #0\n"
- "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
"movi v21.4s, #0\n"
- "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
"movi v22.4s, #0\n"
- "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
"movi v23.4s, #0\n"
- "ldr d14, [%[b_ptr0], #0x60]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
"movi v24.4s, #0\n"
- "ldr temploadreg2, [%[b_ptr0], #0x68]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
"movi v25.4s, #0\n"
- "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr d14, [%[b_ptr0], #0x60]\n"
"movi v26.4s, #0\n"
- "ldr q1, [a_ptr1]\n"
+ "ldr temploadreg2, [%[b_ptr0], #0x68]\n"
"movi v27.4s, #0\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
"ins v14.d[1], temploadreg2\n"
- "add a_ptr2, a_ptr1, %[lda]\n"
- "add c_ptr1, %[c_ptr0], %[ldc]\n"
- "ldr q2, [a_ptr2]\n"
- "add c_ptr2, c_ptr1, %[ldc]\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "cbz %[loops], 2f\n"
+ "b 3f\n"
+ "1:\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #0x10]\n"
+ "ldr q18, [%[c_ptr0], #0x20]\n"
+ "ldr q19, [%[c_ptr0], #0x30]\n"
+ "ldr q20, [c_ptr1]\n"
+ "ldr q21, [c_ptr1, #0x10]\n"
+ "ldr q22, [c_ptr1, #0x20]\n"
+ "ldr q23, [c_ptr1, #0x30]\n"
+ "ldr q24, [c_ptr2]\n"
+ "ldr q25, [c_ptr2, #0x10]\n"
+ "ldr q26, [c_ptr2, #0x20]\n"
+ "ldr q27, [c_ptr2, #0x30]\n"
+ "ldr q0, [%[a_ptr0]]\n"
"add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "ldr q1, [a_ptr1]\n"
"add a_ptr1, a_ptr1, #0x10\n"
+ "ldr q2, [a_ptr2]\n"
"add a_ptr2, a_ptr2, #0x10\n"
+ "ldr q8, [%[b_ptr0]]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr d14, [%[b_ptr0], #0x60]\n"
+ "ldr temploadreg2, [%[b_ptr0], #0x68]\n"
"add %[b_ptr0], %[b_ptr0], #0x80\n"
- "cbz %[loops], 1f\n"
- "2:\n"
+ "ins v14.d[1], temploadreg2\n"
+ "cbz %[loops], 2f\n"
+ "3:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr d15, [%[b_ptr0], #-0x10]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -1203,15 +1278,15 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
"ins v12.d[1], temploadreg0\n"
"ins v13.d[1], temploadreg1\n"
"ins v14.d[1], temploadreg2\n"
- "b.ne 2b\n"
- "1:\n"
+ "b.ne 3b\n"
+ "2:\n"
"ldr d15, [%[b_ptr0], #-0x10]\n"
"prfm PSTL1KEEP, [%[c_ptr0]]\n"
"ldr temploadreg3, [%[b_ptr0], #-0x8]\n"
"prfm PSTL1KEEP, [c_ptr1]\n"
"prfm PSTL1KEEP, [c_ptr2]\n"
"ins v15.d[1], temploadreg3\n"
- "cbz %[regs], 3f\n"
+ "cbz %[regs], 4f\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr d4, [%[a_ptr0]]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -1394,8 +1469,8 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n"
- "b 4f\n"
- "3:\n"
+ "b 5f\n"
+ "4:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr temploadreg0, [%[b_ptr0], #0x8]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -1469,9 +1544,9 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n"
".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n"
".inst 0x4fa2e9fb // sdot v27.4s, v15.16b, v2.4b[3]\n"
- "4:\n"
- "cbz %[blocks], 5f\n"
- "6:\n"
+ "5:\n"
+ "cbz %[blocks], 6f\n"
+ "7:\n"
"ldr q8, [%[b_ptr0]]\n"
"subs %[blocks], %[blocks], #0x1\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
@@ -1496,23 +1571,23 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n"
- "b.ne 6b\n"
- "5:\n"
- "cbz %[odds], 7f\n"
+ "b.ne 7b\n"
+ "6:\n"
+ "cbz %[odds], 8f\n"
"ld1 {v0.b}[0], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[0], [a_ptr1], #1\n"
"ld1 {v2.b}[0], [a_ptr2], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[1], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[1], [a_ptr1], #1\n"
"ld1 {v2.b}[1], [a_ptr2], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[2], [%[a_ptr0]]\n"
"ld1 {v1.b}[2], [a_ptr1]\n"
"ld1 {v2.b}[2], [a_ptr2]\n"
- "8:\n"
+ "9:\n"
"ldr q8, [%[b_ptr0]]\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
"ldr q10, [%[b_ptr0], #0x20]\n"
@@ -1529,7 +1604,7 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n"
- "7:\n"
+ "8:\n"
"str q16, [%[c_ptr0]]\n"
"str q17, [%[c_ptr0], #0x10]\n"
"str q18, [%[c_ptr0], #0x20]\n"
@@ -1569,48 +1644,86 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
"temploadreg1 .req X7\n"
"temploadreg2 .req X8\n"
"temploadreg3 .req X9\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "add a_ptr2, a_ptr1, %[lda]\n"
+ "add c_ptr2, c_ptr1, %[ldc]\n"
+ "add a_ptr3, a_ptr2, %[lda]\n"
+ "add c_ptr3, c_ptr2, %[ldc]\n"
+ "cbnz %[append], 1f\n"
"movi v16.4s, #0\n"
"ldr q0, [%[a_ptr0]]\n"
"movi v17.4s, #0\n"
- "ldr q8, [%[b_ptr0]]\n"
+ "ldr q1, [a_ptr1]\n"
"movi v18.4s, #0\n"
- "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q2, [a_ptr2]\n"
"movi v19.4s, #0\n"
- "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q3, [a_ptr3]\n"
"movi v20.4s, #0\n"
- "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q8, [%[b_ptr0]]\n"
"movi v21.4s, #0\n"
- "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
"movi v22.4s, #0\n"
- "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
"movi v23.4s, #0\n"
- "ldr d14, [%[b_ptr0], #0x60]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
"movi v24.4s, #0\n"
- "ldr temploadreg2, [%[b_ptr0], #0x68]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
"movi v25.4s, #0\n"
- "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
"movi v26.4s, #0\n"
- "ldr q1, [a_ptr1]\n"
+ "ldr d14, [%[b_ptr0], #0x60]\n"
"movi v27.4s, #0\n"
- "ins v14.d[1], temploadreg2\n"
+ "ldr temploadreg2, [%[b_ptr0], #0x68]\n"
"movi v28.4s, #0\n"
- "add a_ptr2, a_ptr1, %[lda]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
"movi v29.4s, #0\n"
- "ldr q2, [a_ptr2]\n"
+ "ins v14.d[1], temploadreg2\n"
"movi v30.4s, #0\n"
- "add a_ptr3, a_ptr2, %[lda]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
"movi v31.4s, #0\n"
- "ldr q3, [a_ptr3]\n"
- "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "cbz %[loops], 2f\n"
+ "b 3f\n"
+ "1:\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #0x10]\n"
+ "ldr q18, [%[c_ptr0], #0x20]\n"
+ "ldr q19, [%[c_ptr0], #0x30]\n"
+ "ldr q20, [c_ptr1]\n"
+ "ldr q21, [c_ptr1, #0x10]\n"
+ "ldr q22, [c_ptr1, #0x20]\n"
+ "ldr q23, [c_ptr1, #0x30]\n"
+ "ldr q24, [c_ptr2]\n"
+ "ldr q25, [c_ptr2, #0x10]\n"
+ "ldr q26, [c_ptr2, #0x20]\n"
+ "ldr q27, [c_ptr2, #0x30]\n"
+ "ldr q28, [c_ptr3]\n"
+ "ldr q29, [c_ptr3, #0x10]\n"
+ "ldr q30, [c_ptr3, #0x20]\n"
+ "ldr q31, [c_ptr3, #0x30]\n"
+ "ldr q0, [%[a_ptr0]]\n"
"add %[a_ptr0], %[a_ptr0], #0x10\n"
- "add c_ptr2, c_ptr1, %[ldc]\n"
+ "ldr q1, [a_ptr1]\n"
"add a_ptr1, a_ptr1, #0x10\n"
- "add c_ptr3, c_ptr2, %[ldc]\n"
+ "ldr q2, [a_ptr2]\n"
"add a_ptr2, a_ptr2, #0x10\n"
+ "ldr q3, [a_ptr3]\n"
"add a_ptr3, a_ptr3, #0x10\n"
+ "ldr q8, [%[b_ptr0]]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr d14, [%[b_ptr0], #0x60]\n"
+ "ldr temploadreg2, [%[b_ptr0], #0x68]\n"
"add %[b_ptr0], %[b_ptr0], #0x80\n"
- "cbz %[loops], 1f\n"
- "2:\n"
+ "ins v14.d[1], temploadreg2\n"
+ "cbz %[loops], 2f\n"
+ "3:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr d15, [%[b_ptr0], #-0x10]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -1870,8 +1983,8 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
"ins v13.d[1], temploadreg1\n"
"prfm PLDL1KEEP, [a_ptr3, #0x40]\n"
"ins v14.d[1], temploadreg2\n"
- "b.ne 2b\n"
- "1:\n"
+ "b.ne 3b\n"
+ "2:\n"
"ldr d15, [%[b_ptr0], #-0x10]\n"
"prfm PSTL1KEEP, [%[c_ptr0]]\n"
"ldr temploadreg3, [%[b_ptr0], #-0x8]\n"
@@ -1879,7 +1992,7 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
"prfm PSTL1KEEP, [c_ptr2]\n"
"prfm PSTL1KEEP, [c_ptr3]\n"
"ins v15.d[1], temploadreg3\n"
- "cbz %[regs], 3f\n"
+ "cbz %[regs], 4f\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr d4, [%[a_ptr0]]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -2098,8 +2211,8 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n"
".inst 0x4fa7e9ff // sdot v31.4s, v15.16b, v7.4b[3]\n"
- "b 4f\n"
- "3:\n"
+ "b 5f\n"
+ "4:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr temploadreg0, [%[b_ptr0], #0x8]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -2189,9 +2302,9 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n"
".inst 0x4fa2e9fb // sdot v27.4s, v15.16b, v2.4b[3]\n"
".inst 0x4fa3e9ff // sdot v31.4s, v15.16b, v3.4b[3]\n"
- "4:\n"
- "cbz %[blocks], 5f\n"
- "6:\n"
+ "5:\n"
+ "cbz %[blocks], 6f\n"
+ "7:\n"
"ldr q8, [%[b_ptr0]]\n"
"subs %[blocks], %[blocks], #0x1\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
@@ -2222,26 +2335,26 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n"
".inst 0x4f83e17f // sdot v31.4s, v11.16b, v3.4b[0]\n"
- "b.ne 6b\n"
- "5:\n"
- "cbz %[odds], 7f\n"
+ "b.ne 7b\n"
+ "6:\n"
+ "cbz %[odds], 8f\n"
"ld1 {v0.b}[0], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[0], [a_ptr1], #1\n"
"ld1 {v2.b}[0], [a_ptr2], #1\n"
"ld1 {v3.b}[0], [a_ptr3], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[1], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[1], [a_ptr1], #1\n"
"ld1 {v2.b}[1], [a_ptr2], #1\n"
"ld1 {v3.b}[1], [a_ptr3], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[2], [%[a_ptr0]]\n"
"ld1 {v1.b}[2], [a_ptr1]\n"
"ld1 {v2.b}[2], [a_ptr2]\n"
"ld1 {v3.b}[2], [a_ptr3]\n"
- "8:\n"
+ "9:\n"
"ldr q8, [%[b_ptr0]]\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
"ldr q10, [%[b_ptr0], #0x20]\n"
@@ -2262,7 +2375,7 @@ void a64_hybrid_s8s32_dot_16x4_a55(const int8_t *A, int lda, const int8_t *B, in
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n"
".inst 0x4f83e17f // sdot v31.4s, v11.16b, v3.4b[0]\n"
- "7:\n"
+ "8:\n"
"str q16, [%[c_ptr0]]\n"
"str q17, [%[c_ptr0], #0x10]\n"
"str q18, [%[c_ptr0], #0x20]\n"
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp
index b48b674621..9f06a48ff5 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp
@@ -35,7 +35,6 @@ namespace arm_gemm {
void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_t *C, int ldc, int M, int N, int K, const int32_t *bias, Activation act, bool append) {
UNUSED(bias);
UNUSED(act);
-
const int K_stride = ((K + 3) / 4) * 4;
const long loops_count = ((K + 16) / 32) - 1;
K -= loops_count * 32;
@@ -76,6 +75,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
switch(M-y) {
case 1:
__asm __volatile (
+ "cbnz %[append], 1f\n"
"movi v16.4s, #0\n"
"ldr q0, [%[a_ptr0]]\n"
"movi v17.4s, #0\n"
@@ -90,8 +90,25 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
"ldr q13, [%[b_ptr0], #0x50]\n"
"ldr q14, [%[b_ptr0], #0x60]\n"
"add %[b_ptr0], %[b_ptr0], #0x80\n"
- "cbz %[loops], 1f\n"
- "2:\n"
+ "cbz %[loops], 2f\n"
+ "b 3f\n"
+ "1:\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #0x10]\n"
+ "ldr q18, [%[c_ptr0], #0x20]\n"
+ "ldr q19, [%[c_ptr0], #0x30]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "ldr q8, [%[b_ptr0]]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q14, [%[b_ptr0], #0x60]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "cbz %[loops], 2f\n"
+ "3:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
@@ -163,11 +180,11 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
"ldr q12, [%[b_ptr0], #-0x40]\n"
"ldr q13, [%[b_ptr0], #-0x30]\n"
"ldr q14, [%[b_ptr0], #-0x20]\n"
- "b.ne 2b\n"
- "1:\n"
+ "b.ne 3b\n"
+ "2:\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
"prfm PSTL1KEEP, [%[c_ptr0]]\n"
- "cbz %[regs], 3f\n"
+ "cbz %[regs], 4f\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q4, [%[a_ptr0]]\n"
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
@@ -228,8 +245,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa4e9b1 // sdot v17.4s, v13.16b, v4.4b[3]\n"
".inst 0x4fa4e9d2 // sdot v18.4s, v14.16b, v4.4b[3]\n"
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
- "b 4f\n"
- "3:\n"
+ "b 5f\n"
+ "4:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q8, [%[b_ptr0]]\n"
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
@@ -255,9 +272,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa0e9b1 // sdot v17.4s, v13.16b, v0.4b[3]\n"
".inst 0x4fa0e9d2 // sdot v18.4s, v14.16b, v0.4b[3]\n"
".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n"
- "4:\n"
- "cbz %[blocks], 5f\n"
- "6:\n"
+ "5:\n"
+ "cbz %[blocks], 6f\n"
+ "7:\n"
"ldr q8, [%[b_ptr0]]\n"
"subs %[blocks], %[blocks], #0x1\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
@@ -270,17 +287,17 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
- "b.ne 6b\n"
- "5:\n"
- "cbz %[odds], 7f\n"
+ "b.ne 7b\n"
+ "6:\n"
+ "cbz %[odds], 8f\n"
"ld1 {v0.b}[0], [%[a_ptr0]], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[1], [%[a_ptr0]], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[2], [%[a_ptr0]]\n"
- "8:\n"
+ "9:\n"
"ldr q8, [%[b_ptr0]]\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
"ldr q10, [%[b_ptr0], #0x20]\n"
@@ -289,7 +306,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
- "7:\n"
+ "8:\n"
"str q16, [%[c_ptr0]]\n"
"str q17, [%[c_ptr0], #0x10]\n"
"str q18, [%[c_ptr0], #0x20]\n"
@@ -304,30 +321,54 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
__asm __volatile (
"a_ptr1 .req X0\n"
"c_ptr1 .req X1\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "cbnz %[append], 1f\n"
"movi v16.4s, #0\n"
"ldr q0, [%[a_ptr0]]\n"
"movi v17.4s, #0\n"
- "ldr q8, [%[b_ptr0]]\n"
+ "ldr q1, [a_ptr1]\n"
"movi v18.4s, #0\n"
- "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q8, [%[b_ptr0]]\n"
"movi v19.4s, #0\n"
- "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
"movi v20.4s, #0\n"
- "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
"movi v21.4s, #0\n"
- "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
"movi v22.4s, #0\n"
- "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
"movi v23.4s, #0\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
"ldr q14, [%[b_ptr0], #0x60]\n"
- "add a_ptr1, %[a_ptr0], %[lda]\n"
- "add c_ptr1, %[c_ptr0], %[ldc]\n"
- "ldr q1, [a_ptr1]\n"
"add %[a_ptr0], %[a_ptr0], #0x10\n"
"add a_ptr1, a_ptr1, #0x10\n"
"add %[b_ptr0], %[b_ptr0], #0x80\n"
- "cbz %[loops], 1f\n"
- "2:\n"
+ "cbz %[loops], 2f\n"
+ "b 3f\n"
+ "1:\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #0x10]\n"
+ "ldr q18, [%[c_ptr0], #0x20]\n"
+ "ldr q19, [%[c_ptr0], #0x30]\n"
+ "ldr q20, [c_ptr1]\n"
+ "ldr q21, [c_ptr1, #0x10]\n"
+ "ldr q22, [c_ptr1, #0x20]\n"
+ "ldr q23, [c_ptr1, #0x30]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "ldr q1, [a_ptr1]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "ldr q8, [%[b_ptr0]]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q14, [%[b_ptr0], #0x60]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "cbz %[loops], 2f\n"
+ "3:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -435,12 +476,12 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
"ldr q14, [%[b_ptr0], #-0x20]\n"
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
- "b.ne 2b\n"
- "1:\n"
+ "b.ne 3b\n"
+ "2:\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
"prfm PSTL1KEEP, [%[c_ptr0]]\n"
"prfm PSTL1KEEP, [c_ptr1]\n"
- "cbz %[regs], 3f\n"
+ "cbz %[regs], 4f\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q4, [%[a_ptr0]]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -535,8 +576,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa5e9d6 // sdot v22.4s, v14.16b, v5.4b[3]\n"
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
- "b 4f\n"
- "3:\n"
+ "b 5f\n"
+ "4:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
"ldr q8, [%[b_ptr0]]\n"
@@ -578,9 +619,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa1e9d6 // sdot v22.4s, v14.16b, v1.4b[3]\n"
".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n"
".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n"
- "4:\n"
- "cbz %[blocks], 5f\n"
- "6:\n"
+ "5:\n"
+ "cbz %[blocks], 6f\n"
+ "7:\n"
"ldr q8, [%[b_ptr0]]\n"
"subs %[blocks], %[blocks], #0x1\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
@@ -599,20 +640,20 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
- "b.ne 6b\n"
- "5:\n"
- "cbz %[odds], 7f\n"
+ "b.ne 7b\n"
+ "6:\n"
+ "cbz %[odds], 8f\n"
"ld1 {v0.b}[0], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[0], [a_ptr1], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[1], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[1], [a_ptr1], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[2], [%[a_ptr0]]\n"
"ld1 {v1.b}[2], [a_ptr1]\n"
- "8:\n"
+ "9:\n"
"ldr q8, [%[b_ptr0]]\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
"ldr q10, [%[b_ptr0], #0x20]\n"
@@ -625,7 +666,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
- "7:\n"
+ "8:\n"
"str q16, [%[c_ptr0]]\n"
"str q17, [%[c_ptr0], #0x10]\n"
"str q18, [%[c_ptr0], #0x20]\n"
@@ -648,38 +689,68 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
"a_ptr2 .req X1\n"
"c_ptr1 .req X2\n"
"c_ptr2 .req X3\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "add a_ptr2, a_ptr1, %[lda]\n"
+ "add c_ptr2, c_ptr1, %[ldc]\n"
+ "cbnz %[append], 1f\n"
"movi v16.4s, #0\n"
"ldr q0, [%[a_ptr0]]\n"
"movi v17.4s, #0\n"
- "ldr q8, [%[b_ptr0]]\n"
+ "ldr q1, [a_ptr1]\n"
"movi v18.4s, #0\n"
- "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q2, [a_ptr2]\n"
"movi v19.4s, #0\n"
- "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q8, [%[b_ptr0]]\n"
"movi v20.4s, #0\n"
- "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
"movi v21.4s, #0\n"
- "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
"movi v22.4s, #0\n"
- "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
"movi v23.4s, #0\n"
- "ldr q14, [%[b_ptr0], #0x60]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
"movi v24.4s, #0\n"
- "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
"movi v25.4s, #0\n"
- "ldr q1, [a_ptr1]\n"
+ "ldr q14, [%[b_ptr0], #0x60]\n"
"movi v26.4s, #0\n"
- "add a_ptr2, a_ptr1, %[lda]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
"movi v27.4s, #0\n"
- "ldr q2, [a_ptr2]\n"
- "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "cbz %[loops], 2f\n"
+ "b 3f\n"
+ "1:\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #0x10]\n"
+ "ldr q18, [%[c_ptr0], #0x20]\n"
+ "ldr q19, [%[c_ptr0], #0x30]\n"
+ "ldr q20, [c_ptr1]\n"
+ "ldr q21, [c_ptr1, #0x10]\n"
+ "ldr q22, [c_ptr1, #0x20]\n"
+ "ldr q23, [c_ptr1, #0x30]\n"
+ "ldr q24, [c_ptr2]\n"
+ "ldr q25, [c_ptr2, #0x10]\n"
+ "ldr q26, [c_ptr2, #0x20]\n"
+ "ldr q27, [c_ptr2, #0x30]\n"
+ "ldr q0, [%[a_ptr0]]\n"
"add %[a_ptr0], %[a_ptr0], #0x10\n"
- "add c_ptr2, c_ptr1, %[ldc]\n"
+ "ldr q1, [a_ptr1]\n"
"add a_ptr1, a_ptr1, #0x10\n"
+ "ldr q2, [a_ptr2]\n"
"add a_ptr2, a_ptr2, #0x10\n"
+ "ldr q8, [%[b_ptr0]]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q14, [%[b_ptr0], #0x60]\n"
"add %[b_ptr0], %[b_ptr0], #0x80\n"
- "cbz %[loops], 1f\n"
- "2:\n"
+ "cbz %[loops], 2f\n"
+ "3:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -823,13 +894,13 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n"
- "b.ne 2b\n"
- "1:\n"
+ "b.ne 3b\n"
+ "2:\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
"prfm PSTL1KEEP, [%[c_ptr0]]\n"
"prfm PSTL1KEEP, [c_ptr1]\n"
"prfm PSTL1KEEP, [c_ptr2]\n"
- "cbz %[regs], 3f\n"
+ "cbz %[regs], 4f\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q4, [%[a_ptr0]]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -958,8 +1029,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n"
- "b 4f\n"
- "3:\n"
+ "b 5f\n"
+ "4:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
".inst 0x4f82e118 // sdot v24.4s, v8.16b, v2.4b[0]\n"
@@ -1017,9 +1088,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n"
".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n"
".inst 0x4fa2e9fb // sdot v27.4s, v15.16b, v2.4b[3]\n"
- "4:\n"
- "cbz %[blocks], 5f\n"
- "6:\n"
+ "5:\n"
+ "cbz %[blocks], 6f\n"
+ "7:\n"
"ldr q8, [%[b_ptr0]]\n"
"subs %[blocks], %[blocks], #0x1\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
@@ -1044,23 +1115,23 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n"
- "b.ne 6b\n"
- "5:\n"
- "cbz %[odds], 7f\n"
+ "b.ne 7b\n"
+ "6:\n"
+ "cbz %[odds], 8f\n"
"ld1 {v0.b}[0], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[0], [a_ptr1], #1\n"
"ld1 {v2.b}[0], [a_ptr2], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[1], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[1], [a_ptr1], #1\n"
"ld1 {v2.b}[1], [a_ptr2], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[2], [%[a_ptr0]]\n"
"ld1 {v1.b}[2], [a_ptr1]\n"
"ld1 {v2.b}[2], [a_ptr2]\n"
- "8:\n"
+ "9:\n"
"ldr q8, [%[b_ptr0]]\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
"ldr q10, [%[b_ptr0], #0x20]\n"
@@ -1077,7 +1148,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n"
- "7:\n"
+ "8:\n"
"str q16, [%[c_ptr0]]\n"
"str q17, [%[c_ptr0], #0x10]\n"
"str q18, [%[c_ptr0], #0x20]\n"
@@ -1109,46 +1180,82 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
"c_ptr1 .req X3\n"
"c_ptr2 .req X4\n"
"c_ptr3 .req X5\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "add a_ptr2, a_ptr1, %[lda]\n"
+ "add c_ptr2, c_ptr1, %[ldc]\n"
+ "add a_ptr3, a_ptr2, %[lda]\n"
+ "add c_ptr3, c_ptr2, %[ldc]\n"
+ "cbnz %[append], 1f\n"
"movi v16.4s, #0\n"
"ldr q0, [%[a_ptr0]]\n"
"movi v17.4s, #0\n"
- "ldr q8, [%[b_ptr0]]\n"
+ "ldr q1, [a_ptr1]\n"
"movi v18.4s, #0\n"
- "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q2, [a_ptr2]\n"
"movi v19.4s, #0\n"
- "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q3, [a_ptr3]\n"
"movi v20.4s, #0\n"
- "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q8, [%[b_ptr0]]\n"
"movi v21.4s, #0\n"
- "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
"movi v22.4s, #0\n"
- "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
"movi v23.4s, #0\n"
- "ldr q14, [%[b_ptr0], #0x60]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
"movi v24.4s, #0\n"
- "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
"movi v25.4s, #0\n"
- "ldr q1, [a_ptr1]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
"movi v26.4s, #0\n"
- "add a_ptr2, a_ptr1, %[lda]\n"
+ "ldr q14, [%[b_ptr0], #0x60]\n"
"movi v27.4s, #0\n"
- "ldr q2, [a_ptr2]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
"movi v28.4s, #0\n"
- "add a_ptr3, a_ptr2, %[lda]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
"movi v29.4s, #0\n"
- "ldr q3, [a_ptr3]\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
"movi v30.4s, #0\n"
- "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
"movi v31.4s, #0\n"
- "add c_ptr2, c_ptr1, %[ldc]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "cbz %[loops], 2f\n"
+ "b 3f\n"
+ "1:\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #0x10]\n"
+ "ldr q18, [%[c_ptr0], #0x20]\n"
+ "ldr q19, [%[c_ptr0], #0x30]\n"
+ "ldr q20, [c_ptr1]\n"
+ "ldr q21, [c_ptr1, #0x10]\n"
+ "ldr q22, [c_ptr1, #0x20]\n"
+ "ldr q23, [c_ptr1, #0x30]\n"
+ "ldr q24, [c_ptr2]\n"
+ "ldr q25, [c_ptr2, #0x10]\n"
+ "ldr q26, [c_ptr2, #0x20]\n"
+ "ldr q27, [c_ptr2, #0x30]\n"
+ "ldr q28, [c_ptr3]\n"
+ "ldr q29, [c_ptr3, #0x10]\n"
+ "ldr q30, [c_ptr3, #0x20]\n"
+ "ldr q31, [c_ptr3, #0x30]\n"
+ "ldr q0, [%[a_ptr0]]\n"
"add %[a_ptr0], %[a_ptr0], #0x10\n"
- "add c_ptr3, c_ptr2, %[ldc]\n"
+ "ldr q1, [a_ptr1]\n"
"add a_ptr1, a_ptr1, #0x10\n"
+ "ldr q2, [a_ptr2]\n"
"add a_ptr2, a_ptr2, #0x10\n"
+ "ldr q3, [a_ptr3]\n"
"add a_ptr3, a_ptr3, #0x10\n"
+ "ldr q8, [%[b_ptr0]]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q14, [%[b_ptr0], #0x60]\n"
"add %[b_ptr0], %[b_ptr0], #0x80\n"
- "cbz %[loops], 1f\n"
- "2:\n"
+ "cbz %[loops], 2f\n"
+ "3:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -1328,14 +1435,14 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n"
".inst 0x4fa7e9ff // sdot v31.4s, v15.16b, v7.4b[3]\n"
- "b.ne 2b\n"
- "1:\n"
+ "b.ne 3b\n"
+ "2:\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
"prfm PSTL1KEEP, [%[c_ptr0]]\n"
"prfm PSTL1KEEP, [c_ptr1]\n"
"prfm PSTL1KEEP, [c_ptr2]\n"
"prfm PSTL1KEEP, [c_ptr3]\n"
- "cbz %[regs], 3f\n"
+ "cbz %[regs], 4f\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q4, [%[a_ptr0]]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -1498,8 +1605,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n"
".inst 0x4fa7e9ff // sdot v31.4s, v15.16b, v7.4b[3]\n"
- "b 4f\n"
- "3:\n"
+ "b 5f\n"
+ "4:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
".inst 0x4f82e118 // sdot v24.4s, v8.16b, v2.4b[0]\n"
@@ -1573,9 +1680,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n"
".inst 0x4fa2e9fb // sdot v27.4s, v15.16b, v2.4b[3]\n"
".inst 0x4fa3e9ff // sdot v31.4s, v15.16b, v3.4b[3]\n"
- "4:\n"
- "cbz %[blocks], 5f\n"
- "6:\n"
+ "5:\n"
+ "cbz %[blocks], 6f\n"
+ "7:\n"
"ldr q8, [%[b_ptr0]]\n"
"subs %[blocks], %[blocks], #0x1\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
@@ -1606,26 +1713,26 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n"
".inst 0x4f83e17f // sdot v31.4s, v11.16b, v3.4b[0]\n"
- "b.ne 6b\n"
- "5:\n"
- "cbz %[odds], 7f\n"
+ "b.ne 7b\n"
+ "6:\n"
+ "cbz %[odds], 8f\n"
"ld1 {v0.b}[0], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[0], [a_ptr1], #1\n"
"ld1 {v2.b}[0], [a_ptr2], #1\n"
"ld1 {v3.b}[0], [a_ptr3], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[1], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[1], [a_ptr1], #1\n"
"ld1 {v2.b}[1], [a_ptr2], #1\n"
"ld1 {v3.b}[1], [a_ptr3], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[2], [%[a_ptr0]]\n"
"ld1 {v1.b}[2], [a_ptr1]\n"
"ld1 {v2.b}[2], [a_ptr2]\n"
"ld1 {v3.b}[2], [a_ptr3]\n"
- "8:\n"
+ "9:\n"
"ldr q8, [%[b_ptr0]]\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
"ldr q10, [%[b_ptr0], #0x20]\n"
@@ -1646,7 +1753,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n"
".inst 0x4f83e17f // sdot v31.4s, v11.16b, v3.4b[0]\n"
- "7:\n"
+ "8:\n"
"str q16, [%[c_ptr0]]\n"
"str q17, [%[c_ptr0], #0x10]\n"
"str q18, [%[c_ptr0], #0x20]\n"
diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
index 188dd0b06d..345060f206 100644
--- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
+++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
@@ -40,7 +40,7 @@ private:
UniqueGemmCommon<To, Tgemm> _subgemm = nullptr;
int32_t *_row_sums = nullptr;
int32_t *_col_sums = nullptr;
- ARequantizeLayer32 _params;
+ Requantize32 _params;
GemmArgs _args;
barrier _barrier;
@@ -125,7 +125,7 @@ public:
QuantizeWrapper(const QuantizeWrapper &) = delete;
QuantizeWrapper operator=(const QuantizeWrapper &) = delete;
- QuantizeWrapper(const GemmArgs &args, const ARequantizeLayer32 &qp) : _params(qp), _args(args), _barrier(args._maxthreads) {
+ QuantizeWrapper(const GemmArgs &args, const Requantize32 &qp) : _params(qp), _args(args), _barrier(args._maxthreads) {
GemmArgs newargs = GemmArgs(args._ci, args._Msize, args._Nsize, args._Ksize, args._nbatches, args._nmulti, args._trA, args._trB, Activation(), args._maxthreads, args._pretransposed_hint, nullptr);
_subgemm = gemm<To, Tgemm>(newargs);
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp
index bffb7ddcb3..00b42cf422 100644
--- a/src/core/NEON/kernels/arm_gemm/quantized.cpp
+++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp
@@ -47,13 +47,19 @@ namespace {
* applied to negative values being shifted right to make sure they round
* properly - if negative values are never output (e.g. fused ReLU) this is
* unnecessary.
+ *
+ * The 'per_channel' template parameter selects between per channel and per
+ * layer requantization - in the former case we need to load vectors of
+ * shifts and multipliers for each column. A separate vector for each
+ * column is set up in any case (and it is hoped that the compiler can elide
+ * the needless movs in the per-layer case).
*/
-template<bool do_shift_correction>
-void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
+template<bool do_shift_correction, bool per_channel>
+void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height,
const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
const int32_t *row_bias, const int32_t *col_bias) {
- const int32x4_t v_mul = vdupq_n_s32(qp.requant_mul);
- const int32x4_t v_shift = vdupq_n_s32(qp.requant_shift);
+ const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul);
+ const int32x4_t v_shift = vdupq_n_s32(qp.per_layer_shift);
const int32x4_t v_minval = vdupq_n_s32(qp.minval);
const int32x4_t v_maxval = vdupq_n_s32(qp.maxval);
const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
@@ -70,6 +76,8 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u
unsigned int odds=(width % 4);
const int32_t *colptr = col_bias;
+ const int32_t *perch_mul_ptr = qp.per_channel_muls;
+ const int32_t *perch_shift_ptr = qp.per_channel_shifts;
const int32_t *in_ptr = input + (row * in_stride);
int8_t *out_ptr = output + (row * out_stride);
@@ -93,6 +101,33 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u
const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1);
while (blocks--) {
+ int32x4_t v_mul0;
+ int32x4_t v_mul1;
+ int32x4_t v_mul2;
+ int32x4_t v_mul3;
+
+ int32x4_t v_shf0;
+ int32x4_t v_shf1;
+ int32x4_t v_shf2;
+ int32x4_t v_shf3;
+
+ if (per_channel) {
+ v_mul0 = vld1q_s32(perch_mul_ptr);
+ v_mul1 = vld1q_s32(perch_mul_ptr + 4);
+ v_mul2 = vld1q_s32(perch_mul_ptr + 8);
+ v_mul3 = vld1q_s32(perch_mul_ptr + 12);
+ perch_mul_ptr += 16;
+
+ v_shf0 = vld1q_s32(perch_shift_ptr);
+ v_shf1 = vld1q_s32(perch_shift_ptr + 4);
+ v_shf2 = vld1q_s32(perch_shift_ptr + 8);
+ v_shf3 = vld1q_s32(perch_shift_ptr + 12);
+ perch_shift_ptr += 16;
+ } else {
+ v_mul0=v_mul1=v_mul2=v_mul3=v_mul;
+ v_shf0=v_shf1=v_shf2=v_shf3=v_shift;
+ }
+
// Load column pointers
int32x4_t v_col0 = vld1q_s32(colptr);
int32x4_t v_col1 = vld1q_s32(colptr + 4);
@@ -136,27 +171,27 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u
v_in13 = vaddq_s32(v_in13, v_col3);
// Quantize - start with multiply
- v_in00 = vqrdmulhq_s32(v_in00, v_mul);
- v_in01 = vqrdmulhq_s32(v_in01, v_mul);
- v_in02 = vqrdmulhq_s32(v_in02, v_mul);
- v_in03 = vqrdmulhq_s32(v_in03, v_mul);
+ v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
+ v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
+ v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
+ v_in03 = vqrdmulhq_s32(v_in03, v_mul3);
- v_in10 = vqrdmulhq_s32(v_in10, v_mul);
- v_in11 = vqrdmulhq_s32(v_in11, v_mul);
- v_in12 = vqrdmulhq_s32(v_in12, v_mul);
- v_in13 = vqrdmulhq_s32(v_in13, v_mul);
+ v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
+ v_in11 = vqrdmulhq_s32(v_in11, v_mul1);
+ v_in12 = vqrdmulhq_s32(v_in12, v_mul2);
+ v_in13 = vqrdmulhq_s32(v_in13, v_mul3);
// Compute and add on corrective offset
if (do_shift_correction) {
- int32x4_t v_temp00 = vandq_s32(v_in00, v_shift);
- int32x4_t v_temp01 = vandq_s32(v_in01, v_shift);
- int32x4_t v_temp02 = vandq_s32(v_in02, v_shift);
- int32x4_t v_temp03 = vandq_s32(v_in03, v_shift);
+ int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
+ int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1);
+ int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2);
+ int32x4_t v_temp03 = vandq_s32(v_in03, v_shf3);
- int32x4_t v_temp10 = vandq_s32(v_in10, v_shift);
- int32x4_t v_temp11 = vandq_s32(v_in11, v_shift);
- int32x4_t v_temp12 = vandq_s32(v_in12, v_shift);
- int32x4_t v_temp13 = vandq_s32(v_in13, v_shift);
+ int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
+ int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1);
+ int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2);
+ int32x4_t v_temp13 = vandq_s32(v_in13, v_shf3);
v_temp00 = vshrq_n_s32(v_temp00, 31);
v_temp01 = vshrq_n_s32(v_temp01, 31);
@@ -179,15 +214,15 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u
v_in13 = vqaddq_s32(v_in13, v_temp13);
}
- v_in00 = vrshlq_s32(v_in00, v_shift);
- v_in01 = vrshlq_s32(v_in01, v_shift);
- v_in02 = vrshlq_s32(v_in02, v_shift);
- v_in03 = vrshlq_s32(v_in03, v_shift);
+ v_in00 = vrshlq_s32(v_in00, v_shf0);
+ v_in01 = vrshlq_s32(v_in01, v_shf1);
+ v_in02 = vrshlq_s32(v_in02, v_shf2);
+ v_in03 = vrshlq_s32(v_in03, v_shf3);
- v_in10 = vrshlq_s32(v_in10, v_shift);
- v_in11 = vrshlq_s32(v_in11, v_shift);
- v_in12 = vrshlq_s32(v_in12, v_shift);
- v_in13 = vrshlq_s32(v_in13, v_shift);
+ v_in10 = vrshlq_s32(v_in10, v_shf0);
+ v_in11 = vrshlq_s32(v_in11, v_shf1);
+ v_in12 = vrshlq_s32(v_in12, v_shf2);
+ v_in13 = vrshlq_s32(v_in13, v_shf3);
v_in00 = vaddq_s32(v_in00, v_c_offset);
v_in01 = vaddq_s32(v_in01, v_c_offset);
@@ -235,6 +270,20 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u
}
while (regs--) {
+ int32x4_t v_mul0;
+ int32x4_t v_shf0;
+
+ if (per_channel) {
+ v_mul0 = vld1q_s32(perch_mul_ptr);
+ perch_mul_ptr += 4;
+
+ v_shf0 = vld1q_s32(perch_shift_ptr);
+ perch_shift_ptr += 4;
+ } else {
+ v_mul0=v_mul;
+ v_shf0=v_shift;
+ }
+
// Load column pointers
int32x4_t v_col0 = vld1q_s32(colptr);
colptr += 4;
@@ -258,15 +307,15 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u
v_in10 = vaddq_s32(v_in10, v_col0);
// Quantize - start with multiply
- v_in00 = vqrdmulhq_s32(v_in00, v_mul);
+ v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
- v_in10 = vqrdmulhq_s32(v_in10, v_mul);
+ v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
// Compute and add on corrective offset
if (do_shift_correction) {
- int32x4_t v_temp00 = vandq_s32(v_in00, v_shift);
+ int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
- int32x4_t v_temp10 = vandq_s32(v_in10, v_shift);
+ int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
v_temp00 = vshrq_n_s32(v_temp00, 31);
@@ -277,9 +326,9 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u
v_in10 = vqaddq_s32(v_in10, v_temp10);
}
- v_in00 = vrshlq_s32(v_in00, v_shift);
+ v_in00 = vrshlq_s32(v_in00, v_shf0);
- v_in10 = vrshlq_s32(v_in10, v_shift);
+ v_in10 = vrshlq_s32(v_in10, v_shf0);
v_in00 = vaddq_s32(v_in00, v_c_offset);
@@ -307,21 +356,40 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u
int32x4_t v_col0 = vdupq_n_s32(0);
int32x4_t v_in00 = vdupq_n_s32(0);
int32x4_t v_in10 = vdupq_n_s32(0);
+ int32x4_t v_mul0 = vdupq_n_s32(0);
+ int32x4_t v_shf0 = vdupq_n_s32(0);
+
+ if (!per_channel) {
+ v_mul0 = v_mul;
+ v_shf0 = v_shift;
+ }
do {
v_col0 = vld1q_lane_s32(colptr, v_col0, 0);
v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0);
v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0);
+ if (per_channel) {
+ v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0);
+ v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0);
+ }
if (odds == 1) { break; }
v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1);
v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1);
v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1);
+ if (per_channel) {
+ v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1);
+ v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1);
+ }
if (odds == 2) { break; }
v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2);
v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2);
v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2);
+ if (per_channel) {
+ v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2);
+ v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2);
+ }
} while (0);
// Add on row sum and bias constant
@@ -335,15 +403,15 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u
v_in10 = vaddq_s32(v_in10, v_col0);
// Quantize - start with multiply
- v_in00 = vqrdmulhq_s32(v_in00, v_mul);
+ v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
- v_in10 = vqrdmulhq_s32(v_in10, v_mul);
+ v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
// Compute and add on corrective offset
if (do_shift_correction) {
- int32x4_t v_temp00 = vandq_s32(v_in00, v_shift);
+ int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
- int32x4_t v_temp10 = vandq_s32(v_in10, v_shift);
+ int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
v_temp00 = vshrq_n_s32(v_temp00, 31);
@@ -354,9 +422,9 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u
v_in10 = vqaddq_s32(v_in10, v_temp10);
}
- v_in00 = vrshlq_s32(v_in00, v_shift);
+ v_in00 = vrshlq_s32(v_in00, v_shf0);
- v_in10 = vrshlq_s32(v_in10, v_shift);
+ v_in10 = vrshlq_s32(v_in10, v_shf0);
v_in00 = vaddq_s32(v_in00, v_c_offset);
@@ -391,23 +459,33 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u
} // anonymous namespace
template<typename Tin, typename Tout>
-void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
+void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride,
const int32_t *row_bias, const int32_t *col_bias) {
- if (qp.minval >= qp.c_offset) {
- requantize_block_32_int<false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ if (qp.per_channel_requant) {
+ if (qp.minval >= qp.c_offset) {
+ requantize_block_32_int<false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ } else {
+ requantize_block_32_int<true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ }
} else {
- requantize_block_32_int<true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ if (qp.minval >= qp.c_offset) {
+ requantize_block_32_int<false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ } else {
+ requantize_block_32_int<true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ }
}
}
-template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
+template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
const int32_t *row_bias, const int32_t *col_bias);
-template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
+template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride,
const int32_t *row_bias, const int32_t *col_bias);
@@ -448,7 +526,7 @@ template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int wid
*/
namespace {
struct row_sum_helpers {
- const ARequantizeLayer32 &qp;
+ const Requantize32 &qp;
/* Load a full 16 byte vector, pairwise accumulate into 'sum' with uadalp or sadalp */
template<typename T>
@@ -571,7 +649,7 @@ namespace {
}
}
- row_sum_helpers(const ARequantizeLayer32 &qp) : qp(qp) { }
+ row_sum_helpers(const Requantize32 &qp) : qp(qp) { }
};
template<>
@@ -612,8 +690,14 @@ namespace {
}
template<typename T>
-void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
+void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height,
const T *input, unsigned int in_stride, int32_t *row_bias) {
+ /* If the 'b' offset is zero, just skip this entirely. */
+ if (qp.b_offset == 0) {
+ memset(row_bias, 0, height * sizeof(int32_t));
+ return;
+ }
+
row_sum_helpers thehelpers(qp);
const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset);
@@ -663,8 +747,8 @@ void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned
}
/* Instantiate the two versions for uint8_t and int8_t. */
-template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *);
-template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *);
+template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *);
+template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *);
template<unsigned int active_rows, typename T>
inline void add_block(const T *input, unsigned int in_stride, int32_t *output);
@@ -739,41 +823,44 @@ inline void add_block(const int8_t *input, unsigned int in_stride, int32_t *outp
* in cases where we are not computing the first columns of the output (i.e.
* in multithreaded cases where we divide columns across threads) */
template<typename T>
-void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col) {
- memset(reinterpret_cast<void *>(col_bias), 0, width * sizeof(int32_t));
-
- for (unsigned int row=0; row<height; row+=4) {
- unsigned int numrows=std::min(height-row, 4u);
-
- for (unsigned int col=0; col<width; col+=16) {
- unsigned int numcols=std::min(width-col, 16u);
-
- if (numcols==16) {
- switch(numrows) {
- default:
- case 1:
- add_block<1>(input + row * in_stride + col, in_stride, col_bias + col);
- break;
-
- case 2:
- add_block<2>(input + row * in_stride + col, in_stride, col_bias + col);
- break;
-
- case 3:
- add_block<3>(input + row * in_stride + col, in_stride, col_bias + col);
- break;
-
- case 4:
- add_block<4>(input + row * in_stride + col, in_stride, col_bias + col);
- break;
- }
- } else {
- for (; col<width; col++) {
- int32_t sum=0;
- for (unsigned int r=0; r<numrows; r++) {
- sum += input[(row + r)*in_stride + col];
+void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col) {
+ /* Only actually add up the columns if a_offset is non-zero. */
+ if (qp.a_offset != 0) {
+ memset(reinterpret_cast<void *>(col_bias), 0, width * sizeof(int32_t));
+
+ for (unsigned int row=0; row<height; row+=4) {
+ unsigned int numrows=std::min(height-row, 4u);
+
+ for (unsigned int col=0; col<width; col+=16) {
+ unsigned int numcols=std::min(width-col, 16u);
+
+ if (numcols==16) {
+ switch(numrows) {
+ default:
+ case 1:
+ add_block<1>(input + row * in_stride + col, in_stride, col_bias + col);
+ break;
+
+ case 2:
+ add_block<2>(input + row * in_stride + col, in_stride, col_bias + col);
+ break;
+
+ case 3:
+ add_block<3>(input + row * in_stride + col, in_stride, col_bias + col);
+ break;
+
+ case 4:
+ add_block<4>(input + row * in_stride + col, in_stride, col_bias + col);
+ break;
+ }
+ } else {
+ for (; col<width; col++) {
+ int32_t sum=0;
+ for (unsigned int r=0; r<numrows; r++) {
+ sum += input[(row + r)*in_stride + col];
+ }
+ col_bias[col] += sum;
}
- col_bias[col] += sum;
}
}
}
@@ -792,8 +879,8 @@ void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned
}
}
-template void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
-template void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
+template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
+template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.hpp b/src/core/NEON/kernels/arm_gemm/quantized.hpp
index a22750796c..a91a888ad9 100644
--- a/src/core/NEON/kernels/arm_gemm/quantized.hpp
+++ b/src/core/NEON/kernels/arm_gemm/quantized.hpp
@@ -26,16 +26,16 @@
namespace arm_gemm {
template<typename Tin, typename Tout>
-void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
+void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride,
const int32_t *row_bias, const int32_t *col_bias);
template<typename T>
-void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
+void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height,
const T *input, unsigned int in_stride, int32_t *row_bias);
template<typename T>
-void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height,
+void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height,
const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth,
unsigned int multi, unsigned int first_col);
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 65d800cb0c..4e43d04446 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -164,6 +164,23 @@ public:
arm_gemm::GemmArgs args, const GEMMInfo &gemm_info,
MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
+ /** Set requantization shifts to be used
+ *
+ * @param[in] shifts Requantization shifts
+ *
+ * @return Pointer to the shift data
+ */
+ /** Set requantization data to be used
+ *
+ *
+ * @param shifts Requantization shifts
+ * @param multipliers Requantization multipliers
+ *
+ * @return A tuple with the pointers to the shift and multiplier data respectively
+ */
+ std::tuple<const int32_t *, const int32_t *> set_requantize_data(const std::vector<int32_t> &shifts,
+ const std::vector<int32_t> &multipliers);
+
// Inherited methods overridden:
void run() override;
void prepare() override;
@@ -212,9 +229,24 @@ private:
FallbackTransform<TypeInput, TypeOutput> _weights_transform{};
/** GEMM kernel description */
arm_gemm::KernelDescription _kernel_info{};
+ /** Per channel quantization shifts */
+ std::vector<int32_t> _shifts{};
+ /** Per channel quantization multipliers */
+ std::vector<int32_t> _multipliers{};
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
+std::tuple<const int32_t *, const int32_t *> Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
+ const std::vector<int32_t> &multipliers)
+{
+ _multipliers = multipliers;
+ _shifts = shifts;
+ std::transform(_shifts.begin(), _shifts.end(), _shifts.begin(),
+ std::bind1st(std::multiplies<int32_t>(), -1));
+ return std::make_tuple(_shifts.data(), _multipliers.data());
+}
+
+template <typename TypeInput, typename TypeOutput, class OutputStage>
void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
arm_gemm::GemmArgs args, const GEMMInfo &gemm_info,
MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
@@ -435,18 +467,32 @@ void create_arm_gemm_quant(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &a
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, activation, num_threads, gemm_info.pretranpose_B());
+ // Create arm_gemm fallback
+ auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
+
// Configure requantization info
const int32_t a_offset = -a->info()->quantization_info().uniform().offset;
const int32_t b_offset = -b->info()->quantization_info().uniform().offset;
const GEMMLowpOutputStageInfo os_info = gemm_info.gemmlowp_output_stage();
- const arm_gemm::ARequantizeLayer32 gemm_requant_info(nullptr, 0,
- a_offset, b_offset, os_info.gemmlowp_offset,
- -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier,
- os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
+ arm_gemm::Requantize32 gemm_requant_info{};
+ if(os_info.gemmlowp_shifts.size() > 1)
+ {
+ const auto requantize_data = fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers);
+ gemm_requant_info = arm_gemm::Requantize32(nullptr, 0,
+ a_offset, b_offset, os_info.gemmlowp_offset,
+ std::get<0>(requantize_data), std::get<1>(requantize_data),
+ os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
+ }
+ else
+ {
+ gemm_requant_info = arm_gemm::Requantize32(nullptr, 0,
+ a_offset, b_offset, os_info.gemmlowp_offset,
+ -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier,
+ os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
+ }
- // Create arm_gemm fallback
- auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::ARequantizeLayer32>>();
+ // Configure fallback
fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info);
arm_gemm = std::move(fallback);
}
@@ -484,7 +530,6 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8_SIGNED && d->data_type() != DataType::S32, "Only S32 output supported for QASYMM8_SIGNED input");
return Status{};
}
@@ -524,7 +569,14 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const
break;
case DataType::S8:
case DataType::QASYMM8_SIGNED:
- create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ if(d->info()->data_type() == DataType::S32)
+ {
+ create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ }
+ else
+ {
+ create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ }
break;
#endif /* __aarch64__ */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index 440f043527..38481afe88 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -119,7 +119,7 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
case DataType::U8:
case DataType::S8:
{
- if(a_to_use->info()->data_type() == DataType::QASYMM8 && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+ if(is_data_type_quantized_asymmetric(a_to_use->info()->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
{
_asm_glue.configure(a_to_use, b, c, output, gemm_info);
_fused_assembly_path = _asm_glue.is_configured();