aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp28
1 files changed, 27 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
index 515d55c73b..2d743a4bd6 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020, 2022 Arm Limited.
+ * Copyright (c) 2017-2020, 2022-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -57,6 +57,8 @@
#include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL.hpp"
#endif // ARM_COMPUTE_ENABLE_SME2
+#include "kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp"
+#include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
#include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp"
#include "kernels/sve_hybrid_bf16fp32_mmla_6x4VL.hpp"
#include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp"
@@ -204,6 +206,30 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
),
+GemmImplementation<bfloat16, float>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_ffinterleaved_bf16fp32_mmla_8x12",
+ KernelWeightFormat::VL256_BL64,
+ [](const GemmArgs &args) { return args._ci->has_bf16(); },
+ [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
+ [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, bfloat16, float>(args); }
+),
+GemmImplementation<bfloat16, float>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_ffhybrid_bf16fp32_mmla_6x16",
+ KernelWeightFormat::VL256_BL64,
+ [](const GemmArgs &args) { return args._ci->has_bf16(); },
+ [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_bf16fp32_mmla_6x16, bfloat16, float>::estimate_cycles<bfloat16>(args); },
+ [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_bf16fp32_mmla_6x16, bfloat16, float>(args); }
+),
+GemmImplementation<bfloat16, float>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_ffinterleaved_bf16fp32_dot_8x12",
+ KernelWeightFormat::VL128_BL32,
+ [](const GemmArgs &args) { return args._ci->has_bf16(); },
+ [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
+ [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
+),
#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,