aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp26
1 files changed, 24 insertions, 2 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index 5f2840b243..91012218e5 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -31,6 +31,7 @@
#include "gemm_hybrid.hpp"
#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
+#include "gemm_interleaved_pretransposed_2d.hpp"
#include "kernels/a32_sgemm_8x6.hpp"
#include "kernels/a64_hgemm_24x8.hpp"
@@ -60,8 +61,19 @@ static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = {
#if defined(__aarch64__) && (defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS))
{
+ GemmMethod::GEMM_INTERLEAVED_2D,
+ "hgemm_24x8_2d",
+#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ [](const GemmArgs &args) { return args._ci->has_fp16(); },
+#else
+ nullptr,
+#endif
+ [](const GemmArgs &args) { return args._maxthreads >= 8; },
+ [](const GemmArgs &args) { return new GemmInterleavedPretransposed2d<hgemm_24x8, __fp16, __fp16>(args); }
+},
+{
GemmMethod::GEMM_INTERLEAVED,
- "hgemm_24x8",
+ "hgemm_24x8_1d",
#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
[](const GemmArgs &args) { return args._ci->has_fp16(); },
#else
@@ -70,11 +82,21 @@ static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = {
nullptr,
[](const GemmArgs &args) { return new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(args); }
},
+
#endif // aarch64 && FP16
#ifdef __aarch64__
+//Pretranpose, 2D split
+{
+ GemmMethod::GEMM_INTERLEAVED_2D,
+ "sgemm_12x8_2d",
+ nullptr,
+ [](const GemmArgs &args) { return args._maxthreads >= 8; },
+ [](const GemmArgs &args) { return new GemmInterleavedPretransposed2d<sgemm_12x8, __fp16, __fp16>(args); }
+},
+//Tranpose, 1D split, with blockmanager
{
GemmMethod::GEMM_INTERLEAVED,
- "sgemm_12x8",
+ "sgemm_12x8_1d",
nullptr,
nullptr,
[](const GemmArgs &args) { return new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(args); }