aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp143
1 files changed, 113 insertions, 30 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
index 04cac6095c..05c5116bf3 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -25,68 +25,151 @@
#include "arm_gemm.hpp"
-#include "kernels/a64_hybrid_s8s32_dot_16x4.hpp"
-#include "kernels/a64_smallK_hybrid_s8s32_dot_4x6.hpp"
-#include "kernels/a64_smallK_hybrid_s8s32_dot_4x8.hpp"
-#include "kernels/sve_hybrid_s8s32_dot_4VLx4.hpp"
-#include "kernels/sve_smallK_hybrid_s8s32_dot_1VLx8.hpp"
+#include "kernels/a64_gemm_s16_8x12.hpp"
+#include "kernels/a64_gemm_s8_4x4.hpp"
+#include "kernels/a64_gemm_s8_8x12.hpp"
+#include "kernels/a64_hybrid_s8qa_dot_4x16.hpp"
+#include "kernels/a64_hybrid_s8qs_dot_6x16.hpp"
+#include "kernels/a64_hybrid_s8s32_dot_6x16.hpp"
+#include "kernels/a64_interleaved_s8s32_mmla_8x12.hpp"
+#include "kernels/a64_smallK_hybrid_s8s32_dot_6x4.hpp"
+#include "kernels/a64_smallK_hybrid_s8s32_dot_8x4.hpp"
+#include "kernels/sve_hybrid_s8s32_dot_6x4VL.hpp"
+#include "kernels/sve_hybrid_s8qa_dot_4x4VL.hpp"
+#include "kernels/sve_hybrid_s8qs_dot_6x4VL.hpp"
+#include "kernels/sve_interleaved_s8s32_dot_8x3VL.hpp"
+#include "kernels/sve_interleaved_s8s32_mmla_8x3VL.hpp"
+#include "kernels/sve_smallK_hybrid_s8s32_dot_8x1VL.hpp"
+
+#include "gemm_hybrid_indirect.hpp"
#include "gemm_hybrid_quantized.hpp"
+#include "gemm_hybrid_quantized_inline.hpp"
+#include "gemm_interleaved.hpp"
#include "quantize_wrapper.hpp"
+#include "utils.hpp"
namespace arm_gemm {
static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods[] =
{
#ifdef __ARM_FEATURE_SVE
+#ifdef MMLA_INT8
{
- GemmMethod::GEMM_HYBRID_QUANTIZED,
- "smallK_hybrid_s8s32_dot_1VLx8",
- [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64; },
+ GemmMethod::GEMM_INTERLEAVED,
+ "sve_interleaved_s8s32_mmla_8x3VL",
+ [](const GemmArgs &args, const Requantize32 &) { return (args._Ksize>8); },
nullptr,
- [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_s8s32_dot_1VLx8, int8_t, int8_t>(args, qp); }
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, int8_t>(args, qp); }
},
+#endif
{
GemmMethod::GEMM_HYBRID_QUANTIZED,
- "hybrid_s8s32_dot_4VLx4",
- [](const GemmArgs &args, const Requantize32 &) { return args._Ksize>=16; },
- [](const GemmArgs &args, const Requantize32 &) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
- [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_s8s32_dot_4VLx4, int8_t, int8_t>(args, qp); }
+ "sve_smallK_hybrid_s8s32_dot_8x1VL",
+ [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64 && !args._indirect_input; },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_sve_smallK_hybrid_s8s32_dot_8x1VL, int8_t, int8_t>(args, qp); }
+},
+#ifdef SVE2
+{
+ GemmMethod::GEMM_HYBRID,
+ "sve_hybrid_s8qs_dot_6x4VL",
+ [](const GemmArgs &args, const Requantize32 &qp) { return quant_hybrid_symmetric(qp); },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_s8qs_dot_6x4VL, int8_t, int8_t, Requantize32>(args, qp); }
+},
+{
+ GemmMethod::GEMM_HYBRID,
+ "sve_hybrid_s8qa_dot_4x4VL",
+ [](const GemmArgs &args, const Requantize32 &qp) { return quant_hybrid_asymmetric(qp); },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_s8qa_dot_4x4VL, int8_t, int8_t, Requantize32>(args, qp); }
},
#endif
{
- GemmMethod::GEMM_HYBRID_QUANTIZED,
- "smallK_hybrid_s8s32_dot_4x8",
- [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32); },
+ GemmMethod::GEMM_HYBRID,
+ "sve_hybrid_s8s32_dot_6x4VL",
+ nullptr,
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_s8s32_dot_6x4VL, int8_t, int8_t, Requantize32, true>(args, qp); }
+},
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sve_interleaved_s8s32_dot_8x3VL",
+ [](const GemmArgs &args, const Requantize32 &) { return (args._Ksize>4); },
nullptr,
- [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_s8s32_dot_4x8, int8_t, int8_t>(args, qp); }
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, int8_t>(args, qp); }
},
+#endif // SVE
+#ifdef MMLA_INT8
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_interleaved_s8s32_mmla_8x12",
+ [](const GemmArgs &args, const Requantize32 &) { return (args._Ksize>8); },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, int8_t>(args, qp); }
+},
+#endif
{
GemmMethod::GEMM_HYBRID_QUANTIZED,
- "smallK_hybrid_s8s32_dot_4x6",
- [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64); },
+ "a64_smallK_hybrid_s8s32_dot_8x4",
+ [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
nullptr,
- [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_s8s32_dot_4x6, int8_t, int8_t>(args, qp); }
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_a64_smallK_hybrid_s8s32_dot_8x4, int8_t, int8_t>(args, qp); }
},
{
GemmMethod::GEMM_HYBRID_QUANTIZED,
- "hybrid_s8s32_dot_16x4",
- [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && args._Ksize>=16; },
- [](const GemmArgs &args, const Requantize32 &) { return args._Nsize<=256 && args._Ksize>128; },
- [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_s8s32_dot_16x4, int8_t, int8_t>(args, qp); }
+ "a64_smallK_hybrid_s8s32_dot_6x4",
+ [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_a64_smallK_hybrid_s8s32_dot_6x4, int8_t, int8_t>(args, qp); }
},
-/** QUANTIZE_WRAPPER_2D enables 2D parallelisation hint for IScheduler in NEGEMMAssemblyDispatch */
{
- GemmMethod::QUANTIZE_WRAPPER_2D,
- "quantized_wrapper_2d",
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_gemm_s16_8x12",
nullptr,
- [](const GemmArgs &args, const Requantize32 &) { return (args._maxthreads >= 8) && (args._Msize >= 8) && (args._Nsize >= 8);},
- [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<int8_t, int8_t, int32_t>(args, qp); }
+ [](const GemmArgs &args, const Requantize32 &) { return args._ci->get_cpu_model() == CPUModel::A53; },
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_s16_8x12, int8_t, int8_t>(args, qp); }
+},
+{
+ GemmMethod::GEMM_HYBRID,
+ "a64_hybrid_s8qs_dot_6x16",
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_dotprod() && quant_hybrid_symmetric(qp); },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_s8qs_dot_6x16, int8_t, int8_t, Requantize32>(args, qp); }
+},
+{
+ GemmMethod::GEMM_HYBRID,
+ "a64_hybrid_s8qa_dot_4x16",
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_dotprod() && quant_hybrid_asymmetric(qp); },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_s8qa_dot_4x16, int8_t, int8_t, Requantize32>(args, qp); }
+},
+{
+ GemmMethod::GEMM_HYBRID,
+ "a64_hybrid_s8s32_dot_6x16",
+ [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod(); },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_s8s32_dot_6x16, int8_t, int8_t, Requantize32, true>(args, qp); }
+},
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_gemm_s8_8x12",
+ [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod(); },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_s8_8x12, int8_t, int8_t>(args, qp); }
+},
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_gemm_s8_4x4",
+ nullptr,
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_s8_4x4, int8_t, int8_t>(args, qp); }
},
{
GemmMethod::QUANTIZE_WRAPPER,
"quantized_wrapper",
- nullptr,
+ [](const GemmArgs &args, const Requantize32 &) { return !args._indirect_input; },
nullptr,
[](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<int8_t, int8_t, int32_t>(args, qp); }
},