diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp index aa761b46e4..1e4de4a39e 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp @@ -32,12 +32,12 @@ namespace arm_gemm { -static const GemmImplementation<bfloat16, bfloat16> gemm_bf16bf16_methods[] = +static const GemmImplementation<bfloat16, bfloat16, bfloat16> gemm_bf16bf16_methods[] = { #ifdef __aarch64__ #ifdef ARM_COMPUTE_ENABLE_BF16 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS -GemmImplementation<bfloat16, bfloat16>::with_estimate( +GemmImplementation<bfloat16, bfloat16, bfloat16>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "a64_ffinterleaved_bf16fp32_mmla_8x12", KernelWeightFormat::VL256_BL64, @@ -45,7 +45,7 @@ GemmImplementation<bfloat16, bfloat16>::with_estimate( [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, bfloat16, bfloat16>::estimate_cycles<bfloat16>(args); }, [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, bfloat16, bfloat16>(args); } ), -GemmImplementation<bfloat16, bfloat16>::with_estimate( +GemmImplementation<bfloat16, bfloat16, bfloat16>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "sve_ffinterleaved_bf16fp32_mmla_8x3VL", KernelWeightFormat::VL2VL_BL64, @@ -66,14 +66,14 @@ GemmImplementation<bfloat16, bfloat16>::with_estimate( }; template<> -const GemmImplementation<bfloat16, bfloat16> *gemm_implementation_list<bfloat16, bfloat16>() { +const GemmImplementation<bfloat16, bfloat16, bfloat16> *gemm_implementation_list<bfloat16, bfloat16, bfloat16>() { return gemm_bf16bf16_methods; } /* Explicitly instantiate the external functions for these types. */ -template UniqueGemmCommon<bfloat16, bfloat16> gemm<bfloat16, bfloat16, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<bfloat16, bfloat16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); -template KernelDescription get_gemm_method<bfloat16, bfloat16, Nothing>(const GemmArgs &args, const Nothing &); -template std::vector<KernelDescription> get_compatible_kernels<bfloat16, bfloat16, Nothing>(const GemmArgs &args, const Nothing &); +template UniqueGemmCommon<bfloat16, bfloat16, bfloat16> gemm<bfloat16, bfloat16, bfloat16, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<bfloat16, bfloat16, bfloat16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); +template KernelDescription get_gemm_method<bfloat16, bfloat16, bfloat16, Nothing>(const GemmArgs &args, const Nothing &); +template std::vector<KernelDescription> get_compatible_kernels<bfloat16, bfloat16, bfloat16, Nothing>(const GemmArgs &args, const Nothing &); } // namespace arm_gemm |