diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp | 18 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp | 18 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 18 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp | 2 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | 7 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_int16.cpp | 2 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_int8.cpp | 2 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp | 2 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp | 2 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp | 2 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp | 2 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/utils.hpp | 2 |
12 files changed, 40 insertions, 37 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp index 50fc5bdb8a..58e4861bc0 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp @@ -33,21 +33,21 @@ #include "kernels/a32_sgemm_8x6.hpp" -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp" #include "kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp" #include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp" -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_hybrid_bf16fp32_dot_6x16.hpp" #include "kernels/a64_hybrid_bf16fp32_mmla_6x16.hpp" #include "kernels/a64_interleaved_bf16fp32_dot_8x12.hpp" #include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp" #include "kernels/a64_sgemm_8x12.hpp" -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp" #include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp" -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp" #include "kernels/sve_hybrid_bf16fp32_mmla_6x4VL.hpp" #include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp" @@ -89,7 +89,7 @@ GemmImplementation<bfloat16, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>(args); } ), -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS GemmImplementation<bfloat16, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "sve_ffinterleaved_bf16fp32_mmla_8x3VL", @@ -106,7 +106,7 @@ GemmImplementation<bfloat16, float>::with_estimate( [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>(args); } ), -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // ARM_COMPUTE_ENABLE_SVE GemmImplementation<bfloat16, float>::with_estimate( GemmMethod::GEMM_HYBRID, @@ -136,7 +136,7 @@ GemmImplementation<bfloat16, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>(args); } ), -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS GemmImplementation<bfloat16, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "a64_ffinterleaved_bf16fp32_mmla_8x12", @@ -161,7 +161,7 @@ GemmImplementation<bfloat16, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); }, [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>(args); } ), -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS GemmImplementation<bfloat16, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "a64_sgemm_8x12", @@ -197,7 +197,7 @@ const GemmImplementation<bfloat16, float> *gemm_implementation_list<bfloat16, fl /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<bfloat16, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template KernelDescription get_gemm_method<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index 2796b0d204..d749dce98d 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -34,17 +34,17 @@ #include "gemm_interleaved.hpp" #include "kernels/a32_sgemm_8x6.hpp" -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_ffhybrid_fp16_mla_6x32.hpp" #include "kernels/a64_ffinterleaved_fp16_mla_8x24.hpp" -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_hgemm_8x24.hpp" #include "kernels/a64_hybrid_fp16_mla_6x32.hpp" #include "kernels/a64_sgemm_8x12.hpp" -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp" #include "kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp" -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_hybrid_fp16_mla_6x4VL.hpp" #include "kernels/sve_interleaved_fp16_mla_8x3VL.hpp" @@ -66,7 +66,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); } ), -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS GemmImplementation<__fp16, __fp16>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "sve_ffinterleaved_fp16_mla_8x3VL", @@ -83,7 +83,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate( [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); } ), -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // ARM_COMPUTE_ENABLE_SVE #if defined(__aarch64__) GemmImplementation<__fp16, __fp16>::with_estimate( @@ -100,7 +100,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>(args); } ), -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS GemmImplementation<__fp16, __fp16>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "a64_ffinterleaved_fp16_mla_8x24", @@ -117,7 +117,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate( [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>(args); } ), -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS { GemmMethod::GEMM_INTERLEAVED, "a64_sgemm_8x12", @@ -152,7 +152,7 @@ const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp1 /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<__fp16, __fp16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 4f7e191fb3..0fc9e8b912 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -31,12 +31,12 @@ #include "gemv_pretransposed.hpp" #include "kernels/a32_sgemm_8x6.hpp" -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp" #include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp" #include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp" #include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp" -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_hybrid_fp32bf16fp32_mmla_4x24.hpp" #include "kernels/a64_hybrid_fp32bf16fp32_mmla_6x16.hpp" #include "kernels/a64_hybrid_fp32_mla_4x24.hpp" @@ -48,12 +48,12 @@ #include "kernels/a64_smallK_hybrid_fp32_mla_6x4.hpp" #include "kernels/a64_smallK_hybrid_fp32_mla_8x4.hpp" -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp" #include "kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp" #include "kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp" #include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp" -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_hybrid_fp32bf16fp32_mmla_4x6VL.hpp" #include "kernels/sve_hybrid_fp32bf16fp32_mmla_6x4VL.hpp" #include "kernels/sve_hybrid_fp32_mla_6x4VL.hpp" @@ -165,7 +165,7 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>(args); } ), - #ifdef ENABLE_FIXED_FORMAT_KERNELS + #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #ifdef ARM_COMPUTE_ENABLE_BF16 GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED, @@ -200,7 +200,7 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>(args); } ), -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // ARM_COMPUTE_ENABLE_SVE // Cortex-A35 specific kernel - use for any problem on A35, and never in any other cases. { @@ -253,7 +253,7 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_a64_sgemm_8x12, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, float, float>(args); } ), -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #ifdef ARM_COMPUTE_ENABLE_BF16 // "fast mode" (BF16) kernels GemmImplementation<float, float>::with_estimate( @@ -289,7 +289,7 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>(args); } ), -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // __aarch64__ #ifdef __arm__ @@ -318,7 +318,7 @@ const GemmImplementation<float, float> *gemm_implementation_list<float, float>() /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<float, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template KernelDescription get_gemm_method<float, float, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp index c41b0a5b3e..90e2f07607 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -450,7 +450,7 @@ public: } /* Make sure we've been set up correctly. */ - assert(_B_transposed); + assert(FixedFormat || _B_transposed); static_assert(std::is_same<To, Tloi>::value, "gemm_native: Operand types must be the same."); // static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same."); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index 75fb1cb306..19c8fcadd3 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -306,9 +306,12 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons } template<typename Top, typename Tret, class OutputStage> -bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) { +bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) { const GemmImplementation<Top, Tret, OutputStage> *impl; - return find_implementation<Top, Tret, OutputStage>(args, os, impl); + const bool success = find_implementation<Top, Tret, OutputStage>(args, os, impl); + if (success) + wf = UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os))->get_config().weight_format; + return success; } template<typename Top, typename Tret, class OutputStage> diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp index 3915861112..18d8fc9312 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp @@ -56,7 +56,7 @@ const GemmImplementation<int16_t, int32_t> *gemm_implementation_list<int16_t, in /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<int16_t, int32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<int16_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index 0c68e4dd99..24507486ac 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -159,7 +159,7 @@ const GemmImplementation<int8_t, int32_t> *gemm_implementation_list<int8_t, int3 /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<int8_t, int32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<int8_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp index 6b813c7974..1d7b9c5b73 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -230,7 +230,7 @@ const GemmImplementation<int8_t, int8_t, Requantize32> *gemm_implementation_list } template UniqueGemmCommon<int8_t, int8_t> gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); -template bool has_opt_gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); +template bool has_opt_gemm<int8_t, int8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os); template std::vector<KernelDescription> get_compatible_kernels<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index 95139c2bf6..be7a4ee570 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -197,7 +197,7 @@ const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_li } template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); -template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); +template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os); template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp index 20cee556f0..fc836f9790 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp @@ -56,7 +56,7 @@ const GemmImplementation<uint16_t, uint32_t> *gemm_implementation_list<uint16_t, /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp index a2d2cc86f0..03e9cd6c1f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp @@ -157,7 +157,7 @@ const GemmImplementation<uint8_t, uint32_t> *gemm_implementation_list<uint8_t, u /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp index 18e124b83e..d7b5398488 100644 --- a/src/core/NEON/kernels/arm_gemm/utils.hpp +++ b/src/core/NEON/kernels/arm_gemm/utils.hpp @@ -24,7 +24,7 @@ #pragma once -#include "arm_gemm.hpp" +#include "src/cpu/kernels/assembly/arm_gemm.hpp" #include <cstddef> #include <limits> |