diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/transforms')
3 files changed, 39 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_2x2.hpp index dfa0167c00..7120d1d33e 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_2x2.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_2x2.hpp @@ -193,4 +193,17 @@ void Transform<1, 2, true, VLType::SME>( ); } +template<> +void Transform<1, 2, true, VLType::SME>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_1VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + #endif diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_2x2.hpp index 13e0a38ebc..3fc3920500 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_2x2.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_2x2.hpp @@ -195,4 +195,17 @@ void Transform<2, 2, true, VLType::SME>( ); } +template<> +void Transform<2, 2, true, VLType::SME>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_2VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + #endif diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_2x2.hpp index 8badde53a9..9b28578217 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_2x2.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_2x2.hpp @@ -155,4 +155,17 @@ void Transform<4, 2, true, VLType::SME>( ); } +template<> +void Transform<4, 2, true, VLType::SME>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_4VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + #endif |